Skip to content

broken torch.compile with "meta" device tensors #144607

@koute

Description

@koute

🐛 Describe the bug

Consider the following code:

import torch

@torch.compile
def foobar(x):
    return x * 2

def test(device):
    foobar(torch.empty((1, 16, 128, 128), device = device))
    foobar(torch.empty((1, 32, 64, 64), device = device))

# OK
test("cuda")
print("cuda ok")

# Fails
test("meta")
print("meta ok")

Running test with "cuda" works, but running test with the "meta" device fails with the following exception:

Traceback (most recent call last):
  File ".venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/__init__.py", line 2234, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1334, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File ".venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 780, in run
    return super().run(*args)
           ^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1319, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1024, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File ".venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1021, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 361, in wrapped
    out = decomp_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 2844, in empty_strided
    pointwise.realize()
  File ".venv/lib/python3.11/site-packages/torch/_inductor/ir.py", line 6282, in realize
    return self.data.realize()
           ^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/ir.py", line 6367, in realize
    layout=FlexibleLayout(
           ^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/_inductor/ir.py", line 3254, in __init__
    super().__init__(device, dtype, size, strides)
  File ".venv/lib/python3.11/site-packages/torch/_inductor/ir.py", line 2900, in __init__
    assert all(isinstance(s, (Expr, int)) for s in size)
torch._inductor.exc.LoweringException: AssertionError: 
  target: aten.empty_strided.default
  args[0]: (1, s0, s1, s2)
  args[1]: (s0*s1*s2, s1*s2, s2, 1)
  kwargs: {'dtype': torch.float32, 'device': device(type='meta')}

This only happens when foobar is called twice inside test and when the size of the tensor in the second call is different.

Versions

(The collect_env.py script doesn't work for me so I'm pasting the versions manually)

torch 2.5.1
triton 3.1.0
python 3.11.8

cc @chauhang @penguinwu @ezyang @bobrenjc93 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @BoyuanFeng

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions