-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepmodule: aotinductoraot inductoraot inductormodule: dynamic shapesoncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Hi. I was looking at a case where it is beneficial to tell torch.export a dim is a multiple of 2. So I did dynamic_shapes={"x": {1: 2 * bs}}. Export runs fine.
However, when I tried to aot compile, I ran into C++ compilation error, which says s0 is not defined. The problem is, while s1 is a sympy.Symbol, 2*s1 is a <class 'sympy.core.mul.Mul'> instead.
repro:
import torch
class TestModel(torch.nn.Module):
def forward(self, x):
return x.shape[1]
from torch.export import Dim
model = TestModel().cuda()
x = torch.rand(2, 6, 6).cuda()
args = (x,)
bs = Dim("bs", max=128)
ep = torch.export.export(
model, args, dynamic_shapes={"x": {1: 2 * bs}}, strict=False
)
aot_model = torch._inductor.aot_compile(ep.module(), args)
codegen: (Note that s1 is not defined)
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 1);
auto arg0_1 = std::move(inputs[0]);
inputs.clear();
auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
RAIIAtenTensorHandle scalar_to_tensor_0 = scalar_to_tensor_handle(2L*s1);
output_handles[0] = scalar_to_tensor_0.release();
C++ compilation error:
error: use of undeclared identifier 's1'
593 | RAIIAtenTensorHandle scalar_to_tensor_0 = scalar_to_tensor_handle(2L*s1);
| ^
1 error generated.
Versions
trunk
cc @ezyang @chauhang @penguinwu @bobrenjc93 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @desertfire @chenyang78
Metadata
Metadata
Assignees
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepmodule: aotinductoraot inductoraot inductormodule: dynamic shapesoncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module