-
Notifications
You must be signed in to change notification settings - Fork 26.3k
accumulate grads #19577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
accumulate grads #19577
Conversation
torch/nn/parallel/distributed.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to skip _sync_params as well? Otherwise, params in self.modules_params[0] will be broadcast to other replicas in every forward, which is not necessary if self.accumulate_grads=True, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli Right, I should skip that as well. And I should also document not calling optimizer.step when set to True
torch/nn/parallel/distributed.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would I be correct if I assume that optimizer.step() should not be called when accumulate_grads=True?
|
The common use case of this is to set it to |
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this option makes no sense to me.
You are doing a DistributedDataParallel but you are not actually synchronizing gradients across all replicas?
Then all you have is an ensemble of models, which you can do much easier without even using DDP.
|
Adding some background for this PR: sometimes, a batch of inputs is too large to fit into GPU, which has to be split into several micro-batches. However, to improve efficiency, it would be helpful to only apply params/gradients sync at original batch boundaries instead of micro-batch boundaries. Hence, @chenyangyu1988 would like to control when to skip/apply sync. |
f509c83 to
6914b7a
Compare
pietern
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do two things before we can land:
- Remove the initializer kwarg. This setting will be off by default. The caller will flip it around every couple of iterations, so they may as well do that right at the first iteration (the code they end up writing won't be different).
- Add a function
accumulate_grads(enable: bool), similar totrain(enable: bool). Users shouldn't be poking around in member variables directly but rather use the setter function to control this. This is how we control the API surface of the class.
6914b7a to
e7a8d7c
Compare
e7a8d7c to
65c6fdd
Compare
65c6fdd to
63068ba
Compare
63068ba to
74642c2
Compare
Summary: Pull Request resolved: pytorch#19577 add flag to enable accumulate grads locally, this is a critial way to increase batch_size in big model training. (e.g bound by GPU memory). This will also increase the training speed because it reduce the all-reduction call Differential Revision: D15035604 fbshipit-source-id: ae115cee8e6d1a8e84e936cd8ec8a3f988906cdf
74642c2 to
74f78b5
Compare
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a reasonable request, but I don't think we need this feature. You can emulate multiple forward passes by creating a wrapper model which takes a list of sub-batches to forward, runs the backward on each of them separately, and only then returns. Then, you wrap only this outer model in DistributedDataParallel.
Otherwise, this definitely shouldn't be a property, and a better implementation would make it a context manager. I'm thinking of something like this:
model = DistributedDataParallel(...)
def train(...)
for batch in ...:
with model.sync_only_on_exit():
for subbatch in batch:
model(subbatch).backward()But really, I'm not 100% convinced that we need this.
|
@apaszke Could you elaborate on how to creating a wrapper model which takes a list of sub-batches to forward, runs the backward on each of them separately, and only then returns? For the latest DistributedDataParallel, we need to call prepare_for_backward in its forward function, I am not sure how to do this with context_manager because when exit the context, all forward and backward already finished, there is no easy way to kick off gradients sync with current design. |
|
The context wrapper approach won't be able to overlap any reduction with the final backward pass, because we only know that we want to reduce again after the call to backward has already returned. Either way, we'll need some way to toggle this behavior on the # Alternative 1: simple but ugly and easy to screw up.
# Doesn't deal with edge condition here.
for index, batch in enumerate(batches):
model.accumulate_grads(index % 5 != 4) # True for [0, 3], False for [4]
model(batch).backward()
# Alternative 2: wrapper (TODO: find better name)
class AccumulateGradientsAcrossBatches(nn.Module):
def __init__(self, model, batch_size):
super().__init__()
self.model = model
self.batch_size = batch_size
def forward(self, megabatch):
batches = torch.chunk(megabatch, self.batch_size)
self.model.accumulate_grads(True)
for batch in batches[0:-1]:
self.model(batch).backward()
self.model.accumulate_grads(False)
return self.model(batches[-1])
model = AccumulateGradientsAcrossBatches(model, 32)
for megabatch in batches:
model(megabatch).backward() |
|
@soumith @apaszke Does @pietern's suggestion look good to you? @chenyangyu1988 If we all have consensus on that, I think you only need to change |
|
@mrshenli Sounds great, I already make it a function rather than a construction time argument :) |
|
|
||
| def forward(self, *inputs, **kwargs): | ||
| self._sync_params() | ||
| if not self.accumulate_grads: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that _sync_params needs to be called on the first mini-batch, while prepare_for_backward needs to be called on the last mini-batch. Do we need an additional arg? @pietern
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli I will make it only work for len(device_ids) = 1, to unblock us, and we could optimize it further with AccumulateGradientsAcrossBatches
|
Ok, I think my suggestion with the wrapper model might not work. Still, those things absolutely shouldn't be setters, because that's just too error prone. If something, they should be context managers like the one I proposed. An acceptable alternative which doesn't involve a context manager would be to add a parameter to the constructor which allows one to disable automatic gradient syncs, and then trigger them manually every time they are needed. So the code would look like that: model = DistributedDataParallel(..., auto_sync=False)
def train(...):
for batch in ...:
for subbatch in batch:
model(subbatch).backward()
model.sync_grads() |
|
@apaszke The DDP overlaps the gradients sync with the backward call to get more performance gain, so we probably need some ways to make it happen in the last backward call. |
|
@apaszke Can you elaborate why not a setter here, considering we have existing ones for e.g. I like your proposed alternative, but we need to find an easy and robust way to incorporate the overlapping of backward with gradient reduction. Before the last call to backward we need to let DDP know it needs to reduce the grads as they are computed. This could look something like this: model = DistributedDataParallel(..., auto_sync=False)
def train(...):
for batch in ...:
for index, subbatch in enumerate(batch):
output = model(subbatch)
if index == len(batch) - 1:
model.sync_grads(output)
output.backward()We could opt to store a copy of the model output in DDP so that you don't have to pass it back to |
|
@chenyangyu1988, are you still working on this? Accumulating gradients with DDP is an important feature that can make a big speed difference when training with large batches. |
|
@myleott @chenyangyu1988 I am picking this up. |
|
We got multiple requests on this feature. Let's revive the discussion and try to get some consensus. I agree with @apaszke that context manager is indeed more error prone. And I also agree with @pietern that it is not favorable to let DDP keep the outputs around until exiting the context. Given that people requesting this feature because they need to deal with tight memory limit, keeping the output a little longer might even hit OOM errors. So, how about let us just provide a The application would then look like: with ddp.no_sync():
for input in inputs:
ddp(input).backward()
ddp(one_more_input).backward()@myleott @chenyangyu1988 @soumith Does this API look OK? |
|
Yep that works for us. |
|
Proposed API is implemented in #21736 |
Summary: The first attempt and more discussions are available in #19577 #### Goal Allow toggling DDP gradient synchronization across iterations. With this feature, users may accumulate grads in module variables, and only kick off expensive grad synchronize every a few iterations. #### Concerns Our first attempt in #19577 tries to do it using a variable or a function. But apaszke made a good point that it will not be error prone, and favors a context manager instead. #### Proposed Solution Instead of providing a `accumulate_grads` variable/function/context, we provide a `DistributedDataParallel.no_sync()` context manager. And it does exactly what the name suggests, i.e., disable DDP grad synchronization within the context. Note that `accumulate_grads` means `no_sync` + no optimizer step, where the latter is not controlled by DDP. It is true that users need to call another `model(input).backward()` after exiting the context, and this is indeed more verbose. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This API should reaffirm the expected behavior, and does not mess up with other use cases if accumulating grads is not required.. The application would then look like: ```python with ddp.no_sync(): for input in inputs: ddp(input).backward() ddp(one_more_input).backward() optimizer.step() ``` chenyangyu1988 myleott Pull Request resolved: #21736 Differential Revision: D15805215 Pulled By: mrshenli fbshipit-source-id: 73405797d1e39965c52016af5cf45b15525ce21c
Summary: The first attempt and more discussions are available in pytorch#19577 #### Goal Allow toggling DDP gradient synchronization across iterations. With this feature, users may accumulate grads in module variables, and only kick off expensive grad synchronize every a few iterations. #### Concerns Our first attempt in pytorch#19577 tries to do it using a variable or a function. But apaszke made a good point that it will not be error prone, and favors a context manager instead. #### Proposed Solution Instead of providing a `accumulate_grads` variable/function/context, we provide a `DistributedDataParallel.no_sync()` context manager. And it does exactly what the name suggests, i.e., disable DDP grad synchronization within the context. Note that `accumulate_grads` means `no_sync` + no optimizer step, where the latter is not controlled by DDP. It is true that users need to call another `model(input).backward()` after exiting the context, and this is indeed more verbose. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This API should reaffirm the expected behavior, and does not mess up with other use cases if accumulating grads is not required.. The application would then look like: ```python with ddp.no_sync(): for input in inputs: ddp(input).backward() ddp(one_more_input).backward() optimizer.step() ``` chenyangyu1988 myleott Pull Request resolved: pytorch#21736 Differential Revision: D15805215 Pulled By: mrshenli fbshipit-source-id: 73405797d1e39965c52016af5cf45b15525ce21c
|
This was superseded by #21736. |
Summary:
add flag to enable accumulate grads locally, this is a critial way to increase batch_size in big model training. (e.g bound by GPU memory).
This will also increase the training speed because it reduce the all-reduction call
Differential Revision: D15035604