Skip to content

[RFC] DDP Communication Hook #39272

@pritamdamania87

Description

@pritamdamania87

Motivation

There are several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while running Distributed DataParallel training. The high level idea here is to provide a flexible hook to users where they can specify how DDP aggregates gradients across multiple workers. Apart from allowing experimentation with ideas mentioned above, this hook would be very useful for researchers to try out new ideas.

Proposal

Provide a hook that does all the necessary processing for communication and returns a Future indicating the completion of the task

// DDP APIs
ddp.register_comm_hook(
    state: object, 
    hook: callable)

// DDP combines multiple parameters into a bucket before doing an allreduce.
class GradBucket:
  Tensor bucket_data
  // Optionally in the future this can be enhanced with parameter to bucket 
  // mappings as well.
  

// The state object is passed to the hook and can be used 
// to maintain and update any state information that users would like to 
// maintain as part of the training process. Examples: error feedback in 
// gradient compression, peers to communicate with next in GossipGrad etc. 

// Hook signature
// The hook is passed the state object and the DDP grad bucket. 
// This function is called once the bucket is ready. The hook can perform
// whatever processing is needed and return a Future indicating 
// completion of any async work (ex: allreduce). If the hook doesn't 
// perform any communication, it can also just return a completed Future.
// The Future should hold the new value of the bucket.
// Once a bucket is ready, c10d reducer would call this hook and use the 
// bucket returned by the Future as the new bucket and would use this 
// bucket to copy grads to individual parameters.
def ddp_comm_hook(state: object, bucket: GradBucket): -> Future

Examples

Gradient Compression

ddp.register_comm_hook(
    state=None, 
    hook: fp16_compress)
    
def fp16_compress(state: object, bucket: Tensor): -> Future
    compressed_bucket = bucket.to(torch.float16)
    state.compressed_buckets[bucket] = compressed_bucket
    work = dist.allreduce(compressed_bucket)
    # NOTE: We also provide an API called "get_future" to retrieve a future
    associated with the completion of c10d.ProcessGroupNCCL.work.
    allreduce_future =work.get_future()
    
    def decompress(fut):
        compressed_bucket = fut.wait()
        return compressed_bucket.to(torch.float32)
        
    return allreduce_future.then(decompress) 

GossipGrad

class GossipGradState:
    def get_next_peers() -> List[Peer]
    # Modifies passed in grad in-place after gossiping. Returns a future 
    # indicating completion of the work.
    def gossip_grads(peers: List[Peer], grad: Tensor) -> Future
    
ddp.register_comm_hook(
    state=GossipGradState(...), 
    hook: gossip_grad_hook)
    
def gossip_grad_hook(state: object, bucket: Tensor): -> Future
    peers = state.get_next_peers()
    return gossip_grads(peers, bucket)

Sub-tasks:

Warning: get_future API is currently only supported by NCCL backend. Please refer #42048 for gloo support.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar

Metadata

Metadata

Assignees

Labels

enhancementNot as big of a feature, but technically not a bug. Should be easy to fixoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions