-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 Feature
Propose a few built-in local SGD algorithms for parameter averaging that can complement the existing DDP communication hook for gradient averaging. Furthermore, we can provide an interface similar to DDP communication hook for OSS users to enable their own model averaging approaches. Another choice is encapsulating model averaging algorithms as optimizers.
Background
Compared with SGD runs on a single process, a key step needed for parallel synchronous SGD in PyTorch is updating the global parameters at each step by allreduce. Specifically, this step can be executed as either gradient averaging or parameter averaging (a.k.a., model averaging). Currently PyTorch DDP can only support gradient averaging. As also explained in Section 2.2 of DDP VLDB paper, this is mainly because 1) parameter averaging can be detrimental to model accuracy due to the potential divergence of optimizer states especially if momentum is involved; and 2) gradient averaging can overlap with the computation in backward pass for performance optimization, since it is kicked off earlier (during loss.backward()) than parameter averaging (after optimizer.step()).
Moreover, in order to optimize communication cost and support larger-scale trainings, DDP communication hook (#39272) is designed to support flexible gradient averaging strategies (e.g., gradient compression), which can be more bandwidth-efficient than vanilla allreduce. To leverage the communication hook feature, the user can either choose a built-in hook (e.g., FP16 compression, PowerSGD) or implement a customized hook.
Time to Revisit the Design Choice
When the design choice between gradient averaging and parameter averaging was made, it seems that only the advantages of gradient averaging were noted. If we revisit this design choice, we can also find some interesting advantages of parameter averaging.
If we disregard the potential divergence of optimizer states possibly caused by parameter averaging, actually a key premise of the mathematical equivalence between gradient averaging and parameter averaging is that, full synchronization among all the processes must be carried out at each step — given the same initial global parameters, the same global gradients can lead to the same global parameter delta, and hence the same global updated parameters. In other words, if the training does not necessarily sync every step, gradient averaging can be problematic, and parameter averaging can be our only choice.
Therefore, just like the motivation of DDP communication hook, if we aim to optimize communication of distributed training, we should also exploit the advantages of parameter averaging over gradient averaging. If DDP can also support parameter averaging as another kind of communication interface, it can open up a huge communication optimization opportunity called local SGD, by decreasing the synchronization frequency.
What’s Local SGD?
As the algorithm below shows, local SGD allows the model in each replica to evolve independently and hence only averages the parameters every T step(s). Particularly, there can be two extreme cases: 1) if T = 1, local SGD will be equivalent to classical parallel SGD that fully syncs at every step; 2) if T = total number of steps, local SGD will be equivalent to one-shot averaging (https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.374.1458&rep=rep1&type=pdf) that only syncs at the last step. Besides extensive empirical study, many recent works have also provided solid convergence analysis of local SGD and its variants under different assumptions, e.g., strong/weekly convex, and non-convex settings.


From “Local SGD Converges Fast and Communicates Little”
Advantages of Local SGD
Reducing Communication Cost Without Additional Computation Overheads
Clearly, if local SGD runs allreduce every T steps, ideally the total communication cost can be reduced by T times. Compared with many other approaches (e.g., gradient compression) that can also significantly reduce the communication cost, a prominent advantage of local SGD is that, it often does not incur any non-trivial additional computation overheads. This can be critical for its practicability, given that our DDP implementations can largely overlap communication and computation.
Additionally, it should be noted that local SGD does not necessarily conflict with compression approaches. Compressing parameters in local SGD may further improve the performance .
Smoothing Out Variation in Computation Costs on Different Workers
Ideally every DDP worker runs the same amount of computation, and hence each worker should should consume the same amount of computation time before the synchronization point at every step. However, as the figure below shows, in practice there can be often some variations in computation time, and the overall computation time is always determined by the slowest worker, which is not necessarily the same one across steps.

From “Adaptive Communication Strategy in Local-Update SGD”
Since local SGD can consolidate the computations for multiple steps before communication, even if we disregard the reduction in communication cost, it can also reduce the computation time by smoothing out the variations in per-step computation time, even if communication is not a major performance bottleneck.
Natural Bonding With Federated Learning
Local SGD has quite a few similarities with federated learning (https://en.wikipedia.org/wiki/Federated_learning): both allow the model on each worker to evolve independently for multiple steps and avoid frequent synchronization. Therefore, we may be able to leverage some efforts from federated learning (e.g., parameter compression before synchronization across workers, auxiliary steps when averaging parameters for accuracy) for further optimizations. (Perhaps certain gradient compression methods can be also adapted for parameter compression.)
What About No Sync Context Manager?
Note that the usage of DDP no_sync() context manager in the tutorial cannot meet the requirements here, because it must separate optimizer.step() out of the training loop, but local SGD runs step() inside the training loop.
with ddp.no_sync():
for input in inputs:
ddp(input).backward()
ddp(one_more_input).backward()
# With no_sync context manager, step() is not applied for each iteration.
local_sgd_optimizer.step()
Steps
As the first step, we can to provide multiple built-in algorithms for model averaging and run some experimentations. Later we can encapsulate these algorithms as optimizers.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu @gcramer23