Skip to content

Conversation

@chenyangyu1988
Copy link
Contributor

Summary:
add flag to enable accumulate grads locally, this is a critial way to increase batch_size in big model training. (e.g bound by GPU memory).
This will also increase the training speed because it reduce the all-reduction call

Differential Revision: D15035604

@pytorchbot pytorchbot added the module: nn Related to torch.nn label Apr 22, 2019
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to skip _sync_params as well? Otherwise, params in self.modules_params[0] will be broadcast to other replicas in every forward, which is not necessary if self.accumulate_grads=True, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli Right, I should skip that as well. And I should also document not calling optimizer.step when set to True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would I be correct if I assume that optimizer.step() should not be called when accumulate_grads=True?

@pietern
Copy link
Contributor

pietern commented Apr 22, 2019

The common use case of this is to set it to False every couple of iterations, so I think passing a default value for the constructor is not that useful. Instead, we can assume the default is False, and that users will have to call a function model.accumulate_grads(enable: bool) every iteration to control what to do.

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this option makes no sense to me.
You are doing a DistributedDataParallel but you are not actually synchronizing gradients across all replicas?
Then all you have is an ensemble of models, which you can do much easier without even using DDP.

@mrshenli
Copy link
Contributor

Adding some background for this PR: sometimes, a batch of inputs is too large to fit into GPU, which has to be split into several micro-batches. However, to improve efficiency, it would be helpful to only apply params/gradients sync at original batch boundaries instead of micro-batch boundaries. Hence, @chenyangyu1988 would like to control when to skip/apply sync.

Copy link
Contributor

@pietern pietern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do two things before we can land:

  1. Remove the initializer kwarg. This setting will be off by default. The caller will flip it around every couple of iterations, so they may as well do that right at the first iteration (the code they end up writing won't be different).
  2. Add a function accumulate_grads(enable: bool), similar to train(enable: bool). Users shouldn't be poking around in member variables directly but rather use the setter function to control this. This is how we control the API surface of the class.

Summary:
Pull Request resolved: pytorch#19577

add flag to enable accumulate grads locally, this is a critial way to increase batch_size in big model training. (e.g bound by GPU memory).
This will also increase the training speed because it reduce the all-reduction call

Differential Revision: D15035604

fbshipit-source-id: ae115cee8e6d1a8e84e936cd8ec8a3f988906cdf
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a reasonable request, but I don't think we need this feature. You can emulate multiple forward passes by creating a wrapper model which takes a list of sub-batches to forward, runs the backward on each of them separately, and only then returns. Then, you wrap only this outer model in DistributedDataParallel.

Otherwise, this definitely shouldn't be a property, and a better implementation would make it a context manager. I'm thinking of something like this:

model = DistributedDataParallel(...)

def train(...)
  for batch in ...:
    with model.sync_only_on_exit():
      for subbatch in batch:
        model(subbatch).backward()

But really, I'm not 100% convinced that we need this.

@chenyangyu1988
Copy link
Contributor Author

@apaszke Could you elaborate on how to creating a wrapper model which takes a list of sub-batches to forward, runs the backward on each of them separately, and only then returns?

For the latest DistributedDataParallel, we need to call prepare_for_backward in its forward function, I am not sure how to do this with context_manager because when exit the context, all forward and backward already finished, there is no easy way to kick off gradients sync with current design.

@pietern
Copy link
Contributor

pietern commented May 2, 2019

The context wrapper approach won't be able to overlap any reduction with the final backward pass, because we only know that we want to reduce again after the call to backward has already returned.

Either way, we'll need some way to toggle this behavior on the DistributedDataParallel class, because it controls how it executes its forward function. This doesn't need to be a construction time parameter because you would toggle it anyway. A setter function would be nice though. @mrshenli and I just chatted about this and I think there are two approaches possible here. One that doesn't look nice and is easy to screw, and one that is better looking and harder to screw up. @apaszke I think your first suggestion maps to this second approach, can you confirm? This approach does require you to add the loss computation to your forward function, so the call to backward can happen in line.

# Alternative 1: simple but ugly and easy to screw up.
# Doesn't deal with edge condition here.
for index, batch in enumerate(batches):
  model.accumulate_grads(index % 5 != 4) # True for [0, 3], False for [4]
  model(batch).backward()

# Alternative 2: wrapper (TODO: find better name)
class AccumulateGradientsAcrossBatches(nn.Module):
  def __init__(self, model, batch_size):
    super().__init__()
    self.model = model
    self.batch_size = batch_size
    
  def forward(self, megabatch):
    batches = torch.chunk(megabatch, self.batch_size)
    self.model.accumulate_grads(True)
    for batch in batches[0:-1]:
      self.model(batch).backward()
    self.model.accumulate_grads(False)
    return self.model(batches[-1])

model = AccumulateGradientsAcrossBatches(model, 32)
for megabatch in batches:
  model(megabatch).backward()

@mrshenli
Copy link
Contributor

mrshenli commented May 2, 2019

@soumith @apaszke Does @pietern's suggestion look good to you?

@chenyangyu1988 If we all have consensus on that, I think you only need to change accumulate_grads from a construction time argument to become a function, which should be sufficient to unblock you. We will add AccumulateGradientsAcrossBatches later (maybe under a different name).

@chenyangyu1988
Copy link
Contributor Author

@mrshenli Sounds great, I already make it a function rather than a construction time argument :)


def forward(self, *inputs, **kwargs):
self._sync_params()
if not self.accumulate_grads:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that _sync_params needs to be called on the first mini-batch, while prepare_for_backward needs to be called on the last mini-batch. Do we need an additional arg? @pietern

Copy link
Contributor Author

@chenyangyu1988 chenyangyu1988 May 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli I will make it only work for len(device_ids) = 1, to unblock us, and we could optimize it further with AccumulateGradientsAcrossBatches

@apaszke
Copy link
Contributor

apaszke commented May 8, 2019

Ok, I think my suggestion with the wrapper model might not work. Still, those things absolutely shouldn't be setters, because that's just too error prone. If something, they should be context managers like the one I proposed.

An acceptable alternative which doesn't involve a context manager would be to add a parameter to the constructor which allows one to disable automatic gradient syncs, and then trigger them manually every time they are needed. So the code would look like that:

model = DistributedDataParallel(..., auto_sync=False)

def train(...):
  for batch in ...:
    for subbatch in batch:
      model(subbatch).backward()
    model.sync_grads()

@chenyangyu1988
Copy link
Contributor Author

@apaszke The DDP overlaps the gradients sync with the backward call to get more performance gain, so we probably need some ways to make it happen in the last backward call.

@pietern
Copy link
Contributor

pietern commented May 9, 2019

@apaszke Can you elaborate why not a setter here, considering we have existing ones for e.g. train(bool)? I agree they can be error prone, cause problems w.r.t. boundary conditions, etc, etc.

I like your proposed alternative, but we need to find an easy and robust way to incorporate the overlapping of backward with gradient reduction. Before the last call to backward we need to let DDP know it needs to reduce the grads as they are computed. This could look something like this:

model = DistributedDataParallel(..., auto_sync=False)

def train(...):
  for batch in ...:
    for index, subbatch in enumerate(batch):
      output = model(subbatch)
      if index == len(batch) - 1:
        model.sync_grads(output)
      output.backward()

We could opt to store a copy of the model output in DDP so that you don't have to pass it back to sync_grads, but that could have unintended side effects by keeping the output alive longer than expected.

@myleott
Copy link

myleott commented Jun 4, 2019

@chenyangyu1988, are you still working on this? Accumulating gradients with DDP is an important feature that can make a big speed difference when training with large batches.

@mrshenli
Copy link
Contributor

@myleott @chenyangyu1988 I am picking this up.

@mrshenli
Copy link
Contributor

mrshenli commented Jun 12, 2019

We got multiple requests on this feature. Let's revive the discussion and try to get some consensus.

I agree with @apaszke that context manager is indeed more error prone.

And I also agree with @pietern that it is not favorable to let DDP keep the outputs around until exiting the context. Given that people requesting this feature because they need to deal with tight memory limit, keeping the output a little longer might even hit OOM errors.

So, how about let us just provide a DistributedDataParallel.no_sync() context manager? It does not sync at all within the entire context. And it is true that users need to call another model(input).backward() after exiting the context. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This is indeed more verbose, but should reaffirm the expected behavior.

The application would then look like:

with ddp.no_sync():
  for input in inputs:
    ddp(input).backward()

ddp(one_more_input).backward()

@myleott @chenyangyu1988 @soumith Does this API look OK?

@myleott
Copy link

myleott commented Jun 13, 2019

Yep that works for us.

@mrshenli
Copy link
Contributor

Proposed API is implemented in #21736

facebook-github-bot pushed a commit that referenced this pull request Jun 20, 2019
Summary:
The first attempt and more discussions are available in #19577

#### Goal

Allow toggling DDP gradient synchronization across iterations. With this feature, users may accumulate grads in module variables, and only kick off expensive grad synchronize every a few iterations.

#### Concerns

Our first attempt in #19577 tries to do it using a variable or a function. But apaszke made a good point that it will not be error prone, and favors a context manager instead.

#### Proposed Solution

Instead of providing a `accumulate_grads` variable/function/context, we provide a `DistributedDataParallel.no_sync()` context manager. And it does exactly what the name suggests, i.e., disable DDP grad synchronization within the context. Note that `accumulate_grads` means `no_sync` + no optimizer step, where the latter is not controlled by DDP.

It is true that users need to call another `model(input).backward()` after exiting the context, and this is indeed more verbose. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This API should reaffirm the expected behavior, and does not mess up with other use cases if accumulating grads is not required..

The application would then look like:

```python
with ddp.no_sync():
  for input in inputs:
    ddp(input).backward()

ddp(one_more_input).backward()
optimizer.step()
```

chenyangyu1988 myleott
Pull Request resolved: #21736

Differential Revision: D15805215

Pulled By: mrshenli

fbshipit-source-id: 73405797d1e39965c52016af5cf45b15525ce21c
iotamudelta pushed a commit to ROCm/pytorch that referenced this pull request Jun 21, 2019
Summary:
The first attempt and more discussions are available in pytorch#19577

#### Goal

Allow toggling DDP gradient synchronization across iterations. With this feature, users may accumulate grads in module variables, and only kick off expensive grad synchronize every a few iterations.

#### Concerns

Our first attempt in pytorch#19577 tries to do it using a variable or a function. But apaszke made a good point that it will not be error prone, and favors a context manager instead.

#### Proposed Solution

Instead of providing a `accumulate_grads` variable/function/context, we provide a `DistributedDataParallel.no_sync()` context manager. And it does exactly what the name suggests, i.e., disable DDP grad synchronization within the context. Note that `accumulate_grads` means `no_sync` + no optimizer step, where the latter is not controlled by DDP.

It is true that users need to call another `model(input).backward()` after exiting the context, and this is indeed more verbose. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This API should reaffirm the expected behavior, and does not mess up with other use cases if accumulating grads is not required..

The application would then look like:

```python
with ddp.no_sync():
  for input in inputs:
    ddp(input).backward()

ddp(one_more_input).backward()
optimizer.step()
```

chenyangyu1988 myleott
Pull Request resolved: pytorch#21736

Differential Revision: D15805215

Pulled By: mrshenli

fbshipit-source-id: 73405797d1e39965c52016af5cf45b15525ce21c
@pietern
Copy link
Contributor

pietern commented Jul 10, 2019

This was superseded by #21736.

@pietern pietern closed this Jul 10, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants