-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Don't keep unnecessary saved_inputs alive #16583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This greatly improves memory efficiency of certain ops like Dropout2d. Previously, they were implemented as `input * mask` where mask never requires_grad, but we didn't use that knowledge in forward, and (in case of a in-place dropout) kept input.clone() for the backward, when it would simply get ignored. This patch tries to address this situation by emitting some guards for stores like this, but only if they are as simple, as checking if a single value requires_grad.
Idea: save |
|
This looks awesome. A before & after comparison on generated code would be great! :) |
|
Oh yes, I forgot to show an example change. Here's the case of Old code: ensor & VariableType::mul_(Tensor & self, const Tensor & other) const {
...
if (compute_requires_grad( self, other )) {
grad_fn = std::shared_ptr<MulBackward0>(new MulBackward0(), deleteFunction);
grad_fn->set_next_edges(collect_next_edges( self, other ));
grad_fn->self_ = SavedVariable(self.clone(), false);
grad_fn->other_ = SavedVariable(other, false);
}
...
}New code: Tensor & VariableType::mul_(Tensor & self, const Tensor & other) const {
...
if (compute_requires_grad( self, other )) {
grad_fn = std::shared_ptr<MulBackward0>(new MulBackward0(), deleteFunction);
grad_fn->set_next_edges(collect_next_edges( self, other ));
if (grad_fn->should_compute_output(1)) {
grad_fn->self_ = SavedVariable(self.clone(), false);
}
if (grad_fn->should_compute_output(0)) {
grad_fn->other_ = SavedVariable(other, false);
}
}
...
} |
|
TensorGeometry seems like a nice idea, I'll try that in a next patch! |
|
Failures are either timeouts, or are data loader multiprocess tests, which I doubt are affected by this change. |
|
I'm guessing this doesn't also magically fix #15115, but I'm going to check anyway. |
|
oh, it appears it does fix #15115. |
| # In the end the memory usage should remain equal, because neither of | ||
| # (x + 2) and ((x + 2) * m) should be kept alive for backward, while the | ||
| # previous allocation of z had the same size as the current one. | ||
| self.assertEqual(base_mem, end_mem) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll add that test, but I still think this one is robust and tests an important thing. The reason why I took up this patch is because we were using on the order of gigabytes more memory for CNNs that used in-place Dropout2d (which is effectively input * mask) which should never happen!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, note that this patch doesn't change the unpacking behavior. So we will still run all the unpacks, except that variables that weren't saved will unpack as undefined tensors (which is fine, because they won't be used anyway).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, I'm not saying one test or the other is better, I'm saying they test different things and should both be tested.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
I don't know why there are failing tests. |
|
I don't see any tests failing. Where are those? Can you rerun them? It used to be green. |
|
I see 4 failing tests:
|
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Trying rebase again. |
|
@apaszke this has consistently timed out on rocm builds, across a few rebases. Can you take a look? |
|
How can I debug this? It adds a bit of code in |
|
@apaszke sorry I should have said RoCM tests (the error message says build, but it's on a test). Here is the latest example: |
|
Still, do we have instructions that would let me reproduce a ROCm build? |
|
@pytorchbot retest this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes pytorch#16577. This greatly improves memory efficiency of certain ops like Dropout2d. Previously, they were implemented as `input * mask` where mask never requires_grad, but we didn't use that knowledge in forward, and (in case of a in-place dropout) kept input.clone() for the backward, when it would simply get ignored. This patch tries to address this situation by emitting some guards for stores like this, but only if they are as simple, as checking if a single value requires_grad. Interestingly, the same optimizations apply to methods like bmm, baddmm, etc., but _not to mm nor addmm_, because of how their derivatives are defined. Apparently they unnecessarily use `mat1` to compute the derivative of `mat1` just to improve the error message in case `mat1` was sparse. I'd like to apply this optimization to that case, but I don't want to loose the nicer error message, so if anyone has any ideas for solutions, please let me know... Full list of operators affected by this patch: * _nnpack_spatial_convolution * addbmm * addcdiv * addcmul * addmv * addr * baddbmm * bmm * cross * div * dot * fmod * ger * index_add_ * mul * mv * scatter_add_ Pull Request resolved: pytorch#16583 Differential Revision: D13900881 Pulled By: gchanan fbshipit-source-id: dd0aeb2ab58c4b6aa95b37b46d3255b3e014291c
Fixes #16577.
This greatly improves memory efficiency of certain ops like Dropout2d. Previously, they were implemented as
input * maskwhere mask never requires_grad, but we didn't use that knowledge in forward, and (in case of a in-place dropout) kept input.clone() for the backward, when it would simply get ignored.This patch tries to address this situation by emitting some guards for stores like this, but only if they are as simple, as checking if a single value requires_grad.
Interestingly, the same optimizations apply to methods like bmm, baddmm, etc., but not to mm nor addmm, because of how their derivatives are defined. Apparently they unnecessarily use
mat1to compute the derivative ofmat1just to improve the error message in casemat1was sparse. I'd like to apply this optimization to that case, but I don't want to loose the nicer error message, so if anyone has any ideas for solutions, please let me know...Full list of operators affected by this patch: