Skip to content

Make DDP failure recoverable #21344

@mrshenli

Description

@mrshenli

@kuttas @pietern and I had a discussion on how to make DDP failure recoverable. The process involves the following steps:

m = SomeModel()
dist.init_process_group()
ddp = DistributedDataParallel(m)
# got error
dist.destroy_process_group()
dist.init_process_group()
del ddp
ddp = DistributedDataParallel(m)

This does not work in today's DDP. Currently, to get better performance, DDP assigns the original module to the first module replica instead of creating a new one. Then, it creates a new Reducer to add post hooks to sync params. However, because every reconstructed DDP instance wraps the same original module, all their reducers will add hooks to the same set of variables. Hence, after 10 recoveries, each param (variable) in the original module will have 11 hooks introduced by 11 different reducers, where only the last one is still alive. We thought about several potential solutions:

Solution 1: Force Module Replication

Force DDP to create new module replicas instead of using the original module. In this way, those variables in the replicas will die together with the DDP instance. But it will make the DDP slower. Maybe make it an option?

Solution 2: Delete Hooks in Destructor

I feel the best way would be deleting those hooks from model variables when destructing a Reducer, but I didn't find a clean way to do that. The add_post_hook function takes unique parameters, and we can get those hooks through post_hooks. Directly looping through the the hooks vector and find the target to delete seems to be too hackish.

Solution 3: Create New Variables (?)

Not sure if this can work. Instead of creating replica (as in Solution 1), let DDP create a new variable for every parameter in the original module. All DDP forward and backward pass will use those new variables. I think this won't work if the application only wraps part of the model using DDP, because there will be two disjoint autograd graphs (?)

@soumith @gchanan @ezyang thoughts?

Metadata

Metadata

Assignees

Labels

module: autogradRelated to torch.autograd, and the autograd engine in generaloncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions