-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[invoke_subgraph] Support None in the fwd output #150082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150082
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit df4afb4 with merge base 15dbad2 ( UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| def test_return_none_from_fwd(self): | ||
| @mark_compile_region | ||
| def gn(x): | ||
| return x * 2, None, x * 3 | ||
|
|
||
| def fn(x): | ||
| ys = gn(x) | ||
| return ys[0] + ys[2] | ||
|
|
||
| opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we see an expecttest for the graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added for both Dynamo and AOT fwd and bwd
[ghstack-poisoned]
|
A subgraph returning None is totally ok with Inductor. I worked on plumbing that last year. Here is the output code from Inductor for the test case - https://www.internalfb.com/phabricator/paste/view/P1768286550 The approach here -
|
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
| # force the grad_outs to be contiguous. Some of the grads can be None, | ||
| # because the forward outs could be None. Filter them out. | ||
| contiguous_grad_outs = [] | ||
| for o in grad_outs: | ||
| if o is not None: | ||
| contiguous_grad_outs.append(o.contiguous()) | ||
| contiguous_grad_outs = tuple(contiguous_grad_outs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you assert that the only None grad_outs are the ones where the forward was None? Because this happens during tracing of the outer graph, this won't increase runtime.
The code here is only correct if we do not change ctx.set_materialize_grads (default True). In the future, I expect we'll want to set it to False to improve performance, which will lead to the following correctness issue. The assertion will help us catch these issues when we flip it.
- An output being None does imply that the grad_out is None.
- However, at trace time of the outer forward+backward, if ctx.set_materialize_grads=False it is possible that an out to invoke_subgraph is a Tensor but a grad_out is None. This happens if the gradient for that Tensor was never computed or if autograd optimized it away (because it thought it was zero).
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but added a comment for an extra assertion we should add
[ghstack-poisoned]
|
Starting merge as part of PR stack under #150450 |
…ass (#150450) Pull Request resolved: #150450 Approved by: https://github.com/zou3519 ghstack dependencies: #150082
…alse (#150486) I am not sure if this is the right way. Pull Request resolved: #150486 Approved by: https://github.com/zou3519 ghstack dependencies: #150082, #150450
…t module (#150556) Pull Request resolved: #150556 Approved by: https://github.com/bdhirsh, https://github.com/zou3519 ghstack dependencies: #150082, #150450, #150486
…150561) I am unable to come up with a testcase. It passes many end-to-end tests that fail with ReshapeError at https://ossci-raw-job-status.s3.amazonaws.com/log/39717218372  Pull Request resolved: #150561 Approved by: https://github.com/zou3519, https://github.com/bdhirsh ghstack dependencies: #150082, #150450, #150486, #150556
…ass (pytorch#150450) Pull Request resolved: pytorch#150450 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#150082
…alse (pytorch#150486) I am not sure if this is the right way. Pull Request resolved: pytorch#150486 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#150082, pytorch#150450
…t module (pytorch#150556) Pull Request resolved: pytorch#150556 Approved by: https://github.com/bdhirsh, https://github.com/zou3519 ghstack dependencies: pytorch#150082, pytorch#150450, pytorch#150486
…ytorch#150561) I am unable to come up with a testcase. It passes many end-to-end tests that fail with ReshapeError at https://ossci-raw-job-status.s3.amazonaws.com/log/39717218372  Pull Request resolved: pytorch#150561 Approved by: https://github.com/zou3519, https://github.com/bdhirsh ghstack dependencies: pytorch#150082, pytorch#150450, pytorch#150486, pytorch#150556
ghstack-source-id: abb58a3 Pull Request resolved: pytorch/pytorch#150082
Stack from ghstack (oldest at bottom):