Reduce EMA RAM usage and training overhead with local-shard EMA#20
Reduce EMA RAM usage and training overhead with local-shard EMA#20chijw wants to merge 1 commit intothu-ml:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR optimizes EMA (Exponential Moving Average) memory usage in FSDP training by keeping EMA as local shards instead of materializing full parameters via summon_full_params() on every step. Full EMA state is only gathered at checkpoint save time.
Changes:
- Remove
summon_full_params()from EMA init/update/copy paths, operating on local shards instead - Add
full_state_dict()method that temporarily swaps EMA weights into the FSDP module to gather a full checkpoint viafsdp_state_dict() - Update all trainer
save()methods to usefull_state_dict(self.model.generator)instead ofstate_dict()
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| utils/distributed.py | Remove summon_full_params, add full_state_dict() for checkpoint export |
| long_video/utils/distributed.py | Mirror of above changes for long_video module |
| trainer/gan.py | Use full_state_dict() at save time |
| trainer/diffusion.py | Use full_state_dict() at save time |
| trainer/distillation.py | Use full_state_dict() at save time |
| trainer/naive_cd.py | Use full_state_dict() at save time |
| long_video/trainer/gan.py | Use full_state_dict() at save time |
| long_video/trainer/distillation.py | Use full_state_dict() at save time |
| long_video/trainer/diffusion.py | Use full_state_dict() at save time |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| for n, p in fsdp_module.module.named_parameters(): | ||
| if n in self.shadow: | ||
| p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device)) | ||
|
|
||
| checkpoint = fsdp_state_dict(fsdp_module) | ||
| shadow_checkpoint = {} | ||
| for n in self.shadow: | ||
| k = n | ||
| if k not in checkpoint and k.startswith("model._fsdp_wrapped_module."): | ||
| k = k.replace("model._fsdp_wrapped_module.", "model.", 1) | ||
| if k in checkpoint: | ||
| shadow_checkpoint[n] = checkpoint[k] | ||
| for n, p in fsdp_module.module.named_parameters(): | ||
| if n in live_state: | ||
| p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device)) | ||
|
|
| for n, p in fsdp_module.module.named_parameters(): | ||
| if n in self.shadow: | ||
| p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device)) | ||
|
|
||
| checkpoint = fsdp_state_dict(fsdp_module) | ||
| shadow_checkpoint = {} | ||
| for n in self.shadow: | ||
| k = n | ||
| if k not in checkpoint and k.startswith("model._fsdp_wrapped_module."): | ||
| k = k.replace("model._fsdp_wrapped_module.", "model.", 1) | ||
| if k in checkpoint: | ||
| shadow_checkpoint[n] = checkpoint[k] | ||
| for n, p in fsdp_module.module.named_parameters(): | ||
| if n in live_state: | ||
| p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device)) | ||
|
|
||
| return shadow_checkpoint |
| shadow_checkpoint = {} | ||
| for n in self.shadow: | ||
| k = n | ||
| if k not in checkpoint and k.startswith("model._fsdp_wrapped_module."): | ||
| k = k.replace("model._fsdp_wrapped_module.", "model.", 1) | ||
| if k in checkpoint: | ||
| shadow_checkpoint[n] = checkpoint[k] |
|
We really appreciate your help and dedication! This has resolved some long-standing efficiency issues. We will test your code within this week and merge it once we confirm everything works well. Thanks again, and best regards! |
|
Glad to help! |

Summary
This PR changes the EMA path so that EMA is maintained as local shards during training instead of materializing full parameters on every rank.
The previous implementation used
summon_full_params()in the EMA hot path, which adds unnecessary communication and keeps a full CPU EMA copy on each rank. With this change, each rank updates only its local EMA shard during training, which reduces both EMA memory usage and per-step overhead.To preserve the existing checkpoint format,
generator_emais still exported as a full state dict at save time. Since EMA is shard-local during training,full_state_dict()reuses the FSDP-wrapped module together withfsdp_state_dict()to gather the full checkpoint, instead of introducing a separate EMA-specific export path.Changes
EMA_FSDPshard-local during trainingsummon_full_params()from EMA init/update/copygenerator_emaonly at save timeself.generator_ema.full_state_dict(self.model.generator)to(dtype=..., device=...)when copying EMA tensors back, for better compatibility with newer PyTorch versions