-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Background
In distributed training scenarios, RNG initialization matters for ensuring a correct model initialization, and in some cases also controlling random ops during training (e.g. dropout). For model initialization, the objective is to ensure that when different ranks initialize different shards of a layer (TP, FSDP) or different layers (PP), they get different values, and when different ranks initialize replicas of the same layer or the same shard (DDP or HSDP), they get the same value.
DTensor RNG Tracker
Currently, DTensor supports RNG management. There are some issues with this.
- It does not account for pipeline parallelism, and needs to be extended
- the design goals are over-ambitious (trying to enable matching the behavior with single-gpu models when this is not generally possible), leading to overly complex implementation
- the user-set RNG seeds are not (always) respected, and may be overwritten/ignored, depending on which RNG Tracker implementation DTensor selects (which is not user-visible, but depends on the degrees of parallelism used)
Proposed API
torch.distributed.manual_seed(seed, sharded_groups=[pp, tp, fsdp])
An api like this could solve several needs:
- initialize ranks in 'sharded_groups' with different seeds by adding different offset values to the user-provided seed per rank, while seeding ranks of other groups with the same seed (to handle replica cases)
- additionally,
sharded_dimsarg would be accepted for device-mesh users, only one used at a atime - order of sharded groups specifies the order of incrementing the seed-offset, giving control to the user, and enabling things like trying to match the seeding behavior of megatron or some existing trainer
To enable this, we should change a few things about DTensor
- avoid overwriting user-provided seed. Require the user to set the seed, but don't do it automatically
- don't broadcast the seed. This can lead to collective hangs in some cases or just unexpected / wrong seeding behavior, since DTensor broadcasts to the whole world rather than knowing which groups are sharded.
- Use one RNG tracker consistently. Avoid switching it depending on the parallelism. Rely on this new API for goals like matching megatron RNG behavior.
For the 'ordering' of seed offsets, this is what I had in mind
ordering of sharded groups
8 gpu: DDP, PP, TP,
DDP 0
PP0 PP1
TP0 TP1 TP0 TP1
[0, 1], [2, 3],
DDP 1
[4, 5], [6, 7]
manual_seed(123, sharded_groups=[PP, TP])
Rank Seed forumula= 123 + (pp_rank * tp_size=2 + tp_rank)
0 123 + 0 0 0
1 123 + 1 0 1
2 123 + 2 1 0
3 123 + 3 1 1
4 123 + 0
5 123 + 1
6 123 + 2
7 123 + 3
manual_seed(123, sharded_groups=[TP, PP])
Rank Seed forumula= 123 + (tp_rank * pp_size=2 + pp_rank)
0 123 + 0 0 0
1 123 + 2 1 0
2 123 + 1 0 1
3 123 + 3 1 1
4 123 + 0
5 123 + 2
6 123 + 1
7 123 + 3
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o