Skip to content

Conversation

@fffrog
Copy link
Collaborator

@fffrog fffrog commented Nov 19, 2024

Stack from ghstack (oldest at bottom):

Detailed description:

The codes below will raise an error

import torch
from torch.fx.experimental.proxy_tensor import make_fx

def func(a):
    b = a + 1
    c = b.view(-1)
    c.add_(1)
    return b

input = torch.randn(2)
out = make_fx(func)(input)

The error info are like below:

...
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/codegen.py", line 34, in <module>
    from .variables.torch_function import TensorWithTFOverrideVariable
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 185, in <module>
    populate_builtin_to_tensor_fn_map()
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 146, in populate_builtin_to_tensor_fn_map
    inp0 = torch.ones(1)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1240, in __torch_function__
    return func(*args, **kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1342, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 907, in proxy_call
    name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
AttributeError: 'PythonKeyTracer' object has no attribute 'graph'
...

Solutions:
Import torch._dynamo before dispatch_trace is called to avoid the context set before dispatch_trace from affecting the torch._dynamo import.

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141022

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 1e62e3d with merge base 93aef68 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@fffrog
Copy link
Collaborator Author

fffrog commented Nov 19, 2024

@pytorchbot label "topic: not user facing"

[ghstack-poisoned]
fffrog added a commit that referenced this pull request Nov 19, 2024
Detailed description:

The codes below will raise an error
```Python
import torch
from torch.fx.experimental.proxy_tensor import make_fx

def func(a):
    b = a + 1
    c = b.view(-1)
    c.add_(1)
    return b

input = torch.randn(2)
out = make_fx(func)(input)
```

The error info are like below:
```Python
...
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/codegen.py", line 34, in <module>
    from .variables.torch_function import TensorWithTFOverrideVariable
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 185, in <module>
    populate_builtin_to_tensor_fn_map()
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 146, in populate_builtin_to_tensor_fn_map
    inp0 = torch.ones(1)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1240, in __torch_function__
    return func(*args, **kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1342, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 907, in proxy_call
    name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
AttributeError: 'PythonKeyTracer' object has no attribute 'graph'
...
```

Solutions:
Import torch._dynamo before dispatch_trace is called to avoid the context set before dispatch_trace from affecting the torch._dynamo import.

ghstack-source-id: ed28768
Pull Request resolved: #141022
@ezyang
Copy link
Contributor

ezyang commented Nov 20, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 20, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@fffrog
Copy link
Collaborator Author

fffrog commented Nov 20, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
…1022)

Detailed description:

The codes below will raise an error
```Python
import torch
from torch.fx.experimental.proxy_tensor import make_fx

def func(a):
    b = a + 1
    c = b.view(-1)
    c.add_(1)
    return b

input = torch.randn(2)
out = make_fx(func)(input)
```

The error info are like below:
```Python
...
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/codegen.py", line 34, in <module>
    from .variables.torch_function import TensorWithTFOverrideVariable
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 185, in <module>
    populate_builtin_to_tensor_fn_map()
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 146, in populate_builtin_to_tensor_fn_map
    inp0 = torch.ones(1)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1240, in __torch_function__
    return func(*args, **kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1342, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 907, in proxy_call
    name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
AttributeError: 'PythonKeyTracer' object has no attribute 'graph'
...
```

Solutions:
Import torch._dynamo before dispatch_trace is called to avoid the context set before dispatch_trace from affecting the torch._dynamo import.

Pull Request resolved: pytorch#141022
Approved by: https://github.com/ezyang
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…1022)

Detailed description:

The codes below will raise an error
```Python
import torch
from torch.fx.experimental.proxy_tensor import make_fx

def func(a):
    b = a + 1
    c = b.view(-1)
    c.add_(1)
    return b

input = torch.randn(2)
out = make_fx(func)(input)
```

The error info are like below:
```Python
...
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/codegen.py", line 34, in <module>
    from .variables.torch_function import TensorWithTFOverrideVariable
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 185, in <module>
    populate_builtin_to_tensor_fn_map()
  File "/root/Git.d/pytorch/pytorch/torch/_dynamo/variables/torch_function.py", line 146, in populate_builtin_to_tensor_fn_map
    inp0 = torch.ones(1)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1240, in __torch_function__
    return func(*args, **kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 1342, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/root/Git.d/pytorch/pytorch/torch/fx/experimental/proxy_tensor.py", line 907, in proxy_call
    name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
AttributeError: 'PythonKeyTracer' object has no attribute 'graph'
...
```

Solutions:
Import torch._dynamo before dispatch_trace is called to avoid the context set before dispatch_trace from affecting the torch._dynamo import.

Pull Request resolved: pytorch#141022
Approved by: https://github.com/ezyang
@github-actions github-actions bot deleted the gh/fffrog/34/head branch December 21, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fx Merged open source release notes: fx release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants