-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[invoke_subgraph] Fake tensor prop caching #149087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149087
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 4386a6c with merge base ce54c43 ( UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
torch/_higher_order_ops/utils.py
Outdated
| registered_hop_fake_fns: dict[torch._ops.OpOverload, Callable] = {} | ||
|
|
||
|
|
||
| def register_hop_fake(hop, fn=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: call this just register_fake? the fqn already has HOP in it (torch._higher_order_ops.register_fake)
|
|
||
| def __hash__(self): | ||
| return id(self.subgraph) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can you add some sort of repr so we know what we're dealing with when we're debugging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
| # For debugging / testing: Validate that the output synthesized | ||
| # from the cache matches the output created by normal dispatch. | ||
| self._crosscheck_cache_output(output, func, types, args, kwargs) | ||
| with disable_fake_tensor_cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, when is self.cache_crosscheck_enabled set to True? I assume it is set to True somewhere in testing which is why this change was necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in test_fake_tensor.py with
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
| # caching implementation, e.g., data dependent ops or ops that modify | ||
| # the inputs. | ||
| from torch._higher_order_ops.utils import registered_hop_fake_fns | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type of func is now wrong, need to update it to be a union
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I tried this earlier and this causes a bunch of mypy failures because now func is expecting tags and other attributes all over the place, which are not present in HigherOrderOperator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah don't worry about it
| isinstance(func, torch._ops.HigherOrderOperator) | ||
| and func in registered_hop_fake_fns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is entirely correct, if the result of the HOP has data-dependent output shape or dynamic output shape then we need to bail out (
pytorch/torch/_subclasses/fake_tensor.py
Lines 1444 to 1448 in 2fcfae7
| if torch.Tag.data_dependent_output in func.tags: | |
| raise _BypassDispatchCache("data dependent output") | |
| if torch.Tag.dynamic_output_shape in func.tags: | |
| raise _BypassDispatchCache("dynamic output shape") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, I remember why I needed a validator function. Ah.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding new logic that goes through the subgraph nodes and checks each of them.
| elif isinstance(arg, torch.fx.GraphModule): | ||
| # This is used for invoke_subgraph where id(graph_module) allows | ||
| # us to cache fake outputs | ||
| result.append(type(arg)) | ||
| result.append(id(arg)) | ||
| elif isinstance(arg, FunctionalizeCtxWrapper): | ||
| result.append(hash(arg)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
umm, does id(arg) assume that the GraphModule stays alive forever? (What if the GraphModule gets deallocated and another one gets allocated in its stead?)
We might need some cache invalidation mechanism via weakref.finalize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing it here - #149667
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we forgot to discuss the recursive case (we did talk about it half a year ago, and I am remembering it now): what happens if invoke_subgraph has a subgraph where there is an operator that is not eligible for FakeTensor caching?
We shouldn't allow that invoke_subgraph to be cached. There's probably an efficient strategy for checking this, like during FakeTensorProp for invoke_subgraph we should do the FakeTensorProp for the subgraph first and then if that didn't have any ineligible operators then we say the invoke_subgraph can be cached.
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests failing
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Redoing #137808 [ghstack-poisoned]
|
Starting merge as part of PR stack under #150036 |
…jects (#149667) Pull Request resolved: #149667 Approved by: https://github.com/zou3519 ghstack dependencies: #149087
Pull Request resolved: #150036 Approved by: https://github.com/angelayi ghstack dependencies: #149087, #149667
Pull Request resolved: #150090 Approved by: https://github.com/eellison, https://github.com/zou3519 ghstack dependencies: #149087, #149667, #150036, #148953
Redoing pytorch#137808 Pull Request resolved: pytorch#149087 Approved by: https://github.com/zou3519
…jects (pytorch#149667) Pull Request resolved: pytorch#149667 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#149087
Pull Request resolved: pytorch#150036 Approved by: https://github.com/angelayi ghstack dependencies: pytorch#149087, pytorch#149667
…pytorch#148953) Pull Request resolved: pytorch#148953 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#149087, pytorch#149667, pytorch#150036
Pull Request resolved: pytorch#150090 Approved by: https://github.com/eellison, https://github.com/zou3519 ghstack dependencies: pytorch#149087, pytorch#149667, pytorch#150036, pytorch#148953
ghstack-source-id: e1a91e0 Pull Request resolved: pytorch/pytorch#149087
Stack from ghstack (oldest at bottom):
Redoing #137808