-
Notifications
You must be signed in to change notification settings - Fork 26.3k
propagate .meta info when replacing subgraphs in fx #87255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87255
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3385d45: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
wconstab
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like a good change, but i wonder how we can tell if we're still missing cases, should the default be true?
|
Yeah - I was pretty torn about whether or not copying metadata should be explicit vs. implicit. Right now it's explicit - you specify I wasn't sure about making it implicit, because I guess I could flip it so that we copy metadata by default, but only if there's no metadata field set on the new node (and if there is, maybe we should just silently not copy metadata instead of erroring). I'm open to either / other opinions though. |
| f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' | ||
|
|
||
| @compatibility(is_backward_compatible=True) | ||
| def replace_all_uses_with(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is an FX public API change (it shouldn't be BC breaking though) - so I'll have to update the tests.
Who else should review public FX API changes? @SherlockNoMad ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a BC breaking change. So it should be fine.
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. [ghstack-poisoned]
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. Pull Request resolved: pytorch#87255 Approved by: https://github.com/wconstab, https://github.com/SherlockNoMad
Fixes pytorch/torchdynamo#1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. Pull Request resolved: pytorch#87255 Approved by: https://github.com/wconstab, https://github.com/SherlockNoMad
Fixes pytorch/torchdynamo#1708
Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new
call_modulenode in the graph.If the original subgraph outputs had fake tensors and other metadata stored in their
.metaattribute though, then this information was getting lost when we spliced in the subgraph.Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the
fx.Node.replace_all_uses_with, just because propagating metadata seems like a pretty core thing.Stack from ghstack (oldest at bottom):