-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
@kuttas @pietern and I had a discussion on how to make DDP failure recoverable. The process involves the following steps:
m = SomeModel()
dist.init_process_group()
ddp = DistributedDataParallel(m)
# got error
dist.destroy_process_group()
dist.init_process_group()
del ddp
ddp = DistributedDataParallel(m)This does not work in today's DDP. Currently, to get better performance, DDP assigns the original module to the first module replica instead of creating a new one. Then, it creates a new Reducer to add post hooks to sync params. However, because every reconstructed DDP instance wraps the same original module, all their reducers will add hooks to the same set of variables. Hence, after 10 recoveries, each param (variable) in the original module will have 11 hooks introduced by 11 different reducers, where only the last one is still alive. We thought about several potential solutions:
Solution 1: Force Module Replication
Force DDP to create new module replicas instead of using the original module. In this way, those variables in the replicas will die together with the DDP instance. But it will make the DDP slower. Maybe make it an option?
Solution 2: Delete Hooks in Destructor
I feel the best way would be deleting those hooks from model variables when destructing a Reducer, but I didn't find a clean way to do that. The add_post_hook function takes unique parameters, and we can get those hooks through post_hooks. Directly looping through the the hooks vector and find the target to delete seems to be too hackish.
Solution 3: Create New Variables (?)
Not sure if this can work. Instead of creating replica (as in Solution 1), let DDP create a new variable for every parameter in the original module. All DDP forward and backward pass will use those new variables. I think this won't work if the application only wraps part of the model using DDP, because there will be two disjoint autograd graphs (?)