-
Notifications
You must be signed in to change notification settings - Fork 26.3k
FX pass to move input mutations into submodule #82602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
❌ 14 New FailuresAs of commit d592f5d (more details on the Dr. CI page): Expand to see more
🕵️ 14 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
[ghstack-poisoned]
[ghstack-poisoned]
|
It would be really nice to see a before/after print |
| # these mutations into an opaque submodule | ||
| # so our graph infra can assume a functional graph. | ||
| if config.use_functionalize: | ||
| move_input_mutations_into_submodule(fw_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if you do this after, partition_fn must be able to deal with mutations in fx_g. Can it, @Chillee ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I re-remembered that it can't haha, so I'll have to change this. the min-cut partition function code calls eliminate_dead_code(), so we want the copy_() ops to be hidden away before then.
| module_node = fx_g.graph.call_module( | ||
| submodule_name, | ||
| args=tuple(node_to_placeholder.keys()), | ||
| kwargs=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we debated a bit possible representations, but this particular rep is actually not one that I had been thinking about. Mutation be in a submodule in the original graph is not great, because it still means the outer graph is not functional! (If you call a mutating function inside a graph, that makes your graph mutating.) I feel like there's an obligation for this submodule call to (somehow) not live in the Graph itself, as that makes it vulnerable to DCE again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmmm I see. The alternative is probably just to leave it out of the fx.Graph (but still keep it in the GraphModule), and then manually cal the submodule later on in AOTAutograd, after the compiled forward gets executed? That doesn't seem too bad.
I guess my thought was that anything that operates on the fx.Graph (DCE + compilers) would just see a custom submodule and know to treat it as an opaque object. But that's probably not totally right - we can't really enforce that a graph pass won't move the submodule around earlier in the graph (which would be wrong).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
although, if I want to be able to invoke the submodule separately outside of the graph, I'll need to somehow get the inputs for the submodule - probably by updating the original graph to make them additional outputs. Extra complexity (since I'll also need to remove the pytree stuff in the graph to do that), but doesn't seem too bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can always write passes so they operate correctly in the presence of mutation. But we don't want to force it on pass writers. Opaque submodules at final lowering are unlikely to be doing heavy optimization, so it's easier to deal with arbitrary side effects. But even, e.g., finding fusion groups, is currently not correct with mutating things (though mostly this is because of DCE calls). But the problem here is you're adding in this module all the way at the beginning, before all of the optimization passes, and so you're forcing them to, at the very least, know not to remove your module.
Extra outputs is what I would expect to see if it's external.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also have it not actually be external, but stored on the GraphModule and not actually part of the graph. But passes would need to know to propagate it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm yeah, having it actually be external feels better. I'll try that.
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule. Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?). I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59 [ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule. Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?). I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59 [ghstack-poisoned]
test/test_functionalization.py
Outdated
| input_clone = inpt.clone() | ||
| input_clone2 = inpt.clone() | ||
| input_clone3 = inpt.clone() | ||
| input_clone4 = inpt.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in this file are just minor QoL changes
|
Ok, I beefed up this PR (so mutations get taken out before parititioning) and added better testing. @Chillee lmk what you think of the AOTAutograd changes. I also dumped my tests in I think what we probably want is that once this lands, we should turn this + re-inplacing on by default in the benchmark testing on torchbench + timm + hugging face (cc @anijain2305). I'm still working on getting that set up so I can actually run it locally and confirm that nothing breaks. |
|
It looks like this pass isn't playing well with the partitioning code - it fails |
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule. Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?). I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59 [ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule. Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?). I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/82602
Note: Links to docs will display an error until the docs builds have been completed. ❌ 7 Failures, 2 PendingAs of commit 83bdef4: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
I'm pushing some more on this based on #85036, since adding an epilogue to AOTAutograd should unblock a few models that were previously hitting dynamo's fallback. This isn't ready for review yet though - waiting to sanity check some passes tests first. |
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule. Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?). I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59 [ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule. Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?). I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59 [ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule. Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?). I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59 [ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule. Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?). I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59 [ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule. Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?). I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59 [ghstack-poisoned]
|
@Chillee @ezyang this should be ready for another round. Some high level questions/notes: (a) What do you all think of the partitioner calling convention changes? (b) Should functionalization (and the same epilogue infra) be refactored to also run in the aot_dispatch_base() case? It looks like.. functionalization doesn’t even run there today (I guess I'm surprised that it isn't breaking anything?) (c) There are a bunch of test failures that are due to fake tensors not being turned on (we need to know requires_grad-ness of the outputs to know what to mark properly in the |
|
I'm deferring to @Chillee for this. |
|
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
| mutated_input_args = [x for pair in zip(original_inputs_needing_mutation, mutated_inputs) for x in pair] | ||
| # TODO: this epilogue should also be responsible for generating outputs | ||
| # that are aliases of inputs. | ||
| input_mutation_epilogue(*mutated_input_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So after the pass, all mutation would be in opaque mutation epilogue and backends lose visibility there.
We are missing out on fusion opportunities here.
Would we be able to opt-in to inline the mutation epilogue back into the main graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So after the pass, all mutation would be in opaque mutation epilogue
TBC, this is only for captured graphs with input mutations. Intermediate mutations in a graph would have already been removed by functionalization.
The idea in this PR as its stands is that instead of the backend seeing an operator like:
native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(a!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
It'll see a purely functional version, that returns the updated inputs instead of mutating them directly:
native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
And if mutated inputs to that operator happen to correspond to graph inputs (which is true for running_mean/running_var variables for batch norm), then at the end of the graph we'll have an epilogue that copies the "updated inputs" back to the original inputs:
def compiled_fn(inpt1, inpt2):
# ... first run the entirely functional compiled function
outs, mutated_inpt1, mutated_inpt2 = real_compiled_fn(inpt1, inpt2)
inpt1.copy_(mutated_inpt1)
inpt2.copy_(mutated_inpt2)
return outs
Would we be able to opt-in to inline the mutation epilogue back into the main graph?
I remember @Chillee bringing this up before - we probably can? Although for now, this PR doesn't do that and just ensures that we do the "correct" thing first.
It's also not clear to me - why would this prevent fusions? The main disadvantage as I see it is that you can end up using more memory - you have to keep the buffer for both the original input and the updated input around while the compiled function is running.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation, I think we are mostly on the same page here. I'm mostly asking for mutation on inputs, since that's what normalization layers uses for running stats update.
It's also not clear to me - why would this prevent fusions?
It's the epilogue in-place copies that we are missing out. Since those are cheap and easy to handle in normalization kernels. i.e. If we keep those in the epilogue, they won't be visible to fuser backend.
wconstab
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks correct to me. thanks @bdhirsh
|
I'm actually going to close this PR and create a fresh one. Thanks for the stamp and sorry about that Will. This is mostly because: (1) Ed's "trace with functionalization in one pass" PR has landed and changed how we want this PR to work pretty substantially - we no longer have to back (2) There are a bunch of other edge cases that are worth thinking about more holistically. Ed has a great doc on them here: https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit?usp=sharing |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing
copy_()ops in the graph. This PR does that by hiding it in an opaque submodule.Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the
copy_()nodes in the forward graph (which... probably needs some more testing, but I think is fine?).I added light testing for this pass by including it in the existing
test_functionalization.pytests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59Stack from ghstack: