init FSDP through from_pretrained#46102
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
LGTM overall just 2 questions just to be sure
| return isinstance(module, FullyShardedDataParallel) | ||
|
|
||
|
|
||
| def initialize_fsdp( |
There was a problem hiding this comment.
Just to be sure, we did not include the PR which introduced this in any release yet? I think it's fine then, if not we should at least add a deprecation cycle
There was a problem hiding this comment.
this function was not used (I forgot to clean it). The way we instantiated it so far is through .from_pretrained which calls init_device_mesh which calls _ensure_torch_distributed (and run an init_process_group)
There was a problem hiding this comment.
Ah no I meant as in deprecation cylce logic to keep BC if needed. I guess we didnt include this in any release so it's easy to remove?
There was a problem hiding this comment.
i dont think it was include in any release so should be good
| # `apply_tensor_parallel` see the shared-parameter graph and can route tied | ||
| # entries (e.g. `lm_head` -> `embed_tokens`) correctly. `_finalize_model_loading` | ||
| # re-runs `tie_weights` after the checkpoint is loaded to handle missing-key edge cases. | ||
| model.tie_weights() |
There was a problem hiding this comment.
Hmm, i guess this wasn't caught before. Could we add a small test or did something fail?
There was a problem hiding this comment.
it's because the tests before was applying FSDP in 2 steps
model = AutoModelForCausalLM.from_config(config).to(device_map)
# from_config -> post_init() -> init_weights() -> tie_weights()
model = apply_fully_shard_data_parallel(model, device_mesh, fsdp_plan=auto_plan)
Now, im calling it with .from_pretrained which apply fsdp before tying the weights that's why the test test_fsdp2_sharding_structure_tied failed
There was a problem hiding this comment.
Mmm this is quite weird because we are tying a bit too many times no?
There was a problem hiding this comment.
ArthurZucker
left a comment
There was a problem hiding this comment.
Please wait for me or @Cyrilvallez when tie weights is called :) :) :)
| # `apply_tensor_parallel` see the shared-parameter graph and can route tied | ||
| # entries (e.g. `lm_head` -> `embed_tokens`) correctly. `_finalize_model_loading` | ||
| # re-runs `tie_weights` after the checkpoint is loaded to handle missing-key edge cases. | ||
| model.tie_weights() |
There was a problem hiding this comment.
Mmm this is quite weird because we are tying a bit too many times no?
| # `apply_tensor_parallel` see the shared-parameter graph and can route tied | ||
| # entries (e.g. `lm_head` -> `embed_tokens`) correctly. `_finalize_model_loading` | ||
| # re-runs `tie_weights` after the checkpoint is loaded to handle missing-key edge cases. | ||
| model.tie_weights() |
There was a problem hiding this comment.
| # entries (e.g. `lm_head` -> `embed_tokens`) correctly. `_finalize_model_loading` | ||
| # re-runs `tie_weights` after the checkpoint is loaded to handle missing-key edge cases. | ||
| model.tie_weights() | ||
| model = distribute_model(model, distributed_config, device_mesh) |
There was a problem hiding this comment.
distribute_model will call apply_fully_shard_data_parallel which has:
if is_weights_tied and hasattr(model, "tie_weights"):
# Re-tie weights.
# fully_shard replaces nn.Parameter objects (swapping data for DTensor shards),
# which breaks weight tying (e.g. lm_head.weight is no longer embed_tokens.weight).
# Re-tying makes lm_head._parameters["weight"] point to the new DTensor parameter
# so gradients accumulate correctly into a single buffer.
model.tie_weights()* Revert "init FSDP through from_pretrained (#46102)" This reverts commit 0588858. * Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (#46141)" This reverts commit 634500b. * Revert "Update cohere2_moe tp_plan (#46189)" This reverts commit e65c3a2. * Revert "FSDP + TP & native save/load distributed (#45028)" This reverts commit 9ba8e85. * fix * they should have been deleted I think * these are actually needed changes * oops
* clean + fix fsdp tied weights * dispatch attn to default * linting
* Revert "init FSDP through from_pretrained (huggingface#46102)" This reverts commit 0588858. * Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (huggingface#46141)" This reverts commit 634500b. * Revert "Update cohere2_moe tp_plan (huggingface#46189)" This reverts commit e65c3a2. * Revert "FSDP + TP & native save/load distributed (huggingface#45028)" This reverts commit 9ba8e85. * fix * they should have been deleted I think * these are actually needed changes * oops
* clean + fix fsdp tied weights * dispatch attn to default * linting
* Revert "init FSDP through from_pretrained (huggingface#46102)" This reverts commit 0588858. * Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (huggingface#46141)" This reverts commit 634500b. * Revert "Update cohere2_moe tp_plan (huggingface#46189)" This reverts commit e65c3a2. * Revert "FSDP + TP & native save/load distributed (huggingface#45028)" This reverts commit 9ba8e85. * fix * they should have been deleted I think * these are actually needed changes * oops
Instantiate FSDP through
.from_pretrainedinstead