-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Re-design functionalization to minimize miss in-placing when args are views. #133045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133045
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 1505e8b with merge base 938f37b ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…en args are views." Previous functionalization fails to re-inplace arguments when they are view over other tensors. see issue #131192 The new functionalization is easier to re-inplace when views are used it works as the following: **A) Functionalizations pass** consider a program ``` t = [..] y = view(x) z = view (x) foo(y, z) // where y and z are mutable inputs to custom op foo return (t, y, z) ``` 1) When we perform functionalization we track for each arg its base (if its a view, otherwise we consider itself to be itself base). we add those to a list _all_bases=[t] in this case and pass it to ``auto_functionalize`' 2) We also add an arg for each mutated arg that maps the arg to it base, _y_base= t, and _z_base =t in the example above. 3) The new output of ``auto_functionalize``, is the args in _all_bases. 4) If the mutated input is an array of tensors, then _x_base=[b1, b2, b3] is an array of bases. 5) We inform the cpp functionlization that t = t' has changed, and ask it whenever it encounters a a view over t later in the code then it needs to regenerate the view from the new t' value. This is done by calling replace(t, t'), commit_update(t) and sync(t) sequence. For example for the function above the program after the functionalization is: ``` t = [..] y = t[0] z = t[1] t' = auto_functionalize(foo, y, z, _y_base=t , z_base=t, _all_bases=t) y' = t'[0] z' = t'[1] return (t', y', z') ``` **B) Semantics of auto_functionalize** The new semantics of auto_functionalize is as the following: ``` 1. copy all mutated inputs into new variables input' 2. call the custom op on the copies of the mutated inputs. 3. for each base in _all_base base' = base for each mutated input' if _input_base != base continue if base' is input : base' = input' elseif input is view on base' base' = alias'.as_strided_scatter(input',input'.size(), input'.stride(), input'.storage_offset()) 4. return [base'...] ## all the new base' values ``` **C) Re-inplace pass** the following changes are applied to re-inplace pass: 1. if the argument to a function is a view of another input, we do not reinplace if we do not find the copy node that reflects the mutation to the base. 2. for each argument we can inplace if non of the argument or any of its aliases, is used after the auto_functionalize (before the copy node if exists ). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…en args are views." Previous functionalization fails to re-inplace arguments when they are view over other tensors. see issue #131192 The new functionalization is easier to re-inplace when views are used it works as the following: **A) Functionalizations pass** consider a program ``` t = [..] y = view(x) z = view (x) foo(y, z) // where y and z are mutable inputs to custom op foo return (t, y, z) ``` 1) When we perform functionalization we track for each arg its base (if its a view, otherwise we consider itself to be itself base). we add those to a list _all_bases=[t] in this case and pass it to ``auto_functionalize`' 2) We also add an arg for each mutated arg that maps the arg to it base, _y_base= t, and _z_base =t in the example above. 3) The new output of ``auto_functionalize``, is the args in _all_bases. 4) If the mutated input is an array of tensors, then _x_base=[b1, b2, b3] is an array of bases. 5) We inform the cpp functionlization that t = t' has changed, and ask it whenever it encounters a a view over t later in the code then it needs to regenerate the view from the new t' value. This is done by calling replace(t, t'), commit_update(t) and sync(t) sequence. For example for the function above the program after the functionalization is: ``` t = [..] y = t[0] z = t[1] t' = auto_functionalize(foo, y, z, _y_base=t , z_base=t, _all_bases=t) y' = t'[0] z' = t'[1] return (t', y', z') ``` **B) Semantics of auto_functionalize** The new semantics of auto_functionalize is as the following: ``` 1. copy all mutated inputs into new variables input' 2. call the custom op on the copies of the mutated inputs. 3. for each base in _all_base base' = base for each mutated input' if _input_base != base continue if base' is input : base' = input' elseif input is view on base' base' = alias'.as_strided_scatter(input',input'.size(), input'.stride(), input'.storage_offset()) 4. return [base'...] ## all the new base' values ``` **C) Re-inplace pass** the following changes are applied to re-inplace pass: 1. if the argument to a function is a view of another input, we do not reinplace if we do not find the copy node that reflects the mutation to the base. 2. for each argument we can inplace if non of the argument or any of its aliases, is used after the auto_functionalize (before the copy node if exists ). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
|
there is one failure which i fix in |
| def forward(self, x): | ||
| cos = torch.ops.aten.cos.default(x) | ||
| auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None | ||
| auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos, _x_base = x, _z_base = cos, _all_bases = [x, cos]); x = cos = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need _x_base, _z_base, and _all_bases ? Is't _x_base, _z_base_ enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am changing it to
[x', y'] = auto_functionalize( foo, x, y, _all_bases=[x,y], _observe_mutation_from=[[arg0],[arg1])
see https://github.com//pull/134315
initially i used to extract the relationships dynamically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah back to your question, _all_bases removes the repetitions and specify the order of the outputs.
| def __call__( | ||
| self, | ||
| /, | ||
| self_, # noqa: B902 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local linter i guess will reverse it .
| _only_clone_these_tensors: Optional[Tuple[str, ...]] = None, | ||
| **kwargs: Any, | ||
| ) -> Tuple[Any, Tuple[Tensor, ...]]: | ||
| _all_bases: List[Tensor] = kwargs.pop("_all_bases", []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized we need to handle the following case in this PR: if the base has no view relationship with the tensor (maybe due to a graph pass), then we don't know what the right thing to do is. How do we decide what the right arguments to as_strided_scatter are?
So, in addition to recording x_base for each x, we also need to either:
- record x.stride() and x.storage_offset().
- record the view chain that produced x_base from x.
The former is easier so I'd recommend that.
This needs to happen in this PR (sorry I didn't realize earlier)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries lets chat about this and about the change that i made to this. to handle test_multi_output_intermediate when you have time. I would like to understand better this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Under this design we wouldn't need to pass x anymore (it doesn't matter what's in x!). Instead we can pass (x_sizes, x_strides, x_storage_offset) instead of x.
Also, the semantics of auto_functionalized become:
x_base = maybe_clone(x_base) # if we can't reinplace it
x = x_base.as_strided(x_sizes, x_strides, x_storage_offset)
mutable_op(x)But we also want to somehow not emit an as_strided when it's unnecessary
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be back to review the rest, but we might need to change how the auto_functionalized node works (+ had some comments to simplify it).
| basis[arg_name].append(None) | ||
| continue | ||
|
|
||
| base = tensor if tensor._base is None else tensor._base |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check if tensor._base is available under torch.inference_mode(). If it's not then... we're gonna need to add more infra to keep track of bases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with torch.inference_mode():
x = torch.randn(3)
f(x)
…Inline and add options to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
…ptions to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
|
closed for the favor of #134409 |
…kip comments and to skip empty lines (#134248) I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in Pull Request resolved: #134248 Approved by: https://github.com/aorenste ghstack dependencies: #133639, #134364
…Inline and add options to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
…ptions to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
…Inline and add options to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
…ptions to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
…Inline and add options to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
…ptions to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
…Inline and add options to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
…ptions to skip comments and to skip empty lines" I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in [ghstack-poisoned]
…kip comments and to skip empty lines (#134248) I had a night mare rewriting tests in test_misc.py specifically : 1. graphs can have comments that refers to my files "/lsakka/.." we really dont care about comments add option to ignore comments. 2. empty lines added when EXPECTTEST_ACCEPT=1 are changed with linter causing tests to fail or linter fail! add flag to ignore empty lines. 3. EXPECTTEST_ACCEPT fails when the text have some not readable characters. those should not effect comparing strings, also those causes weird diffs comments when tests fails. I removed ansi_escape chars #133045 this is used in Pull Request resolved: #134248 Approved by: https://github.com/aorenste ghstack dependencies: #133639, #134364
Stack from ghstack (oldest at bottom):
Previous functionalization fails to re-inplace arguments when they are view over other tensors.
see issue #131192
The new functionalization is easier to re-inplace when views are used it works as the following:
A) Functionalizations pass
consider a program
When we perform functionalization we track for each arg its base (if its a view, otherwise we consider itself to be
itself base). we add those to a list _all_bases=[t] in this case and pass it to ``auto_functionalize`'
We also add an arg for each mutated arg that maps the arg to it base, _y_base= t, and _z_base =t in the example above.
The new output of
auto_functionalize, is the args in _all_bases.If the mutated input is an array of tensors, then _x_base=[b1, b2, b3] is an array of bases.
We inform the cpp functionlization that t = t' has changed, and ask it whenever it encounters a a view over t later in the code then it needs to regenerate the view from the new t' value. This is done by calling replace(t, t'), commit_update(t) and sync(t) sequence.
For example for the function above the program after the functionalization is:
B) Semantics of auto_functionalize
The new semantics of auto_functionalize is as the following:
C) Re-inplace pass
the following changes are applied to re-inplace pass:
if we do not find the copy node that reflects the mutation to the base.
after the auto_functionalize (before the copy node if exists ).
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec