Skip to content

Conversation

@laithsakka
Copy link
Contributor

@laithsakka laithsakka commented Aug 23, 2024

Stack from ghstack (oldest at bottom):

In the previous diff, auto_functionalize used to have the form

[x', y'] = auto_functionalize( foo, x, y, _arg0_base=x, _arg1_base = y, _all_bases=[x,y])

Which can be read as auto_functionalize

  1. has two outputs len(all_bases)
  2. the first output starts with all_bases[0] x then it will observe the mutation from each arg that have arg_base =x
    In this case the mutation on arg0.
  3. same for the second output, it will be start with y and observe the mutations of the arg1.

Now one problem with the above is if a pass run and decided that y=x
then we get :

[x', y'] = auto_functionalize( foo, x, x, _arg0_base=x, _arg1_base = x,_ all_bases=[x,x])

This is problematic because

  1. x' now start with x and observe both mutations of first and second arg of foo, because _arg0_base=x and _arg1_base = x which changes the semantics of the function. and make it wrong even if do not re-inplace anything!
  2. same for y' it starts with x but observe both mutations.

Solution
Make the relationships between the outputs and the args that we need to observe their mutations static and not
changeable.

[x', y'] = auto_functionalize( foo, x, y, _all_bases=[x,y], _observe_mutation_from=[[arg0],[arg1])

we use the addresses recorded at the time of the functionlization to know for each base, which args need to have their mutations observed

  1. x' output starts with x and observe mutations of arg0
  2. y' output starts with and observe mutations of arg1

if we do CSE we get

[x', y'] = auto_functionalize( foo, x, x, _all_bases=[x,x], _observe_mutation_from=[[arg0],[arg1])
  1. x' output starts with x and observe mutations of arg0
  2. y' output starts with x and observe mutations of arg1

this makes inductor/test_inplacing_pass.py::TestReinplacingPassCorrectness::test_multi_output_intermediate
pass.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

…atic and not depending on new inputs passed

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134315

Note: Links to docs will display an error until the docs builds have been completed.

❌ 35 New Failures, 1 Unrelated Failure

As of commit eb59152 with merge base 938f37b (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

laithsakka added a commit that referenced this pull request Aug 23, 2024
…atic and not depending on new inputs passed

ghstack-source-id: 4e96e0e
Pull Request resolved: #134315
@laithsakka laithsakka requested a review from zou3519 August 23, 2024 06:38
@laithsakka laithsakka changed the title make relationship between outputs and inputs of auto_functionalize static and not depending on new inputs passed Make relationship between outputs and inputs of auto_functionalize static and not depending on new inputs passed Aug 23, 2024
@laithsakka laithsakka marked this pull request as draft August 23, 2024 14:20
@laithsakka
Copy link
Contributor Author

ah
I will change
[x', y'] = auto_functionalize( foo, x, y, _arg0_base_adr=1, _arg1_base_adrs = 2, _all_bases=[x,y], _all_bases_adrs=[1,2])
to
[x', y'] = auto_functionalize( foo, x, y, _all_bases=[x,y], observe mutation_from=[[arg0],[arg1]])

why :

  1. more readable
  2. easier for unit tests since addresses are not the same across runs .

…ionalize static and not depending on new inputs passed"


In the previous diff, auto_functionalize used to have the form
```
[x', y'] = auto_functionalize( foo, x, y, _arg0_base=x, _arg1_base = y, _all_bases=[x,y])
```
Which can be read as auto_functionalize
1)  has two outputs len(all_bases)
2) the first output starts with all_bases[0] `x` then it will observe the mutations each arg that have arg_base =`x`
In this case the mutation on x. (which is read either from copy(x) that is passed to foo).  
3) same for the second output, it will be start with y and observe the mutations of the arg y. 


Now one problem with the above is if a pass run and decided that y=x
then we get :
```
[x', y'] = auto_functionalize( foo, x, x, _arg0_base=x, _arg1_base = x,_ all_bases=[x,x])
```
This is problematic because 
1) x' now start with x and observe both mutations of first and second arg of foo, because _arg0_base=x and _arg1_base = x which changes the semantics of the function. and make it wrong even if do no re-inplace anything!
2) same for y' not it starts with x but observe both mutations. 

**Solution**
Make the relationships between the outputs and the args that we need to observe their mutations static and not
changeable. 
```
[x', y'] = auto_functionalize( foo, x, y, _all_bases=[x,y], _observe_mutation_from=[[arg0],[arg1])
```
we use the addresses recorded at the time of the functionlization to know for each base, which args need to have their mutations observed 
1) x' output starts with x and observe mutations of arg0
2) y' output starts with  and observe mutations of arg1

if we do CSE we get 
```
[x', y'] = auto_functionalize( foo, x, x, _all_bases=[x,x], _observe_mutation_from=[[arg0],[arg1])
```
1) x' output starts with x and observe mutations of arg0
2) y' output starts with x and observe mutations of arg1

this makes inductor/test_inplacing_pass.py::TestReinplacingPassCorrectness::test_multi_output_intermediate
 pass. 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Aug 23, 2024
…atic and not depending on new inputs passed

ghstack-source-id: 8fd2c38
Pull Request resolved: #134315
@laithsakka laithsakka marked this pull request as ready for review August 23, 2024 15:26
…ionalize static and not depending on new inputs passed"


In the previous diff, auto_functionalize used to have the form
```
[x', y'] = auto_functionalize( foo, x, y, _arg0_base=x, _arg1_base = y, _all_bases=[x,y])
```
Which can be read as auto_functionalize
1)  has two outputs len(all_bases)
2) the first output starts with all_bases[0] `x` then it will observe the mutation from each arg that have arg_base =`x`
In this case the mutation on arg0.
3) same for the second output, it will be start with y and observe the mutations of the arg1. 


Now one problem with the above is if a pass run and decided that y=x
then we get :
```
[x', y'] = auto_functionalize( foo, x, x, _arg0_base=x, _arg1_base = x,_ all_bases=[x,x])
```
This is problematic because 
1) x' now start with x and observe both mutations of first and second arg of foo, because _arg0_base=x and _arg1_base = x which changes the semantics of the function. and make it wrong even if do not re-inplace anything!
2) same for y' it starts with x but observe both mutations. 

**Solution**
Make the relationships between the outputs and the args that we need to observe their mutations static and not
changeable. 
```
[x', y'] = auto_functionalize( foo, x, y, _all_bases=[x,y], _observe_mutation_from=[[arg0],[arg1])
```
we use the addresses recorded at the time of the functionlization to know for each base, which args need to have their mutations observed 
1) x' output starts with x and observe mutations of arg0
2) y' output starts with  and observe mutations of arg1

if we do CSE we get 
```
[x', y'] = auto_functionalize( foo, x, x, _all_bases=[x,x], _observe_mutation_from=[[arg0],[arg1])
```
1) x' output starts with x and observe mutations of arg0
2) y' output starts with x and observe mutations of arg1

this makes inductor/test_inplacing_pass.py::TestReinplacingPassCorrectness::test_multi_output_intermediate
 pass. 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Aug 23, 2024
…atic and not depending on new inputs passed

ghstack-source-id: d6542c0
Pull Request resolved: #134315
@laithsakka laithsakka changed the title Make relationship between outputs and inputs of auto_functionalize static and not depending on new inputs passed Make relationship between outputs and inputs of auto_functionalize static and not depending on addresses of new inputs passed Aug 23, 2024
…ionalize static and not depending on addresses of new inputs passed"


In the previous diff, auto_functionalize used to have the form
```
[x', y'] = auto_functionalize( foo, x, y, _arg0_base=x, _arg1_base = y, _all_bases=[x,y])
```
Which can be read as auto_functionalize
1)  has two outputs len(all_bases)
2) the first output starts with all_bases[0] `x` then it will observe the mutation from each arg that have arg_base =`x`
In this case the mutation on arg0.
3) same for the second output, it will be start with y and observe the mutations of the arg1. 


Now one problem with the above is if a pass run and decided that y=x
then we get :
```
[x', y'] = auto_functionalize( foo, x, x, _arg0_base=x, _arg1_base = x,_ all_bases=[x,x])
```
This is problematic because 
1) x' now start with x and observe both mutations of first and second arg of foo, because _arg0_base=x and _arg1_base = x which changes the semantics of the function. and make it wrong even if do not re-inplace anything!
2) same for y' it starts with x but observe both mutations. 

**Solution**
Make the relationships between the outputs and the args that we need to observe their mutations static and not
changeable. 
```
[x', y'] = auto_functionalize( foo, x, y, _all_bases=[x,y], _observe_mutation_from=[[arg0],[arg1])
```
we use the addresses recorded at the time of the functionlization to know for each base, which args need to have their mutations observed 
1) x' output starts with x and observe mutations of arg0
2) y' output starts with  and observe mutations of arg1

if we do CSE we get 
```
[x', y'] = auto_functionalize( foo, x, x, _all_bases=[x,x], _observe_mutation_from=[[arg0],[arg1])
```
1) x' output starts with x and observe mutations of arg0
2) y' output starts with x and observe mutations of arg1

this makes inductor/test_inplacing_pass.py::TestReinplacingPassCorrectness::test_multi_output_intermediate
 pass. 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Aug 25, 2024
…atic and not depending on new inputs passed

ghstack-source-id: de3c24e
Pull Request resolved: #134315
@laithsakka
Copy link
Contributor Author

closed for the favor of #134409

@laithsakka laithsakka closed this Aug 26, 2024
@github-actions github-actions bot deleted the gh/laithsakka/49/head branch October 1, 2024 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant