Skip to content

init FSDP through from_pretrained#46102

Merged
3outeille merged 9 commits into
mainfrom
clean-fsdp-init
May 26, 2026
Merged

init FSDP through from_pretrained#46102
3outeille merged 9 commits into
mainfrom
clean-fsdp-init

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented May 20, 2026

Instantiate FSDP through .from_pretrained instead

@3outeille 3outeille requested a review from ArthurZucker May 20, 2026 07:12
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille changed the title clean + fix fsdp tied weights init FSDP through from_pretrained May 26, 2026
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall just 2 questions just to be sure

return isinstance(module, FullyShardedDataParallel)


def initialize_fsdp(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure, we did not include the PR which introduced this in any release yet? I think it's fine then, if not we should at least add a deprecation cycle

Copy link
Copy Markdown
Member Author

@3outeille 3outeille May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function was not used (I forgot to clean it). The way we instantiated it so far is through .from_pretrained which calls init_device_mesh which calls _ensure_torch_distributed (and run an init_process_group)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no I meant as in deprecation cylce logic to keep BC if needed. I guess we didnt include this in any release so it's easy to remove?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think it was include in any release so should be good

# `apply_tensor_parallel` see the shared-parameter graph and can route tied
# entries (e.g. `lm_head` -> `embed_tokens`) correctly. `_finalize_model_loading`
# re-runs `tie_weights` after the checkpoint is loaded to handle missing-key edge cases.
model.tie_weights()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, i guess this wasn't caught before. Could we add a small test or did something fail?

Copy link
Copy Markdown
Member Author

@3outeille 3outeille May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because the tests before was applying FSDP in 2 steps

model = AutoModelForCausalLM.from_config(config).to(device_map)
# from_config -> post_init() -> init_weights() -> tie_weights()
model = apply_fully_shard_data_parallel(model, device_mesh, fsdp_plan=auto_plan)

Now, im calling it with .from_pretrained which apply fsdp before tying the weights that's why the test test_fsdp2_sharding_structure_tied failed

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, makes sense

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm this is quite weird because we are tying a bit too many times no?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread tests/test_fsdp_mixin.py Outdated
@3outeille 3outeille added this pull request to the merge queue May 26, 2026
Merged via the queue into main with commit 0588858 May 26, 2026
33 checks passed
@3outeille 3outeille deleted the clean-fsdp-init branch May 26, 2026 18:46
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please wait for me or @Cyrilvallez when tie weights is called :) :) :)

# `apply_tensor_parallel` see the shared-parameter graph and can route tied
# entries (e.g. `lm_head` -> `embed_tokens`) correctly. `_finalize_model_loading`
# re-runs `tie_weights` after the checkpoint is loaded to handle missing-key edge cases.
model.tie_weights()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm this is quite weird because we are tying a bit too many times no?

# `apply_tensor_parallel` see the shared-parameter graph and can route tied
# entries (e.g. `lm_head` -> `embed_tokens`) correctly. `_finalize_model_loading`
# re-runs `tie_weights` after the checkpoint is loaded to handle missing-key edge cases.
model.tie_weights()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# entries (e.g. `lm_head` -> `embed_tokens`) correctly. `_finalize_model_loading`
# re-runs `tie_weights` after the checkpoint is loaded to handle missing-key edge cases.
model.tie_weights()
model = distribute_model(model, distributed_config, device_mesh)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distribute_model will call apply_fully_shard_data_parallel which has:

    if is_weights_tied and hasattr(model, "tie_weights"):
        # Re-tie weights.
        # fully_shard replaces nn.Parameter objects (swapping data for DTensor shards),
        # which breaks weight tying (e.g. lm_head.weight is no longer embed_tokens.weight).
        # Re-tying makes lm_head._parameters["weight"] point to the new DTensor parameter
        # so gradients accumulate correctly into a single buffer.
        model.tie_weights()

vasqu added a commit that referenced this pull request May 28, 2026
* Revert "init FSDP through from_pretrained (#46102)"

This reverts commit 0588858.

* Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (#46141)"

This reverts commit 634500b.

* Revert "Update cohere2_moe tp_plan (#46189)"

This reverts commit e65c3a2.

* Revert "FSDP + TP & native save/load distributed (#45028)"

This reverts commit 9ba8e85.

* fix

* they should have been deleted I think

* these are actually needed changes

* oops
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
* clean + fix fsdp tied weights

* dispatch attn to default

* linting
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
* Revert "init FSDP through from_pretrained (huggingface#46102)"

This reverts commit 0588858.

* Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (huggingface#46141)"

This reverts commit 634500b.

* Revert "Update cohere2_moe tp_plan (huggingface#46189)"

This reverts commit e65c3a2.

* Revert "FSDP + TP & native save/load distributed (huggingface#45028)"

This reverts commit 9ba8e85.

* fix

* they should have been deleted I think

* these are actually needed changes

* oops
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
* clean + fix fsdp tied weights

* dispatch attn to default

* linting
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
* Revert "init FSDP through from_pretrained (huggingface#46102)"

This reverts commit 0588858.

* Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (huggingface#46141)"

This reverts commit 634500b.

* Revert "Update cohere2_moe tp_plan (huggingface#46189)"

This reverts commit e65c3a2.

* Revert "FSDP + TP & native save/load distributed (huggingface#45028)"

This reverts commit 9ba8e85.

* fix

* they should have been deleted I think

* these are actually needed changes

* oops
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants