-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Background
During model initialization, each PP stage initializes their own portion of the model. This is preferrable over initializing the model all in one place and then transferring it to each stage, due to the high memory requirements of initializing an entire large model in one place.
Ignoring random states, each PP stage would start from the same RNG seed and each model chunk could have similar or even identical weights. This would likely lead to numerical issues during training. It also makes loss-curves of PP training less comparable with loss-curves of non-PP training, which lowers confidence in PP training and adds obstacles to debugging.
Currently, torchtitan implements a workaround where the whole model is initialized on CPU offline in one process, saved to disk, and then loaded by each pipeline stage. (Loading does not require loading the whole model into ram, only the portion used locally). This results in consistent model initialization but at a runtime and UX cost of generating and using the seed checkpoint.
Proposal
Instead, we can manage the RNG state automatically inside torch.pipelining, to avoid identical weights.
Option 1. Use different seeds per stage
We could simply re-seed the RNG on each PP stage, without worrying about actual initialization parity with non-PP training.
This option is potentially faster since it allows each stage to initialize in parallel. It should also gaurantee 'safe' initializations, since each stage would have no chance of duplicate or similar init weights. However, this initialization would not 'match' the initialization values for non-PP training, leading to less comparability between loss curves.
Option 2. Sequentially initialize
After stage 0 initializes, it could summarize its own RNG states and send them to stage 1, which in turn loads the states and initializes, then saves new states and sends them to stage 2, etc. This results in serializing initialization, which would lead to an observable slowdown at startup time, but also matches perfectly (bitwise) with the initialization values that would be observed in a non-PP training.
This option requires more complexity:
- virtual-stages means more than one round of communication across PP ranks is needed
- V-schedules means sometimes, neighboring stages exist on the same rank and thus seed communication should be skipped for those cases
Note: seed save/restore can be implemented like this: https://gist.github.com/wconstab/6bdea055eff9a7904fdae595c6cdac6e
It might be worth implementing both options and experimenting with the tradeoffs. I suspect most users would prefer option 1 if its significantly faster, and we can prove it leads to equally accurate training.
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o @tianyu-l