-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
Original repro can from torchtune, minified repro below:
import torch
@torch.compile(dynamic=True)
def f(x, y):
x_view = x.view(-1, 4)
y_view = y.view(-1, 4)
return x_view * y_view
x = torch.randn(4)
y = torch.randn(8)
out = f(x, y)
Here, inductor needs to lower a broadcasting aten.mul(), where the shapes in question being multiplied are [1, 4] and [2, 4]. It fails with:
File "/home/hirsheybar/local/b/pytorch/torch/_inductor/lowering.py", line 408, in broadcast_symbolic_shapes
V.graph.sizevars.guard_equals(x, y)
File "/home/hirsheybar/local/b/pytorch/torch/_inductor/sizevars.py", line 418, in guard_equals
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu