-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Hi, we find dynamo cannot handle the tensor attributes assigned from the user side. Is it expected or any suggestions to solve it?
Error logs
Traceback (most recent call last):
File "dynamo_tensor_attr.py", line 17, in <module>
r1 = compiled_fn()
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
File "dynamo_tensor_attr.py", line 12, in toy_example
t.add_one = add_one
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
return self._torchdynamo_orig_callable(
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
result = self._inner_convert(
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
return _compile(
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
out_code = transform_code_object(code, transform)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
transformations(instructions, code_options)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
return fn(*args, **kwargs)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
tracer.run()
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
super().run()
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
while self.step():
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
self.dispatch_table[inst.opcode](self, inst)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
return inner_fn(self, inst)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1459, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/variables/misc.py", line 680, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/variables/tensor.py", line 476, in call_method
return wrap_fx_proxy(
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value
ret_val = wrap_fake_exception(
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception
return fn()
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1786, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1921, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/Users/chenxiny/miniforge3/envs/torch-metal/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1905, in run_node
return getattr(args[0], node.target)(*args[1:], **kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_method add_one(*(Parameter(FakeTensor(..., size=(), requires_grad=True)),), **{}):
'FakeTensor' object has no attribute 'add_one'
from user code:
File "dynamo_tensor_attr.py", line 13, in torch_dynamo_resume_in_toy_example_at_12
return t.add_one()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Minified repro
import torch
import torch._dynamo as dynamo
dynamo.reset()
def toy_example():
def add_one(x):
return x + 1
t = torch.nn.Parameter(torch.tensor(1.))
t.add_one = add_one
return t.add_one(t)
compiled_fn = torch.compile(toy_example, backend="inductor")
r1 = compiled_fn()
print(f"r1 = {r1}")Versions
PyTorch version: 2.5.0.dev20240729
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.3.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.0
Libc version: N/A
Python version: 3.8.18 (default, Sep 11 2023, 08:17:16) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.3.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] onnx==1.16.1
[pip3] onnx2torch==1.5.14
[pip3] optree==0.11.0
[pip3] pytorch-lightning==1.7.7
[pip3] torch==2.5.0a0+git1f961ad
[pip3] torchaudio==2.4.0.dev20240729
[pip3] torchmetrics==0.10.0
[pip3] torchvision==0.12.0
[conda] numpy 1.22.3 pypi_0 pypi
[conda] numpy-base 1.24.3 py38h90707a3_0
[conda] onnx2torch 1.5.14 pypi_0 pypi
[conda] optree 0.11.0 pypi_0 pypi
[conda] pytorch-lightning 1.7.7 pypi_0 pypi
[conda] torch 1.12.0.dev20220518 pypi_0 pypi
[conda] torchaudio 2.4.0.dev20240729 pypi_0 pypi
[conda] torchmetrics 0.10.0 pypi_0 pypi
[conda] torchvision 0.12.0 pypi_0 pypi
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @rec
Metadata
Metadata
Assignees
Labels
module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module