-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
actionablegood first issuemodule: dynamic shapesmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Consider the following code:
import torch
@torch.compile
def foobar(x):
return x * 2
def test(device):
foobar(torch.empty((1, 16, 128, 128), device = device))
foobar(torch.empty((1, 32, 64, 64), device = device))
# OK
test("cuda")
print("cuda ok")
# Fails
test("meta")
print("meta ok")Running test with "cuda" works, but running test with the "meta" device fails with the following exception:
Traceback (most recent call last):
File ".venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/__init__.py", line 2234, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
return _fw_compiler_base(model, example_inputs, is_inference)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
compiled_graph = FxGraphCache.load(
^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1334, in load
compiled_graph = compile_fx_fn(
^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile
graph.run(*example_inputs)
File ".venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 780, in run
return super().run(*args)
^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1319, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1024, in call_function
raise LoweringException(e, target, args, kwargs).with_traceback(
File ".venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1021, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 361, in wrapped
out = decomp_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 2844, in empty_strided
pointwise.realize()
File ".venv/lib/python3.11/site-packages/torch/_inductor/ir.py", line 6282, in realize
return self.data.realize()
^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/ir.py", line 6367, in realize
layout=FlexibleLayout(
^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/torch/_inductor/ir.py", line 3254, in __init__
super().__init__(device, dtype, size, strides)
File ".venv/lib/python3.11/site-packages/torch/_inductor/ir.py", line 2900, in __init__
assert all(isinstance(s, (Expr, int)) for s in size)
torch._inductor.exc.LoweringException: AssertionError:
target: aten.empty_strided.default
args[0]: (1, s0, s1, s2)
args[1]: (s0*s1*s2, s1*s2, s2, 1)
kwargs: {'dtype': torch.float32, 'device': device(type='meta')}
This only happens when foobar is called twice inside test and when the size of the tensor in the second call is different.
Versions
(The collect_env.py script doesn't work for me so I'm pasting the versions manually)
torch 2.5.1
triton 3.1.0
python 3.11.8
cc @chauhang @penguinwu @ezyang @bobrenjc93 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @BoyuanFeng
Metadata
Metadata
Assignees
Labels
actionablegood first issuemodule: dynamic shapesmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module