Skip to content

Disable python dispatcher in fallthrough for PyOperators#95891

Closed
angelayi wants to merge 5 commits intomasterfrom
cond_pyd
Closed

Disable python dispatcher in fallthrough for PyOperators#95891
angelayi wants to merge 5 commits intomasterfrom
cond_pyd

Conversation

@angelayi
Copy link
Contributor

@angelayi angelayi commented Mar 2, 2023

Possible fix for #89037

Context: The existing fallthrough implementation for PyOperators will cause the PythonDispatcher to infinitely redispatch to the PythonDispatcher due to this line which permanently adds the PythonDispatcher to the dispatch key set which we get on this line. We temporarily fixed this by excluding the PythonDispatcher key from the global keyset (here), but this runs into an issue during the implementation for the functionalization key where we want to call functionalize for the true/false subgraphs, and make_fx to check for aliasing/mutations, which requires having the PythonDispatcher key.

Our attempt at fixing this is to modify the fallthrough function to ignore the PythonDispatcher key when generating keys to redispatch to. This should prevent the infinite recursion, but won't modify the global state of having the PythonDispatcher key.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 2, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95891

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit cea298f:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@angelayi angelayi added the topic: not user facing topic category label Mar 2, 2023
@angelayi angelayi marked this pull request as ready for review March 2, 2023 18:46
@angelayi angelayi requested a review from tugsbayasgalan March 2, 2023 18:46
@ezyang
Copy link
Contributor

ezyang commented Mar 3, 2023

how urgent is this

@angelayi
Copy link
Contributor Author

angelayi commented Mar 3, 2023

how urgent is this

semi-urgent? This is blocking turning on functionalization for DPE which uses the control flow ops. Right now we use a hacky version of functionalization which just skips the control flow ops, but it should be removed...

@ezyang
Copy link
Contributor

ezyang commented Mar 3, 2023

I feel like there is probably a much simpler fix for the problem. Thinking.

"""
try:
gm = make_fx(branch)(*fake_inputs)
gm = make_fx(branch)(*inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: try to avoid unrelated refactor like this, it makes it harder for reviewer to see what's going on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored those fixes into #95988, so those changes should disappear in this PR after that one merges!

@ezyang
Copy link
Contributor

ezyang commented Mar 3, 2023

see #89037 (comment)

@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2023

Hmmph. I don't like this, because you're smearing out state on the PyOperator that should be on a per call basis. And in fact, it's not even right, because if I fallthrough a key and then redispatch, I will redo that key (the fallthrough is not sticky!)

@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2023

Can you tell me more about why the approach we described in VC didn't work out?

@angelayi
Copy link
Contributor Author

angelayi commented Mar 7, 2023

Can you tell me more about why the approach we described in VC didn't work out?

My understanding from the VC is that order we want is for every key that is run (besides the PythonDispatcher), it should not be dispatched to again and go back to the PythonDispatcher to dispatch to the following key. So the order should look something like PythonDispatcher -> PythonTLSSnapshot -> PythonDispatcher -> AutogradCPU -> PythonDispatcher ....

The way I thought to do that was to use the ExcludeDispatchKeyGuard to prevent those keys from being dispatched to again. But because that affects the global set of keys, it prevented the inner make_fx call we make in cond from running correctly.

if I fallthrough a key and then redispatch, I will redo that key (the fallthrough is not sticky!)

If you fallthrough a key wouldn't it get added to the list of keys that have been run already and redispatch to PythonDispatcher?

@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2023

The way I thought to do that was to use the ExcludeDispatchKeyGuard to prevent those keys from being dispatched to again. But because that affects the global set of keys, it prevented the inner make_fx call we make in cond from running correctly.

Not necessary. Because python dispatcher can compute what the correct key to go to is. So then you just call it directly (op_dk, or just call the callable in your python side dispatch dict)

@ezyang
Copy link
Contributor

ezyang commented Mar 8, 2023

No test?

ezyang added a commit that referenced this pull request Mar 8, 2023
Fallthrough is modeled as a mask which we use to remove keys from the
compute dispatch key set for eligibility.

It's possible this addresses #89037
in a better way than #95891 but I
cannot easily tell as the original repro no longer works and the new PR
does not have a test.

Signed-off-by: Edward Z. Yang <[email protected]>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Mar 8, 2023
…how C++ works"

Fallthrough is modeled as a mask which we use to remove keys from the
compute dispatch key set for eligibility.

It's possible this addresses #89037
in a better way than #95891 but I
cannot easily tell as the original repro no longer works and the new PR
does not have a test.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Mar 8, 2023
Fallthrough is modeled as a mask which we use to remove keys from the
compute dispatch key set for eligibility.

It's possible this addresses #89037
in a better way than #95891 but I
cannot easily tell as the original repro no longer works and the new PR
does not have a test.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Mar 8, 2023
…how C++ works"

Fallthrough is modeled as a mask which we use to remove keys from the
compute dispatch key set for eligibility.

It's possible this addresses #89037
in a better way than #95891 but I
cannot easily tell as the original repro no longer works and the new PR
does not have a test.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Mar 8, 2023
Fallthrough is modeled as a mask which we use to remove keys from the
compute dispatch key set for eligibility.

It's possible this addresses #89037
in a better way than #95891 but I
cannot easily tell as the original repro no longer works and the new PR
does not have a test.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Mar 8, 2023
Fallthrough is modeled as a mask which we use to remove keys from the
compute dispatch key set for eligibility.

It's possible this addresses #89037
in a better way than #95891 but I
cannot easily tell as the original repro no longer works and the new PR
does not have a test.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: #96304
Approved by: https://github.com/zou3519, https://github.com/albanD, https://github.com/zhxchen17
@zhxchen17
Copy link
Contributor

trying to reach out to Angela to see what's she'll do for this PR. She's on pto right now.

@ezyang
Copy link
Contributor

ezyang commented Mar 10, 2023

Please check whatever actual use case you needed isn't already fixed on master, I landed a set of orthogonal changes which should fix infinite fallthrough loops. I have no way of testing since this PR doesn't have a test.

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
Fallthrough is modeled as a mask which we use to remove keys from the
compute dispatch key set for eligibility.

It's possible this addresses pytorch/pytorch#89037
in a better way than pytorch/pytorch#95891 but I
cannot easily tell as the original repro no longer works and the new PR
does not have a test.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch/pytorch#96304
Approved by: https://github.com/zou3519, https://github.com/albanD, https://github.com/zhxchen17
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
Fallthrough is modeled as a mask which we use to remove keys from the
compute dispatch key set for eligibility.

It's possible this addresses pytorch/pytorch#89037
in a better way than pytorch/pytorch#95891 but I
cannot easily tell as the original repro no longer works and the new PR
does not have a test.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch/pytorch#96304
Approved by: https://github.com/zou3519, https://github.com/albanD, https://github.com/zhxchen17
@angelayi
Copy link
Contributor Author

Yup, the actual case is fixed on master. #96635 to remove the existing hacky fallthrough (the tests were already landed previously).

@angelayi angelayi closed this Mar 13, 2023
ydwu4 added a commit to ydwu4/pytorch that referenced this pull request Mar 13, 2023
Fallthrough is modeled as a mask which we use to remove keys from the
compute dispatch key set for eligibility.

It's possible this addresses pytorch#89037
in a better way than pytorch#95891 but I
cannot easily tell as the original repro no longer works and the new PR
does not have a test.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#96304
Approved by: https://github.com/zou3519, https://github.com/albanD, https://github.com/zhxchen17
@github-actions github-actions bot deleted the cond_pyd branch September 2, 2024 02:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants