-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix torch.compile FunctionalTensor inputs for higherOrderOps #107604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107604
Note: Links to docs will display an error until the docs builds have been completed. ✅ 1 Unrelated FailureAs of commit 4b2410e with merge base 67bb3c0 ( UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/dynamo/test_subclasses.py
Outdated
| actual = normalize_gm( | ||
| backend.graphs[exp_n_graph - 1].print_readable(print_output=False) | ||
| ) | ||
| self.assertEqual(actual, exp_graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we assertExpectedInline somehow on this?
torch/_subclasses/fake_tensor.py
Outdated
|
|
||
| def is_fake(x): | ||
| if isinstance(x, FakeTensor): | ||
| def is_fake(x, fake_mode=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really understand why we need to pass in a fake_mode to is_fake. Are we saying that a FakeTensor might have a different fake_mode than expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this could happen when input is a FakeTensor created outside of compiled region e.g. make_fx(torch.func.functionalize(compiled_f))(x). But theoretically, FakeTensors should already be fakified by dynamo by the time we call the first is_fake right now. I guess it's mostly to add more safety checks and enforcing an invariant that all fake tensors in the dynamo should has the same fake_mode as current instructionTranslator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My problem right now is that the is_fake(fake_mode) API doesn't really make sense. If you want to uphold an invariant that all fake tensors have the same fake_mode as the current instruction translator, then instead of complicating the is_fake check we should add an assertion somewhere (maybe post-fakification?) that all FakeTensors have the same fake mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that also sounds good to me! I can change the implementation.
Before this PR, for the added [test](https://github.com/pytorch/pytorch/pull/107604/files#diff-c618f2274b6b5ccc533c580549d2e552edbd9fc5ac0da1aa4b00338525c8f78dR224), which feeds FunctionTensorWrapper inputs to higherOrderOperator, we have an assertion error in this line [code](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1R1284). The key difference of this PR is this [line ](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1L1263)of check: ```python elif ( isinstance(example_value, FakeTensor) and example_value.fake_mode is tx.fake_mode ): ``` The original intention of it seems to be dealing with case where we want to wrap an fx proxy for an intermediate fake tensor that's already tracked by dynamo and an example value is provided (as is the case for higherOrderOps [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/higher_order_ops.py#L85)). A fakified FunctionalTensorWrapper(FakeTensor) always fails this check. This PR changes it to checking whether it's already fakified by tx.fake_mode. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Before this PR, for the added [test](https://github.com/pytorch/pytorch/pull/107604/files#diff-c618f2274b6b5ccc533c580549d2e552edbd9fc5ac0da1aa4b00338525c8f78dR224), which feeds FunctionTensorWrapper inputs to higherOrderOperator, we have an assertion error in this line [code](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1R1284). The key difference of this PR is this [line ](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1L1263)of check: ```python elif ( isinstance(example_value, FakeTensor) and example_value.fake_mode is tx.fake_mode ): ``` The original intention of it seems to be dealing with case where we want to wrap an fx proxy for an intermediate fake tensor that's already tracked by dynamo and an example value is provided (as is the case for higherOrderOps [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/higher_order_ops.py#L85)). A fakified FunctionalTensorWrapper(FakeTensor) always fails this check. This PR changes it to checking whether it's already fakified by tx.fake_mode. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Before this PR, for the added [test](https://github.com/pytorch/pytorch/pull/107604/files#diff-c618f2274b6b5ccc533c580549d2e552edbd9fc5ac0da1aa4b00338525c8f78dR224), which feeds FunctionTensorWrapper inputs to higherOrderOperator, we have an assertion error in this line [code](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1R1284). The key difference of this PR is this [line ](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1L1263)of check: ```python elif ( isinstance(example_value, FakeTensor) and example_value.fake_mode is tx.fake_mode ): ``` The original intention of it seems to be dealing with case where we want to wrap an fx proxy for an intermediate fake tensor that's already tracked by dynamo and an example value is provided (as is the case for higherOrderOps [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/higher_order_ops.py#L85)). A fakified FunctionalTensorWrapper(FakeTensor) always fails this check. This PR changes it to checking whether it's already fakified by tx.fake_mode. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Before this PR, for the added test, which feeds FunctionTensorWrapper inputs to higherOrderOperator, we have an assertion error in this line code.
The key difference of this PR is this line of check:
The original intention of it seems to be dealing with case where we want to wrap an fx proxy for an intermediate fake tensor that's produced by some tensor ops and an example value is provided (as is the case for higherOrderOps here). A fakified FunctionalTensorWrapper(FakeTensor) always fails this check. This PR changes it to checking whether it's already fakified by tx.fake_mode.
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov