-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[pipelining] RNG seed management for PP init #139445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139445
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 16503c4 with merge base failed to retrieve merge base, please contact dev infra: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Implement both 'serial' and 'parallel' options described in #139304 'parallel' mode sets a unique seed per rank, but does not worry about parity with non-PP initialization. It is fastest becuase it parallelizes initialization and uses no communication. 'serial' mode does initialization one stage at a time, communicating the RNG states from rank to rank as needed so that the resulting initialization is equivalent to non-PP initialization where the whole model is initialized on one rank. The overall initialization is delegated to the 'schedule' since it knows which stages it owns, which ranks they are on, and what order they go in. But the per-stage initialization is handled by a user-passed function that can do anything it wants to the given module owned by that stage. TODO: add tests and see if this actually works. See whether 'serial' mode is slow for large PP setups. ghstack-source-id: 83f0cae Pull Request resolved: #139445
Implement both 'serial' and 'parallel' options described in #139304 'parallel' mode sets a unique seed per rank, but does not worry about parity with non-PP initialization. It is fastest becuase it parallelizes initialization and uses no communication. 'serial' mode does initialization one stage at a time, communicating the RNG states from rank to rank as needed so that the resulting initialization is equivalent to non-PP initialization where the whole model is initialized on one rank. The overall initialization is delegated to the 'schedule' since it knows which stages it owns, which ranks they are on, and what order they go in. But the per-stage initialization is handled by a user-passed function that can do anything it wants to the given module owned by that stage. TODO: add tests and see if this actually works. See whether 'serial' mode is slow for large PP setups. ghstack-source-id: ddcde6f Pull Request resolved: #139445
Add new 'seeded_module_init' API to PipelineScheduleMulti - note: skipped for PipelineScheduleSingle for now. Could add it there too, but i want to delete that class and unify Example: ``` schedule.seeded_module_init( # required arg, can be any fn that accepts 'module' as arg; # can be ModuleClass.init_weights my_initializer_fn, # optional args mode="serial", seed=1234 ) ``` Implement both 'serial' and 'parallel' options described in #139304 'parallel' mode sets a unique seed per rank, but does not worry about parity with non-PP initialization. It is fastest becuase it parallelizes initialization and uses no communication. 'serial' mode does initialization one stage at a time, communicating the RNG states from rank to rank as needed so that the resulting initialization is equivalent to non-PP initialization where the whole model is initialized on one rank. The overall initialization is delegated to the 'schedule' since it knows which stages it owns, which ranks they are on, and what order they go in. But the per-stage initialization is handled by a user-passed function that can do anything it wants to the given module owned by that stage. ghstack-source-id: 35ce7ed Pull Request resolved: #139445
| "Simply stop passing it, and everything should still work fine." | ||
| ) | ||
|
|
||
| def seeded_module_init( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#bikeshed_me
"seeded_module_init"? meh.
I first used 'initialize_stage' but i realized there is a preexisting '_initialize_stage' API right below that is confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to also change the order to let the naming begins with the verb (init) to be consistent with _initialize_stage? Not very important nit.
|
|
||
| def seeded_module_init( | ||
| self, | ||
| stage_initializer: Callable[[torch.nn.Module], None], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should 'stage_initializer' be an optional argument? I could check stage.submod for the presence of an 'init_weights' method, and if present, use that as the default. Otherwise, if 'stage_initializer' is not provided, I would raise an error.
| # need to know the size/dtype of the state tensors so might as well get current state then overwrite | ||
| state_cpu = torch.get_rng_state().to(device) | ||
| state_cuda = torch.cuda.get_rng_state(device=device).to(device) | ||
| torch.distributed.recv( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it overkill to mess with the CPU rng state? I included it since in theory, someone could use CPU side to calculate something and then copy that to GPU for their tensor init.
| f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}" | ||
| ) | ||
|
|
||
| def stage_global_rank(peer_rank): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: just moved this function from a derived class up to the base class, so it can be used from everywhere.
|
Does "parallel" mode need help from |
not strictly. It only needs to know 'pp_rank' to set different seeds per pp-rank, and then do init as usual. However, once adding an API to pipelining, I think the only use case for 'parallel' would be to get back some performance if serial was too slow, and having it be a flag-flip seems nicer than having to edit your code to do init differently. |
|
I don't have problem landing the proposed as experimental API. The description in the RFC indicates that a pipeline schedule can be of some help for serialized init and hence why attaching a util here. (Although, it could be still debatable whether pipeline schedule -- a "run" thing -- should get caught in model init, but anyway). The reason for suggesting making it experimental is so that we can be off the hook of maintenance for now. That said, just thinking out loud: Re "serialized" mode: Re customer base: My point being, for the majority of users, should we not give them a feeling that they should/could change they are doing today? i.e. set different seeds, init the modules, hook it with pipeline schedule, run, etc. |
|
I think we should run experiments at larger scale to see if serialized init is noticeably slow. If not, I would be happy to recommend it as the default way, and just consider 'parallel' mode as an escape hatch.
Users can freely ignore this helper API and do their own seeding if they want.
Agreed that this is a bit weak in general. However, if users follow the practice of writing 'init_weights' per submodule then it is quite easy to have a generalized top-level initializer that delegates. That is what we have in titan. I was thinking maybe we can get rid of the 'for m in model_parts' loop here, define a helper 'parallelize_and_initialize(module)' that can be reused in PP and non PP branch, then just have the overall code look like this: i'm not totally sure its good to bundle the parallelize and initialize in this way, but it might be ok? Anyway.. I think there are 2 points |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks quite nice!! Had some inline questions.
| # recv, set RNG seed | ||
| if mode == "serial": | ||
| if stage.is_first: | ||
| torch.manual_seed(initial_seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make initial_seed optional, i.e. default to None instead of 0?
so that the semantic is the same as single device (in particular the device that hosts the first stage) even when user don't specify the initial_seed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand. 0 is a valid seed too.
If we had initial_seed=None, what should i pass to manual_seed? I can't turn the current RNG state (tensor) back into a seed (int) so I am forced to use an initial seed value here I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we had initial_seed=None, what should i pass to manual_seed?
I was thinking just don't call manual_seed if the seed is None. The point is, without PP user can call init_weights without providing a seed. I wonder if we can reproduce the same values with PP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i got your point, yea I can make that change. For 'parallel' mode we'd still need to convert 'None' to 0 but for 'serial' mode we can just avoid setting anything on stage0 for 'None' and then propagate from there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to use keep initial_seed. The name of this function, seeded_module_init already implicitly indicates that some seeding API is going to be called. IMHO, having initial_seed is consistent with the function naming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then maybe we should change the function name? How about seed_aware_init
My idea is: we should follow the flow of "keep single-device semantics" -> "decide functionality" -> "design API (with minimum interface)". In this case, without PP users don't need to explicitly provide a seed here to do init (in torchtitan there is already a config to globally set the seed); with PP we can just mimic that behavior.
| if mode == "parallel": | ||
| seed = initial_seed + self.rank | ||
| torch.manual_seed(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two questions.
- Before this PR (if we don't do seeded checkpoint), are the random seeds the same on all PP ranks (assuming torchtitan users don't set it)? If not, then it sounds expected that PP shouldn't cause converging issues.
- What's the implication of this change to FSDP/TP ranks on the same PP rank? Would DTensor ensure different init on different FSDP/TP ranks? I guess this is the same for non-PP setup. cc: @wz337
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For q1; it's up to the user code, but if they don't do anything then every rank would start with the same seed.
For q2; I don't remember how dtensor manages rng. We should check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the weights are DTensors, DTensor will perform seed synchronization. cc., @XilunWu
|
After going through this PR, I believe DTensor RNG won't have impact on this because DTensor RNG is only used on DTensor random ops (e.g. code pointer: https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_random.py#L91-L99 |
| def init_weights(self): | ||
| torch.nn.init.xavier_uniform_(self.net1.weight) | ||
| torch.nn.init.xavier_uniform_(self.net2.weight) | ||
| torch.nn.init.zeros_(self.net1.bias) | ||
| torch.nn.init.zeros_(self.net2.bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the place where model init happens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the test case, yes. But 'self' could be already an FSDP model, i just didn't test it that way
| torch.distributed.send( | ||
| state_cpu, dst=stage.next_rank, group=stage.group | ||
| ) | ||
| torch.distributed.send( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that the "serial" seeded init and the newly added shape inference both use send/recv ops, is there a possibility that these ops can conflict with each other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, unless I have a bug. They should run in a deterministic order that is the same on all ranks (init first then shape).
|
Abandoning this as not widely useful. It works fine for PP only (1D parallel) but will not work as intended when combined with SPMD/DTensor parallelism. That's because DTensor/SPMD parallelism does not have a way to gaurantee same initialization behavior under different sharding configurations. |

Stack from ghstack (oldest at bottom):
Add new 'seeded_module_init' API to PipelineScheduleMulti
too, but i want to delete that class and unify
Example:
Implement both 'serial' and 'parallel' options described in #139304
'parallel' mode sets a unique seed per rank, but does not worry about
parity with non-PP initialization. It is fastest becuase it
parallelizes initialization and uses no communication.
'serial' mode does initialization one stage at a time, communicating the
RNG states from rank to rank as needed so that the resulting
initialization is equivalent to non-PP initialization where the whole
model is initialized on one rank.
The overall initialization is delegated to the 'schedule' since it knows which
stages it owns, which ranks they are on, and what order they go in.
But the per-stage initialization is handled by a user-passed function
that can do anything it wants to the given module owned by that stage.