Skip to content

[RFC] Should DDP support custom reduction logic for registered module buffers? #63041

@rohan-varma

Description

@rohan-varma

🚀 Feature

Support custom logic (i.e. collective reduction) for registered module buffers as part of DDP training. Currently, when buffers are registered as part of a model, DDP will broadcast them from rank 0 or do nothing if self.broadcast_buffers=False. However, we have seen some use cases (such as here), as well as a request internally where broadcast is not sufficient and user would like to have more control over the reduction.

Motivation

The main motivation is that users would like to register certain tensors as buffers so that they are saved as part of state_dict to be a part of model checkpoint, but semantically broadcast doesn't make sense for these buffers. One example is tracking the no. of positive/negative examples a rank has seen and using that in computing a weighted loss, where we may want to allreduce this ratio instead of broadcasting rank 0's.

Pitch

@pritamdamania87 brought up an idea for a communication hook (similar to existing gradient communication hooks) that allow users to specify custom buffer reductions. It will default to broadcast to maintain compatibility with the current version. Because different buffers may require different reduction, it may have to be a per-buffer lambda specified by the user. A rough sketch of how this can work:

class MyModel():
    def __init__(self):
        self.my_buf = register_buffer(...)

ddp = DDP(MyModel())
ddp.register_buffer_comm_hook(ddp.my_buf, lambda buf: dist.all_reduce(buf))

def register_buffer_comm_hook(self, buf, fn):
    assert buf in self.buffers()
    self._buffer_hooks[buf] = fn

def _sync_params(self):
    # default to broadcast_coalesced if _buffer_hooks is empty
    for buf in self.modules_buffers:
        self._buffer_hooks.get(buf, lambda buf: dist.broadcast(buf, 0))(buf)

Performance consideration

We currently coalesce all buffers and broadcast them instead of broadcasting them individually, which reduces the no. of collective comms. we launch and allows us to do it all in one go. With these sort of custom hooks, it is unlikely we will be able to do such a coalesced operation because each buffer could have different reduction logic, which might impact the performance if user chooses to use this feature.

Alternatives

Instead of hooks which gives more control, it might be simpler to offer setting a strategy on how to reduce each buffer, such as broadcast, allreduce, etc. This may be simpler but is less configurable.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu @gcramer23

Metadata

Metadata

Assignees

Labels

module: ddpIssues/PRs related distributed data parallel trainingoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions