-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[hierarchical-compilation][invoke-subgraph] Add fake tensor caching #137808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137808
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 4dd2319 with merge base 3d0aa6f ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
| if isinstance(func, torch._ops.HigherOrderOperator): | ||
| # For invoke_subgraph op, if the identifier is set then its safe to | ||
| # cache the fake tensor result. | ||
| from torch._higher_order_ops.utils import ( | ||
| registered_hop_fake_tensor_cache_key_validation_fns, | ||
| ) | ||
|
|
||
| if func in registered_hop_fake_tensor_cache_key_validation_fns: | ||
| return registered_hop_fake_tensor_cache_key_validation_fns[func]( | ||
| self, *args, **kwargs | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I kind of forgot the discussion we had on Thursday. I remember we reasoned out that everything was safe? If that's the case, can we include the reasoning here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, we're going the conservative approach
…r caching" [ghstack-poisoned]
| if isinstance(func, torch._ops.HigherOrderOperator): | ||
| # For invoke_subgraph, ignore the subgraph arg. We rely on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to do this here ? cant the op impl / op caching handle this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a subgraph here. We're trying to avoid needing to hash the subgraph everytime it is seen. So that's what the identifier is for.
Taking suggestions for how to make this better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the subgraph is repeated, can we have FakeTensorCache cache the hashing of it ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that better, what say you @anijain2305 ?
As a part of this, you could define what it means for a "subgraph to be cacheable" to be what the current "invoke subgraph fake tensor cache validator function" is and get rid of the "fake tensor cache validator registration" thing.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with the approach. I can imagine that we might want to dig it out and replace it with something better later (it's a bit conservative). Just want to poke on two points a bit (if we can remove the .tags from the HOP and I'm trying to figure out what the crossref thing is)
| self.tags = () | ||
| self.is_view = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anijain2305 do we still need these?
| # For invoke_subgraph op, if the identifier is set then its safe to | ||
| # cache the fake tensor result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment is not correct anymore right? It's only safe to cache the result if the validation fn returns True.
| if isinstance(func, torch._ops.HigherOrderOperator): | ||
| # For invoke_subgraph op, if the identifier is set then its safe to | ||
| # cache the fake tensor result. | ||
| from torch._higher_order_ops.utils import ( | ||
| registered_hop_fake_tensor_cache_key_validation_fns, | ||
| ) | ||
|
|
||
| if func in registered_hop_fake_tensor_cache_key_validation_fns: | ||
| return registered_hop_fake_tensor_cache_key_validation_fns[func]( | ||
| self, *args, **kwargs | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, we're going the conservative approach
…r caching" [ghstack-poisoned]
…r caching" [ghstack-poisoned]
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
Redoing #137808 [ghstack-poisoned]
Redoing #137808 Pull Request resolved: #149087 Approved by: https://github.com/zou3519
Redoing pytorch#137808 Pull Request resolved: pytorch#149087 Approved by: https://github.com/zou3519
Stack from ghstack (oldest at bottom):