-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix bugs about torch.fx.experimental.proxy_tensor.make_fx #141022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141022
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 1e62e3d with merge base 93aef68 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "topic: not user facing" |
Detailed description:
The codes below will raise an error
```Python
import torch
from torch.fx.experimental.proxy_tensor import make_fx
def func(a):
b = a + 1
c = b.view(-1)
c.add_(1)
return b
input = torch.randn(2)
out = make_fx(func)(input)
```
The error info are like below:
```Python
...
File "/root/Git.d/pytorch/pytorch/torch/_dynamo/codegen.py", line 34, in <module>
from .variables.torch_function import TensorWithTFOverrideVariable
File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 185, in <module>
populate_builtin_to_tensor_fn_map()
File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 146, in populate_builtin_to_tensor_fn_map
inp0 = torch.ones(1)
File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1240, in __torch_function__
return func(*args, **kwargs)
File "/root/Git.d/pytorch/pytorch/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1342, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 907, in proxy_call
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
AttributeError: 'PythonKeyTracer' object has no attribute 'graph'
...
```
Solutions:
Import torch._dynamo before dispatch_trace is called to avoid the context set before dispatch_trace from affecting the torch._dynamo import.
ghstack-source-id: ed28768
Pull Request resolved: #141022
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…1022) Detailed description: The codes below will raise an error ```Python import torch from torch.fx.experimental.proxy_tensor import make_fx def func(a): b = a + 1 c = b.view(-1) c.add_(1) return b input = torch.randn(2) out = make_fx(func)(input) ``` The error info are like below: ```Python ... File "/root/Git.d/pytorch/pytorch/torch/_dynamo/codegen.py", line 34, in <module> from .variables.torch_function import TensorWithTFOverrideVariable File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 185, in <module> populate_builtin_to_tensor_fn_map() File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 146, in populate_builtin_to_tensor_fn_map inp0 = torch.ones(1) File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1240, in __torch_function__ return func(*args, **kwargs) File "/root/Git.d/pytorch/pytorch/torch/utils/_stats.py", line 21, in wrapper return fn(*args, **kwargs) File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1342, in __torch_dispatch__ return proxy_call(self, func, self.pre_dispatch, args, kwargs) File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 907, in proxy_call name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__), AttributeError: 'PythonKeyTracer' object has no attribute 'graph' ... ``` Solutions: Import torch._dynamo before dispatch_trace is called to avoid the context set before dispatch_trace from affecting the torch._dynamo import. Pull Request resolved: pytorch#141022 Approved by: https://github.com/ezyang
…1022) Detailed description: The codes below will raise an error ```Python import torch from torch.fx.experimental.proxy_tensor import make_fx def func(a): b = a + 1 c = b.view(-1) c.add_(1) return b input = torch.randn(2) out = make_fx(func)(input) ``` The error info are like below: ```Python ... File "/root/Git.d/pytorch/pytorch/torch/_dynamo/codegen.py", line 34, in <module> from .variables.torch_function import TensorWithTFOverrideVariable File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 185, in <module> populate_builtin_to_tensor_fn_map() File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 146, in populate_builtin_to_tensor_fn_map inp0 = torch.ones(1) File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1240, in __torch_function__ return func(*args, **kwargs) File "/root/Git.d/pytorch/pytorch/torch/utils/_stats.py", line 21, in wrapper return fn(*args, **kwargs) File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1342, in __torch_dispatch__ return proxy_call(self, func, self.pre_dispatch, args, kwargs) File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 907, in proxy_call name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__), AttributeError: 'PythonKeyTracer' object has no attribute 'graph' ... ``` Solutions: Import torch._dynamo before dispatch_trace is called to avoid the context set before dispatch_trace from affecting the torch._dynamo import. Pull Request resolved: pytorch#141022 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
Detailed description:
The codes below will raise an error
The error info are like below:
Solutions:
Import torch._dynamo before dispatch_trace is called to avoid the context set before dispatch_trace from affecting the torch._dynamo import.
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv