Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Nov 1, 2024

Stack from ghstack (oldest at bottom):

Add new 'seeded_module_init' API to PipelineScheduleMulti

  • note: skipped for PipelineScheduleSingle for now. Could add it there
    too, but i want to delete that class and unify

Example:

schedule.seeded_module_init(
   # required arg, can be any fn that accepts 'module' as arg;
   # can be ModuleClass.init_weights
   my_initializer_fn,

   # optional args
   mode="serial",
   seed=1234
)

Implement both 'serial' and 'parallel' options described in #139304

'parallel' mode sets a unique seed per rank, but does not worry about
parity with non-PP initialization. It is fastest becuase it
parallelizes initialization and uses no communication.

'serial' mode does initialization one stage at a time, communicating the
RNG states from rank to rank as needed so that the resulting
initialization is equivalent to non-PP initialization where the whole
model is initialized on one rank.

The overall initialization is delegated to the 'schedule' since it knows which
stages it owns, which ranks they are on, and what order they go in.
But the per-stage initialization is handled by a user-passed function
that can do anything it wants to the given module owned by that stage.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139445

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 16503c4 with merge base failed to retrieve merge base, please contact dev infra:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 1, 2024
wconstab added a commit that referenced this pull request Nov 1, 2024
Implement both 'serial' and 'parallel' options described in #139304

'parallel' mode sets a unique seed per rank, but does not worry about
parity with non-PP initialization.  It is fastest becuase it
parallelizes initialization and uses no communication.

'serial' mode does initialization one stage at a time, communicating the
RNG states from rank to rank as needed so that the resulting
initialization is equivalent to non-PP initialization where the whole
model is initialized on one rank.

The overall initialization is delegated to the 'schedule' since it knows which
stages it owns, which ranks they are on, and what order they go in.
But the per-stage initialization is handled by a user-passed function
that can do anything it wants to the given module owned by that stage.

TODO: add tests and see if this actually works. See whether 'serial'
mode is slow for large PP setups.

ghstack-source-id: 83f0cae
Pull Request resolved: #139445
@wconstab wconstab added release notes: distributed (pipeline) release notes category module: pipelining Pipeline Parallelism labels Nov 1, 2024
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Nov 1, 2024
Implement both 'serial' and 'parallel' options described in #139304

'parallel' mode sets a unique seed per rank, but does not worry about
parity with non-PP initialization.  It is fastest becuase it
parallelizes initialization and uses no communication.

'serial' mode does initialization one stage at a time, communicating the
RNG states from rank to rank as needed so that the resulting
initialization is equivalent to non-PP initialization where the whole
model is initialized on one rank.

The overall initialization is delegated to the 'schedule' since it knows which
stages it owns, which ranks they are on, and what order they go in.
But the per-stage initialization is handled by a user-passed function
that can do anything it wants to the given module owned by that stage.

TODO: add tests and see if this actually works. See whether 'serial'
mode is slow for large PP setups.

ghstack-source-id: ddcde6f
Pull Request resolved: #139445
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Nov 1, 2024
Add new 'seeded_module_init' API to PipelineScheduleMulti
- note: skipped for PipelineScheduleSingle for now. Could add it there
  too, but i want to delete that class and unify

Example:

```
schedule.seeded_module_init(
   # required arg, can be any fn that accepts 'module' as arg;
   # can be ModuleClass.init_weights
   my_initializer_fn,

   # optional args
   mode="serial",
   seed=1234
)
```

Implement both 'serial' and 'parallel' options described in #139304

'parallel' mode sets a unique seed per rank, but does not worry about
parity with non-PP initialization.  It is fastest becuase it
parallelizes initialization and uses no communication.

'serial' mode does initialization one stage at a time, communicating the
RNG states from rank to rank as needed so that the resulting
initialization is equivalent to non-PP initialization where the whole
model is initialized on one rank.

The overall initialization is delegated to the 'schedule' since it knows which
stages it owns, which ranks they are on, and what order they go in.
But the per-stage initialization is handled by a user-passed function
that can do anything it wants to the given module owned by that stage.

ghstack-source-id: 35ce7ed
Pull Request resolved: #139445
"Simply stop passing it, and everything should still work fine."
)

def seeded_module_init(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#bikeshed_me
"seeded_module_init"? meh.
I first used 'initialize_stage' but i realized there is a preexisting '_initialize_stage' API right below that is confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to also change the order to let the naming begins with the verb (init) to be consistent with _initialize_stage? Not very important nit.


def seeded_module_init(
self,
stage_initializer: Callable[[torch.nn.Module], None],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should 'stage_initializer' be an optional argument? I could check stage.submod for the presence of an 'init_weights' method, and if present, use that as the default. Otherwise, if 'stage_initializer' is not provided, I would raise an error.

# need to know the size/dtype of the state tensors so might as well get current state then overwrite
state_cpu = torch.get_rng_state().to(device)
state_cuda = torch.cuda.get_rng_state(device=device).to(device)
torch.distributed.recv(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it overkill to mess with the CPU rng state? I included it since in theory, someone could use CPU side to calculate something and then copy that to GPU for their tensor init.

f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}"
)

def stage_global_rank(peer_rank):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: just moved this function from a derived class up to the base class, so it can be used from everywhere.

@kwen2501
Copy link
Collaborator

kwen2501 commented Nov 1, 2024

Does "parallel" mode need help from torch.pipelining?

@wconstab
Copy link
Contributor Author

wconstab commented Nov 1, 2024

Does "parallel" mode need help from torch.pipelining?

not strictly. It only needs to know 'pp_rank' to set different seeds per pp-rank, and then do init as usual. However, once adding an API to pipelining, I think the only use case for 'parallel' would be to get back some performance if serial was too slow, and having it be a flag-flip seems nicer than having to edit your code to do init differently.

@kwen2501
Copy link
Collaborator

kwen2501 commented Nov 1, 2024

I don't have problem landing the proposed as experimental API. The description in the RFC indicates that a pipeline schedule can be of some help for serialized init and hence why attaching a util here. (Although, it could be still debatable whether pipeline schedule -- a "run" thing -- should get caught in model init, but anyway). The reason for suggesting making it experimental is so that we can be off the hook of maintenance for now.

That said, just thinking out loud:
Re "parallel" mode:
Per @wconstab's comment above, users can probably do it easily / cleanly in their code. So this does not motivate the proposed API originally.

Re "serialized" mode:
The proposed API has an assumption that user provide a single init function for all the stage models. How easy / general is that (e.g. would that need some amount of if-else in the init function? Anything else?)

Re customer base:
If there are only a small set of users that need serialized init, what's wrong with recommending checkpoint based init to them? It seems similar to a training recovery step or an initial load step in modeling tuning. And would it be more performant than serialized init since all ranks can then load from the checkpoint concurrently?

My point being, for the majority of users, should we not give them a feeling that they should/could change they are doing today? i.e. set different seeds, init the modules, hook it with pipeline schedule, run, etc.

@wconstab
Copy link
Contributor Author

wconstab commented Nov 1, 2024

I think we should run experiments at larger scale to see if serialized init is noticeably slow. If not, I would be happy to recommend it as the default way, and just consider 'parallel' mode as an escape hatch.

My point being, for the majority of users, should we not give them a feeling that they should/could change they are doing today? i.e. set different seeds, init the modules, hook it with pipeline schedule, run, etc.

Users can freely ignore this helper API and do their own seeding if they want.

The proposed API has an assumption that user provide a single init function for all the stage models.

Agreed that this is a bit weak in general. However, if users follow the practice of writing 'init_weights' per submodule then it is quite easy to have a generalized top-level initializer that delegates. That is what we have in titan.

In the titan code,
image

I was thinking maybe we can get rid of the 'for m in model_parts' loop here, define a helper 'parallelize_and_initialize(module)' that can be reused in PP and non PP branch, then just have the overall code look like this:

    def parallelize_and_initialize(m):
            # apply SPMD-style PT-D techniques
            models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
            m.to_empty(device=init_device)
            m.init_weights(buffer_device=buffer_device)
            m.train()

    # apply parallelisms and initialization
    if parallel_dims.pp_enabled:
        # apply PT-D Pipeline Parallel
        pp_schedule, model_parts = models_pipelining_fns[model_name](
            model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
        )

       schedule.seeded_module_init(parallelize_and_initialize)

    else:
        parallelize_and_initialize(model)

        model_parts = [model]

i'm not totally sure its good to bundle the parallelize and initialize in this way, but it might be ok?

Anyway.. I think there are 2 points
(a) should we offer serial init?
(b) if we are going to have serial init, what should the API be?

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks quite nice!! Had some inline questions.

# recv, set RNG seed
if mode == "serial":
if stage.is_first:
torch.manual_seed(initial_seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make initial_seed optional, i.e. default to None instead of 0?
so that the semantic is the same as single device (in particular the device that hosts the first stage) even when user don't specify the initial_seed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. 0 is a valid seed too.

If we had initial_seed=None, what should i pass to manual_seed? I can't turn the current RNG state (tensor) back into a seed (int) so I am forced to use an initial seed value here I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we had initial_seed=None, what should i pass to manual_seed?

I was thinking just don't call manual_seed if the seed is None. The point is, without PP user can call init_weights without providing a seed. I wonder if we can reproduce the same values with PP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i got your point, yea I can make that change. For 'parallel' mode we'd still need to convert 'None' to 0 but for 'serial' mode we can just avoid setting anything on stage0 for 'None' and then propagate from there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to use keep initial_seed. The name of this function, seeded_module_init already implicitly indicates that some seeding API is going to be called. IMHO, having initial_seed is consistent with the function naming.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then maybe we should change the function name? How about seed_aware_init

My idea is: we should follow the flow of "keep single-device semantics" -> "decide functionality" -> "design API (with minimum interface)". In this case, without PP users don't need to explicitly provide a seed here to do init (in torchtitan there is already a config to globally set the seed); with PP we can just mimic that behavior.

Comment on lines +1064 to +1066
if mode == "parallel":
seed = initial_seed + self.rank
torch.manual_seed(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two questions.

  1. Before this PR (if we don't do seeded checkpoint), are the random seeds the same on all PP ranks (assuming torchtitan users don't set it)? If not, then it sounds expected that PP shouldn't cause converging issues.
  2. What's the implication of this change to FSDP/TP ranks on the same PP rank? Would DTensor ensure different init on different FSDP/TP ranks? I guess this is the same for non-PP setup. cc: @wz337

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For q1; it's up to the user code, but if they don't do anything then every rank would start with the same seed.

For q2; I don't remember how dtensor manages rng. We should check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the weights are DTensors, DTensor will perform seed synchronization. cc., @XilunWu

@tianyu-l tianyu-l linked an issue Nov 2, 2024 that may be closed by this pull request
@XilunWu
Copy link
Contributor

XilunWu commented Nov 5, 2024

After going through this PR, I believe DTensor RNG won't have impact on this because DTensor RNG is only used on DTensor random ops (e.g. torch.distributed.tensor.randn) while PP RNG is directly calling torch.manual_seed. There's no interaction between them, unless we're using DTensor ops to initialize tensors.

code pointer: https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_random.py#L91-L99
You can see that DTensor RNG only sets up the seed for DTensor Random ops' python context.

Comment on lines +90 to +94
def init_weights(self):
torch.nn.init.xavier_uniform_(self.net1.weight)
torch.nn.init.xavier_uniform_(self.net2.weight)
torch.nn.init.zeros_(self.net1.bias)
torch.nn.init.zeros_(self.net2.bias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the place where model init happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the test case, yes. But 'self' could be already an FSDP model, i just didn't test it that way

torch.distributed.send(
state_cpu, dst=stage.next_rank, group=stage.group
)
torch.distributed.send(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the "serial" seeded init and the newly added shape inference both use send/recv ops, is there a possibility that these ops can conflict with each other?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, unless I have a bug. They should run in a deterministic order that is the same on all ranks (init first then shape).

@wconstab
Copy link
Contributor Author

Abandoning this as not widely useful. It works fine for PP only (1D parallel) but will not work as intended when combined with SPMD/DTensor parallelism. That's because DTensor/SPMD parallelism does not have a way to gaurantee same initialization behavior under different sharding configurations.

@wconstab wconstab closed this Nov 27, 2024
@github-actions github-actions bot deleted the gh/wconstab/355/head branch December 28, 2024 02:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC][Pipelining] RNG state communication, avoid seed checkpoint

7 participants