-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Make relationship between outputs and inputs of auto_functionalize static and not depending on addresses of new inputs passed #134315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…atic and not depending on new inputs passed [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134315
Note: Links to docs will display an error until the docs builds have been completed. ❌ 35 New Failures, 1 Unrelated FailureAs of commit eb59152 with merge base 938f37b ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
ah why :
|
…ionalize static and not depending on new inputs passed" In the previous diff, auto_functionalize used to have the form ``` [x', y'] = auto_functionalize( foo, x, y, _arg0_base=x, _arg1_base = y, _all_bases=[x,y]) ``` Which can be read as auto_functionalize 1) has two outputs len(all_bases) 2) the first output starts with all_bases[0] `x` then it will observe the mutations each arg that have arg_base =`x` In this case the mutation on x. (which is read either from copy(x) that is passed to foo). 3) same for the second output, it will be start with y and observe the mutations of the arg y. Now one problem with the above is if a pass run and decided that y=x then we get : ``` [x', y'] = auto_functionalize( foo, x, x, _arg0_base=x, _arg1_base = x,_ all_bases=[x,x]) ``` This is problematic because 1) x' now start with x and observe both mutations of first and second arg of foo, because _arg0_base=x and _arg1_base = x which changes the semantics of the function. and make it wrong even if do no re-inplace anything! 2) same for y' not it starts with x but observe both mutations. **Solution** Make the relationships between the outputs and the args that we need to observe their mutations static and not changeable. ``` [x', y'] = auto_functionalize( foo, x, y, _all_bases=[x,y], _observe_mutation_from=[[arg0],[arg1]) ``` we use the addresses recorded at the time of the functionlization to know for each base, which args need to have their mutations observed 1) x' output starts with x and observe mutations of arg0 2) y' output starts with and observe mutations of arg1 if we do CSE we get ``` [x', y'] = auto_functionalize( foo, x, x, _all_bases=[x,x], _observe_mutation_from=[[arg0],[arg1]) ``` 1) x' output starts with x and observe mutations of arg0 2) y' output starts with x and observe mutations of arg1 this makes inductor/test_inplacing_pass.py::TestReinplacingPassCorrectness::test_multi_output_intermediate pass. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…ionalize static and not depending on new inputs passed" In the previous diff, auto_functionalize used to have the form ``` [x', y'] = auto_functionalize( foo, x, y, _arg0_base=x, _arg1_base = y, _all_bases=[x,y]) ``` Which can be read as auto_functionalize 1) has two outputs len(all_bases) 2) the first output starts with all_bases[0] `x` then it will observe the mutation from each arg that have arg_base =`x` In this case the mutation on arg0. 3) same for the second output, it will be start with y and observe the mutations of the arg1. Now one problem with the above is if a pass run and decided that y=x then we get : ``` [x', y'] = auto_functionalize( foo, x, x, _arg0_base=x, _arg1_base = x,_ all_bases=[x,x]) ``` This is problematic because 1) x' now start with x and observe both mutations of first and second arg of foo, because _arg0_base=x and _arg1_base = x which changes the semantics of the function. and make it wrong even if do not re-inplace anything! 2) same for y' it starts with x but observe both mutations. **Solution** Make the relationships between the outputs and the args that we need to observe their mutations static and not changeable. ``` [x', y'] = auto_functionalize( foo, x, y, _all_bases=[x,y], _observe_mutation_from=[[arg0],[arg1]) ``` we use the addresses recorded at the time of the functionlization to know for each base, which args need to have their mutations observed 1) x' output starts with x and observe mutations of arg0 2) y' output starts with and observe mutations of arg1 if we do CSE we get ``` [x', y'] = auto_functionalize( foo, x, x, _all_bases=[x,x], _observe_mutation_from=[[arg0],[arg1]) ``` 1) x' output starts with x and observe mutations of arg0 2) y' output starts with x and observe mutations of arg1 this makes inductor/test_inplacing_pass.py::TestReinplacingPassCorrectness::test_multi_output_intermediate pass. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…ionalize static and not depending on addresses of new inputs passed" In the previous diff, auto_functionalize used to have the form ``` [x', y'] = auto_functionalize( foo, x, y, _arg0_base=x, _arg1_base = y, _all_bases=[x,y]) ``` Which can be read as auto_functionalize 1) has two outputs len(all_bases) 2) the first output starts with all_bases[0] `x` then it will observe the mutation from each arg that have arg_base =`x` In this case the mutation on arg0. 3) same for the second output, it will be start with y and observe the mutations of the arg1. Now one problem with the above is if a pass run and decided that y=x then we get : ``` [x', y'] = auto_functionalize( foo, x, x, _arg0_base=x, _arg1_base = x,_ all_bases=[x,x]) ``` This is problematic because 1) x' now start with x and observe both mutations of first and second arg of foo, because _arg0_base=x and _arg1_base = x which changes the semantics of the function. and make it wrong even if do not re-inplace anything! 2) same for y' it starts with x but observe both mutations. **Solution** Make the relationships between the outputs and the args that we need to observe their mutations static and not changeable. ``` [x', y'] = auto_functionalize( foo, x, y, _all_bases=[x,y], _observe_mutation_from=[[arg0],[arg1]) ``` we use the addresses recorded at the time of the functionlization to know for each base, which args need to have their mutations observed 1) x' output starts with x and observe mutations of arg0 2) y' output starts with and observe mutations of arg1 if we do CSE we get ``` [x', y'] = auto_functionalize( foo, x, x, _all_bases=[x,x], _observe_mutation_from=[[arg0],[arg1]) ``` 1) x' output starts with x and observe mutations of arg0 2) y' output starts with x and observe mutations of arg1 this makes inductor/test_inplacing_pass.py::TestReinplacingPassCorrectness::test_multi_output_intermediate pass. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
|
closed for the favor of #134409 |
Stack from ghstack (oldest at bottom):
In the previous diff, auto_functionalize used to have the form
Which can be read as auto_functionalize
xthen it will observe the mutation from each arg that have arg_base =xIn this case the mutation on arg0.
Now one problem with the above is if a pass run and decided that y=x
then we get :
This is problematic because
Solution
Make the relationships between the outputs and the args that we need to observe their mutations static and not
changeable.
we use the addresses recorded at the time of the functionlization to know for each base, which args need to have their mutations observed
if we do CSE we get
this makes inductor/test_inplacing_pass.py::TestReinplacingPassCorrectness::test_multi_output_intermediate
pass.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec