Skip to content

[RFC] Distributed manual_seed API for N-D parallelism #140301

@wconstab

Description

@wconstab

Background

In distributed training scenarios, RNG initialization matters for ensuring a correct model initialization, and in some cases also controlling random ops during training (e.g. dropout). For model initialization, the objective is to ensure that when different ranks initialize different shards of a layer (TP, FSDP) or different layers (PP), they get different values, and when different ranks initialize replicas of the same layer or the same shard (DDP or HSDP), they get the same value.

DTensor RNG Tracker

Currently, DTensor supports RNG management. There are some issues with this.

  • It does not account for pipeline parallelism, and needs to be extended
  • the design goals are over-ambitious (trying to enable matching the behavior with single-gpu models when this is not generally possible), leading to overly complex implementation
  • the user-set RNG seeds are not (always) respected, and may be overwritten/ignored, depending on which RNG Tracker implementation DTensor selects (which is not user-visible, but depends on the degrees of parallelism used)

Proposed API

torch.distributed.manual_seed(seed, sharded_groups=[pp, tp, fsdp])

An api like this could solve several needs:

  • initialize ranks in 'sharded_groups' with different seeds by adding different offset values to the user-provided seed per rank, while seeding ranks of other groups with the same seed (to handle replica cases)
  • additionally, sharded_dims arg would be accepted for device-mesh users, only one used at a atime
  • order of sharded groups specifies the order of incrementing the seed-offset, giving control to the user, and enabling things like trying to match the seeding behavior of megatron or some existing trainer

To enable this, we should change a few things about DTensor

  • avoid overwriting user-provided seed. Require the user to set the seed, but don't do it automatically
  • don't broadcast the seed. This can lead to collective hangs in some cases or just unexpected / wrong seeding behavior, since DTensor broadcasts to the whole world rather than knowing which groups are sharded.
  • Use one RNG tracker consistently. Avoid switching it depending on the parallelism. Rely on this new API for goals like matching megatron RNG behavior.

For the 'ordering' of seed offsets, this is what I had in mind

ordering of sharded groups
8 gpu: DDP, PP, TP,

DDP 0

  PP0     PP1
TP0 TP1 TP0 TP1
[0, 1], [2, 3], 


DDP 1
[4, 5], [6, 7]

manual_seed(123, sharded_groups=[PP, TP])

Rank     Seed        forumula= 123 + (pp_rank * tp_size=2 + tp_rank)
0        123 + 0                       0                     0
1        123 + 1                       0                     1
2        123 + 2                       1                     0
3        123 + 3                       1                     1
4        123 + 0
5        123 + 1
6        123 + 2
7        123 + 3


manual_seed(123, sharded_groups=[TP, PP])

Rank     Seed        forumula= 123 + (tp_rank * pp_size=2 + pp_rank)
0        123 + 0                       0                     0
1        123 + 2                       1                     0
2        123 + 1                       0                     1
3        123 + 3                       1                     1
4        123 + 0
5        123 + 2
6        123 + 1
7        123 + 3

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: c10dIssues/PRs related to collective communications and process groupsoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions