-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Closed
Copy link
Labels
oncall: cpu inductorCPU Inductor issues for Intel team to triageCPU Inductor issues for Intel team to triageoncall: pt2
Description
🐛 Describe the bug
When re-running a cached inductor graph that was saved with inductor_max_autotune and --freezing on, the cached graph is missing constant attributes and the graph lookup crashes.
Presumably this is related to the caching of constant tensors in torch/_inductor/codegen/cpp_gemm_template.py.
Repro
Run twice:
python benchmarks/dynamo/torchbench.py --accuracy --float32 -dcpu --dynamic-shapes --dynamic-batch-only --inference --inductor --inductor-compile-mode max-autotune --freezing --only BERT_pytorchOn first run:
cpu eval BERT_pytorch
pass
On second run:
cpu eval BERT_pytorch
ERROR:common:
Traceback (most recent call last):
File "/localdisk/fmitchel/pytorch/benchmarks/dynamo/common.py", line 3054, in check_accuracy
new_result = optimized_model_iter_fn(model_copy, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/eval_frame.py", line 573, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/convert_frame.py", line 1164, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/localdisk/fmitchel/pytorch/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/localdisk/fmitchel/pytorch/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/localdisk/fmitchel/pytorch/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/localdisk/fmitchel/pytorch/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
self._return(inst)
File "/localdisk/fmitchel/pytorch/torch/_dynamo/symbolic_convert.py", line 3033, in _return
self.output.compile_subgraph(
File "/localdisk/fmitchel/pytorch/torch/_dynamo/output_graph.py", line 1136, in compile_subgraph
self.compile_and_call_fx_graph(
File "/localdisk/fmitchel/pytorch/torch/_dynamo/output_graph.py", line 1382, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/output_graph.py", line 1432, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/output_graph.py", line 1483, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/localdisk/fmitchel/pytorch/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/__init__.py", line 2314, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_inductor/compile_fx.py", line 1552, in compile_fx
return compile_fx(
^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_inductor/compile_fx.py", line 1863, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_functorch/aot_autograd.py", line 1145, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 754, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 201, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_inductor/compile_fx.py", line 1457, in fw_compiler_freezing
optimized_function = inner_compile(
^^^^^^^^^^^^^^
File "/localdisk/fmitchel/miniforge3/envs/pt-dev/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_inductor/compile_fx.py", line 569, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_inductor/compile_fx.py", line 660, in _compile_fx_inner
mb_compiled_graph, cache_info = FxGraphCache.load_with_key(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_inductor/codecache.py", line 1308, in load_with_key
compiled_graph, cache_info = FxGraphCache._lookup_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_inductor/codecache.py", line 1090, in _lookup_graph
artifact_path = graph.after_deserialization(constants)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_inductor/output_code.py", line 568, in after_deserialization
constants.unwrap(self),
^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/_inductor/output_code.py", line 294, in unwrap
return {
^
File "/localdisk/fmitchel/pytorch/torch/_inductor/output_code.py", line 295, in <dictcomp>
name: getattr(self.gm, name)
^^^^^^^^^^^^^^^^^^^^^^
File "/localdisk/fmitchel/pytorch/torch/nn/modules/module.py", line 1928, in __getattr__
raise AttributeError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: '<lambda>' object has no attribute 'constant195'
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
TorchDynamo optimized model failed to run because of following error
fail_to_run
Versions
| name | target_branch | target_commit | refer_branch | refer_commit |
|---|---|---|---|---|
| torchbench | main | 766a5e3a | main | 766a5e3a |
| torch | main | 2682e5e | main | e29dabb |
| torchvision | main | 0.19.0a0+d23a6e1 | main | 0.19.0a0+d23a6e1 |
| torchtext | main | 0.16.0a0+b0ebddc | main | 0.16.0a0+b0ebddc |
| torchaudio | main | 2.5.0a0+332760d | main | 2.5.0a0+332760d |
| torchdata | main | 0.7.0a0+11bb5b8 | main | 0.7.0a0+11bb5b8 |
Metadata
Metadata
Labels
oncall: cpu inductorCPU Inductor issues for Intel team to triageCPU Inductor issues for Intel team to triageoncall: pt2