Skip to content

[RFC][Pipelining] RNG state communication, avoid seed checkpoint #139304

@wconstab

Description

@wconstab

Background

During model initialization, each PP stage initializes their own portion of the model. This is preferrable over initializing the model all in one place and then transferring it to each stage, due to the high memory requirements of initializing an entire large model in one place.

Ignoring random states, each PP stage would start from the same RNG seed and each model chunk could have similar or even identical weights. This would likely lead to numerical issues during training. It also makes loss-curves of PP training less comparable with loss-curves of non-PP training, which lowers confidence in PP training and adds obstacles to debugging.

Currently, torchtitan implements a workaround where the whole model is initialized on CPU offline in one process, saved to disk, and then loaded by each pipeline stage. (Loading does not require loading the whole model into ram, only the portion used locally). This results in consistent model initialization but at a runtime and UX cost of generating and using the seed checkpoint.

Proposal

Instead, we can manage the RNG state automatically inside torch.pipelining, to avoid identical weights.

Option 1. Use different seeds per stage

We could simply re-seed the RNG on each PP stage, without worrying about actual initialization parity with non-PP training.

This option is potentially faster since it allows each stage to initialize in parallel. It should also gaurantee 'safe' initializations, since each stage would have no chance of duplicate or similar init weights. However, this initialization would not 'match' the initialization values for non-PP training, leading to less comparability between loss curves.

Option 2. Sequentially initialize

After stage 0 initializes, it could summarize its own RNG states and send them to stage 1, which in turn loads the states and initializes, then saves new states and sends them to stage 2, etc. This results in serializing initialization, which would lead to an observable slowdown at startup time, but also matches perfectly (bitwise) with the initialization values that would be observed in a non-PP training.

This option requires more complexity:

  • virtual-stages means more than one round of communication across PP ranks is needed
  • V-schedules means sometimes, neighboring stages exist on the same rank and thus seed communication should be skipped for those cases

Note: seed save/restore can be implemented like this: https://gist.github.com/wconstab/6bdea055eff9a7904fdae595c6cdac6e

It might be worth implementing both options and experimenting with the tradeoffs. I suspect most users would prefer option 1 if its significantly faster, and we can prove it leads to equally accurate training.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o @tianyu-l

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions