Skip to content

inductor can't broadcast tensors when they have dynamic shapes: #136640

@bdhirsh

Description

@bdhirsh

Original repro can from torchtune, minified repro below:

import torch

@torch.compile(dynamic=True)
def f(x, y):
    x_view = x.view(-1, 4)
    y_view = y.view(-1, 4)
    return x_view * y_view


x = torch.randn(4)
y = torch.randn(8)
out = f(x, y)

Here, inductor needs to lower a broadcasting aten.mul(), where the shapes in question being multiplied are [1, 4] and [2, 4]. It fails with:

  File "/home/hirsheybar/local/b/pytorch/torch/_inductor/lowering.py", line 408, in broadcast_symbolic_shapes
    V.graph.sizevars.guard_equals(x, y)
  File "/home/hirsheybar/local/b/pytorch/torch/_inductor/sizevars.py", line 418, in guard_equals
    assert self.shape_env.evaluate_expr(sympy.Eq(left, right))

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions