Skip to content

Reduce EMA RAM usage and training overhead with local-shard EMA#20

Closed
chijw wants to merge 1 commit intothu-ml:mainfrom
chijw:main
Closed

Reduce EMA RAM usage and training overhead with local-shard EMA#20
chijw wants to merge 1 commit intothu-ml:mainfrom
chijw:main

Conversation

@chijw
Copy link
Copy Markdown

@chijw chijw commented Mar 15, 2026

Summary

This PR changes the EMA path so that EMA is maintained as local shards during training instead of materializing full parameters on every rank.

The previous implementation used summon_full_params() in the EMA hot path, which adds unnecessary communication and keeps a full CPU EMA copy on each rank. With this change, each rank updates only its local EMA shard during training, which reduces both EMA memory usage and per-step overhead.

To preserve the existing checkpoint format, generator_ema is still exported as a full state dict at save time. Since EMA is shard-local during training, full_state_dict() reuses the FSDP-wrapped module together with fsdp_state_dict() to gather the full checkpoint, instead of introducing a separate EMA-specific export path.

Changes

  • keep EMA_FSDP shard-local during training
  • remove summon_full_params() from EMA init/update/copy
  • export full generator_ema only at save time
  • switch trainer save paths to use self.generator_ema.full_state_dict(self.model.generator)
  • use to(dtype=..., device=...) when copying EMA tensors back, for better compatibility with newer PyTorch versions

Copilot AI review requested due to automatic review settings March 15, 2026 03:51
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes EMA (Exponential Moving Average) memory usage in FSDP training by keeping EMA as local shards instead of materializing full parameters via summon_full_params() on every step. Full EMA state is only gathered at checkpoint save time.

Changes:

  • Remove summon_full_params() from EMA init/update/copy paths, operating on local shards instead
  • Add full_state_dict() method that temporarily swaps EMA weights into the FSDP module to gather a full checkpoint via fsdp_state_dict()
  • Update all trainer save() methods to use full_state_dict(self.model.generator) instead of state_dict()

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
utils/distributed.py Remove summon_full_params, add full_state_dict() for checkpoint export
long_video/utils/distributed.py Mirror of above changes for long_video module
trainer/gan.py Use full_state_dict() at save time
trainer/diffusion.py Use full_state_dict() at save time
trainer/distillation.py Use full_state_dict() at save time
trainer/naive_cd.py Use full_state_dict() at save time
long_video/trainer/gan.py Use full_state_dict() at save time
long_video/trainer/distillation.py Use full_state_dict() at save time
long_video/trainer/diffusion.py Use full_state_dict() at save time

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +125 to +140
for n, p in fsdp_module.module.named_parameters():
if n in self.shadow:
p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device))

checkpoint = fsdp_state_dict(fsdp_module)
shadow_checkpoint = {}
for n in self.shadow:
k = n
if k not in checkpoint and k.startswith("model._fsdp_wrapped_module."):
k = k.replace("model._fsdp_wrapped_module.", "model.", 1)
if k in checkpoint:
shadow_checkpoint[n] = checkpoint[k]
for n, p in fsdp_module.module.named_parameters():
if n in live_state:
p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device))

Comment on lines +125 to +141
for n, p in fsdp_module.module.named_parameters():
if n in self.shadow:
p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device))

checkpoint = fsdp_state_dict(fsdp_module)
shadow_checkpoint = {}
for n in self.shadow:
k = n
if k not in checkpoint and k.startswith("model._fsdp_wrapped_module."):
k = k.replace("model._fsdp_wrapped_module.", "model.", 1)
if k in checkpoint:
shadow_checkpoint[n] = checkpoint[k]
for n, p in fsdp_module.module.named_parameters():
if n in live_state:
p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device))

return shadow_checkpoint
Comment on lines +130 to +136
shadow_checkpoint = {}
for n in self.shadow:
k = n
if k not in checkpoint and k.startswith("model._fsdp_wrapped_module."):
k = k.replace("model._fsdp_wrapped_module.", "model.", 1)
if k in checkpoint:
shadow_checkpoint[n] = checkpoint[k]
@zhuhz22
Copy link
Copy Markdown
Member

zhuhz22 commented Mar 15, 2026

We really appreciate your help and dedication! This has resolved some long-standing efficiency issues. We will test your code within this week and merge it once we confirm everything works well. Thanks again, and best regards!

@zhuhz22
Copy link
Copy Markdown
Member

zhuhz22 commented Mar 19, 2026

We have adopted your improvements and sincerely appreciate it. We did not directly auto-merge this pull request because we wanted to retain the original code for the Rolling Forcing section in the long_video subdirectory for copyright reasons.
image

@zhuhz22 zhuhz22 closed this Mar 19, 2026
@chijw
Copy link
Copy Markdown
Author

chijw commented Mar 19, 2026

Glad to help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants