Skip to content

FSDP + TP & native save/load distributed#45028

Merged
3outeille merged 133 commits into
mainfrom
refactor-tp-dtensor
May 19, 2026
Merged

FSDP + TP & native save/load distributed#45028
3outeille merged 133 commits into
mainfrom
refactor-tp-dtensor

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented Mar 26, 2026

Highlights

  • One config, one call. Pass a DistributedConfig to from_pretrained to apply
    TP, FSDP2, or both — sharding happens during weight load (shard-on-read), not
    after.
    model = AutoModelForCausalLM.from_pretrained(
        name,
        distributed_config=DistributedConfig(
            tp_size=2, fsdp_size=2,
            enable_sequence_parallel=True,
        ),
        dtype=torch.bfloat16,
    )
  • DTensor-based TP. Leverage Pytorch TP API. TP plans
    are dicts of {module_pattern: style_name}; style names ("colwise_allgather", "rowwise_allreduce", "packed_colwise", "moe_experts_allreduce", "colwise_loss_parallel", …) resolve through the ALL_PARALLEL_STYLES registry to
    torch.distributed.tensor.parallel.ParallelStyle instances and are applied by torch's
    parallelize_module.
  • FSDP2 native in transformers: FSDP2 native support in transformers  #44083
  • Shard-on-read via DtensorShardOperation. Each rank streams the full checkpoint
    tensors and immediately slices down to its local DTensor shard for any combination
    of placements on any-D mesh — Replicate, Shard(d), _StridedShard(d, sf=N). The
    class encapsulates the (mesh, placements) pair so slicing logic isn't repeated at
    every call site, and handles both shipped checkpoint layouts: one stacked tensor
    (e.g. [num_experts, in, out]) or N per-expert tensors. No full-tensor
    materialization on rank 0; no post-load redistribute.
  • Sequence parallelism for activations and norms via per-model _sp_plan entries
    toggled with enable_sequence_parallel=True.
  • MoE expert parallelism + packed-weight sharding for fused gate_up_proj /
    grouped_mm kernels
    , including _StridedShard for interleaved shards and an
    all-reduce-on-backward path for routing weights via _AllReduceBackward.
  • Distributed checkpointing for both model and optimizer.
 model.save_pretrained(ckpt_dir, distributed_checkpoint=True)
 save_optimizer_distributed(model, optimizer, opt_dir)

 model = AutoModelForCausalLM.from_pretrained(
     ckpt_dir, distributed_config=DistributedConfig(...),
 )
 load_optimizer_distributed(model, optimizer, opt_dir)
  • Includes optimizer-state fusion handling (get_fusion_metadata /
    unfuse_optimizer_state / fuse_optimizer_state) so a single optimizer_state slot
    covering a fused parameter like gate_up_proj is split/rejoined cleanly across
    save/load.
  • Resume under a different parallelism layout. Save under FSDP=2 × TP=2, reload
    under TP=4, continue training — same checkpoint, different topology. The PR's demo
    trains 5 steps under one config, reloads under another, finishes training, then runs
    inference under a third — and verifies the model overfits the target sentence
    verbatim.
  • Standard safetensors export. model.save_pretrained(dir) (no
    distributed_checkpoint) writes a fully-gathered,
    plain safetensors checkpoint
    that loads anywhere — single GPU, different parallelism, vLLM, etc.
  • Distributed-aware utilities. clip_grad_norm handles DTensor parameters across
    the full mesh; optimizer save/load auto-disables foreach / fused kernels on
    parameter groups that mix regular tensors and DTensors.

Reproduction

"""
torchrun --nproc_per_node=4 overfit_demo.py

The script overfit one sentence following the steps:
    - Train first half using FSDP=2+TP=2
    - Save the model and optimizer in distributed checkpoint
    - Reload the model and optimizer from the distributed checkpoint
    - Train the rest in TP=4 (change distributed config)
    - Save the model and optimizer in distributed checkpoint
    - Reload the model in a single safetensors file.
    - Do inference in TP=4  and assert greedy generation reproduces the sentence verbatim
"""

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.distributed import DistributedConfig
from transformers.distributed.utils import (
    clip_grad_norm,
    load_optimizer_distributed,
    save_optimizer_distributed,
)

NAME = "Isotonic/TinyMixtral-4x248M-MoE"
TEXT = "In a quiet village nestled between rolling hills and a slow river, the autumn mornings arrived with mist that hung low over the fields and a sky that turned from grey to pale gold as the sun climbed."
CKPT = "./checkpoints"
OPT = os.path.join(CKPT, "optimizer")
STEPS = 10
HALF = STEPS // 2

rank, local_rank = int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

tokenizer = AutoTokenizer.from_pretrained(NAME)
ids = tokenizer(TEXT, return_tensors="pt").input_ids.to(f"cuda:{local_rank}")

# Train first half, distributed-save model + optimizer.
model = AutoModelForCausalLM.from_pretrained(
    NAME,
    distributed_config=DistributedConfig(tp_size=2, fsdp_size=2, enable_sequence_parallel=True),
    dtype=torch.bfloat16,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.train()
for step in range(0, HALF):
    loss = model(ids, labels=ids).loss
    loss.backward()
    total_norm = clip_grad_norm(model.parameters(), max_norm=1.0)
    optimizer.step()
    optimizer.zero_grad()
    if rank == 0:
        print(f"step {step:>2} | loss {loss.item():.5f} grad norm {total_norm.item():.5f}")

model.save_pretrained(CKPT, distributed_checkpoint=True)
save_optimizer_distributed(model, optimizer, OPT)
del model, optimizer
torch.cuda.empty_cache()

# Reload model + optimizer from the distributed checkpoint, train the rest.

model = AutoModelForCausalLM.from_pretrained(
    CKPT,
    distributed_config=DistributedConfig(tp_size=4, enable_sequence_parallel=True),
    dtype=torch.bfloat16,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
load_optimizer_distributed(model, optimizer, OPT)

model.train()
for step in range(HALF, STEPS):
    loss = model(ids, labels=ids).loss
    loss.backward()
    total_norm = clip_grad_norm(model.parameters(), max_norm=1.0)
    optimizer.step()
    optimizer.zero_grad()
    if rank == 0:
        print(f"step {step:>2} | loss {loss.item():.5f} grad norm {total_norm.item():.5f}")

# INFERENCE in TP=4
model.save_pretrained(CKPT)
save_optimizer_distributed(model, optimizer, OPT + "_tp4")
del model, optimizer
torch.cuda.empty_cache()

model = AutoModelForCausalLM.from_pretrained(
    CKPT,
    distributed_config=DistributedConfig(tp_size=4),
    dtype=torch.bfloat16,
)
model.eval()
prompt = tokenizer("In a quiet village", return_tensors="pt").to(f"cuda:{local_rank}")
out = model.generate(**prompt, max_new_tokens=ids.shape[-1] - prompt.input_ids.shape[-1], do_sample=False)

got, want = out[0].tolist(), ids[0].tolist()
if rank == 0:
    print(f"generated: {tokenizer.decode(got, skip_special_tokens=True)!r}")
    print(f"expected: {tokenizer.decode(want, skip_special_tokens=True)!r}")
assert got == want, (
    f"generation mismatch at index {next((i for i, (g, e) in enumerate(zip(got, want)) if g != e), -1)}"
)

torch.distributed.destroy_process_group()
(env_refactor-tp-dtensor) ➜  refactor-tp-dtensor git:(refactor-tp-dtensor) ✗ torchrun --nproc_per_node=4 overfit_demo.py
W0513 16:49:43.716000 715651 torch/distributed/run.py:851] 
W0513 16:49:43.716000 715651 torch/distributed/run.py:851] *****************************************
W0513 16:49:43.716000 715651 torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0513 16:49:43.716000 715651 torch/distributed/run.py:851] *****************************************
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 4594.39it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 4261.65it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 4122.51it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 3596.90it/s]
step  0 | loss 4.97818 grad norm 10.12500
step  1 | loss 2.75203 grad norm 7.34375
step  2 | loss 0.41595 grad norm 4.09375
step  3 | loss 0.03579 grad norm 1.35156
step  4 | loss 0.01942 grad norm 1.25000
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 1795.49it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 1802.65it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 1714.74it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 1695.72it/s]
[transformers] Param group 0 mixes regular tensors and DTensors; disabling foreach/fused optimizer kernels for that group so distributed optimizer save/load can materialize state.
[transformers] Param group 0 mixes regular tensors and DTensors; disabling foreach/fused optimizer kernels for that group so distributed optimizer save/load can materialize state.
[transformers] Param group 0 mixes regular tensors and DTensors; disabling foreach/fused optimizer kernels for that group so distributed optimizer save/load can materialize state.
[transformers] Param group 0 mixes regular tensors and DTensors; disabling foreach/fused optimizer kernels for that group so distributed optimizer save/load can materialize state.
step  5 | loss 0.00169 grad norm 0.07031
step  6 | loss 0.00736 grad norm 0.69922
step  7 | loss 0.00137 grad norm 0.04810
step  8 | loss 0.00170 grad norm 0.06885
step  9 | loss 0.00065 grad norm 0.01953
Writing model shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.08s/it]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 5458.52it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 6130.57it/s]
[transformers] Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
[transformers] Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 6980.13it/s]
[transformers] Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 6672.51it/s]
[transformers] Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
generated: 'In a quiet village nestled between rolling hills and a slow river, the autumn mornings arrived with mist that hung low over the fields and a sky that turned from grey to pale gold as the sun climbed.'
expected: 'In a quiet village nestled between rolling hills and a slow river, the autumn mornings arrived with mist that hung low over the fields and a sky that turned from grey to pale gold as the sun climbed.'

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille force-pushed the refactor-tp-dtensor branch 2 times, most recently from fcea5ce to f98e208 Compare April 4, 2026 16:53
Comment thread src/transformers/models/qwen3/modeling_qwen3.py Outdated
3outeille and others added 4 commits April 13, 2026 15:33
- Add apply_fully_shard_data_parallel() with auto/manual mode block detection
- FSDP vs DDP loss/grad parity tests
- Distributed test helpers (testing_utils.py)
- is_fsdp_enabled(), is_fsdp_managed_module() utilities
- Minimal FSDP hooks in from_pretrained
- FSDP-aware flash attention check
- DtensorShardOperation for range-math shard-on-read
- spawn_materialize() enhancements
- from_pretrained wiring for distributed config
- Shard operation helpers in tensor_parallel
- Shard-on-read and LoadStateDictConfig tests
@3outeille 3outeille force-pushed the fsdp-core-model-loading branch from 607cc11 to 739332c Compare April 13, 2026 14:14
- Replace hook-based TP with DTensor-based TPStyle API
- TPStyle dataclass with dense kinds: colwise, rowwise, vocab
- apply_tensor_parallel() using PyTorch parallelize_module
- verify_tp_plan() for plan validation
- Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle
- DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3
- Extended DistributedConfig with tp/fsdp size and plan fields
- DistributedConfig serialization in configuration_utils
- MXFP4 NotImplementedError for DTensor TP
- Dense TP tests
@3outeille 3outeille force-pushed the fsdp-core-model-loading branch from dbc9619 to c567240 Compare April 14, 2026 09:54
@3outeille 3outeille force-pushed the refactor-tp-dtensor branch from 34a5085 to eb428cc Compare April 14, 2026 09:54
- Re-export is_fsdp_enabled and is_fsdp_managed_module from
  integrations/fsdp.py (moved to distributed/utils.py)
- Remove unused # type: ignore comments in generation/utils.py
@3outeille 3outeille force-pushed the fsdp-core-model-loading branch from c567240 to c1dab9e Compare April 14, 2026 13:44
@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 5423e3eb workflow commit (merge commit)
PR 7fb37af0 branch commit (from PR)
main 38a8b55f base commit (on main)

⚠️ Model CI failed to report results

The test failure analysis could not be completed. Please check the workflow run for details.

@3outeille 3outeille enabled auto-merge May 19, 2026 08:07
@3outeille 3outeille disabled auto-merge May 19, 2026 08:08
@3outeille 3outeille enabled auto-merge May 19, 2026 11:18
@3outeille 3outeille disabled auto-merge May 19, 2026 11:18
Restores legitimate improvements that were accidentally undone during a
stale merge of main into fsdp-vs-ddp:

- Restore test_resize_embeddings_untied_no_reinit_on_post_init
- Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo
  special-cases
- Restore skip_base_model parameter on test_reverse_loading_mapping
- Restore "is not None" guard on subconfig in test_initialization
- Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe, apertus, arcee, aria, audioflamingo3, bamba, bitnet, cohere, cohere2, csm, cwm, data2vec, dbrx, deepseek_v2, deepseek_v3, deepseek_v4

@3outeille 3outeille added this pull request to the merge queue May 19, 2026
Merged via the queue into main with commit 9ba8e85 May 19, 2026
158 of 162 checks passed
@3outeille 3outeille deleted the refactor-tp-dtensor branch May 19, 2026 13:03
Comment on lines 500 to 509
if enable_sp:
base_model = getattr(model, model.base_model_prefix, model)

def _inject_sp_metadata(mod, args, kwargs):
input_ids = kwargs.get("input_ids", args[0] if args else None)
if input_ids is None:
return args, kwargs
if "position_ids" not in kwargs or kwargs["position_ids"] is None:
seq_len = input_ids.shape[1]
kwargs["position_ids"] = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flagging again let's remove

Comment on lines +1101 to +1104
param_value = torch.nn.Parameter(dtensor_param, requires_grad=ref.requires_grad)
# super important otherwise _init_weight will re-init the param
param_value._is_hf_initialized = True
setattr(module_obj, param_name, param_value)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NICE

@remi-or remi-or mentioned this pull request May 20, 2026
vasqu added a commit that referenced this pull request May 28, 2026
* Revert "init FSDP through from_pretrained (#46102)"

This reverts commit 0588858.

* Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (#46141)"

This reverts commit 634500b.

* Revert "Update cohere2_moe tp_plan (#46189)"

This reverts commit e65c3a2.

* Revert "FSDP + TP & native save/load distributed (#45028)"

This reverts commit 9ba8e85.

* fix

* they should have been deleted I think

* these are actually needed changes

* oops
IlyasMoutawwakil added a commit that referenced this pull request May 28, 2026
Resolves the FSDP+TP rewrite (PR #45028) which moved
`src/transformers/integrations/tensor_parallel.py` to
`src/transformers/distributed/tensor_parallel.py` under a new `MoEExpertsParallel`
TPStyle API. Accepted the deletion of the old TP file; `to_local` is now sourced
from `transformers.distributed.utils` so the FP8/sonicmoe/deepgemm integrations
import a single canonical helper.

V4 config: adopted upstream's `moe_experts_allreduce` rename and `base_model_fsdp_plan`,
preserved our indexer TP entries (`q_b_proj` colwise, `scorer.weights_proj` colwise,
`scorer` all_reduce).

Mega-MoE TP hooks (router-side remap skip, process_group injection, post-forward
all_reduce skip) are not yet ported to the new MoEExpertsParallel lifecycle —
tracked as a separate task.
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
* init

* FSDP2 (fully_shard) integration

- Add apply_fully_shard_data_parallel() with auto/manual mode block detection
- FSDP vs DDP loss/grad parity tests
- Distributed test helpers (testing_utils.py)
- is_fsdp_enabled(), is_fsdp_managed_module() utilities
- Minimal FSDP hooks in from_pretrained
- FSDP-aware flash attention check

* DistributedConfig + shard-on-read loading

- DtensorShardOperation for range-math shard-on-read
- spawn_materialize() enhancements
- from_pretrained wiring for distributed config
- Shard operation helpers in tensor_parallel
- Shard-on-read and LoadStateDictConfig tests

* TPStyle API + dense model tensor parallelism

- Replace hook-based TP with DTensor-based TPStyle API
- TPStyle dataclass with dense kinds: colwise, rowwise, vocab
- apply_tensor_parallel() using PyTorch parallelize_module
- verify_tp_plan() for plan validation
- Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle
- DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3
- Extended DistributedConfig with tp/fsdp size and plan fields
- DistributedConfig serialization in configuration_utils
- MXFP4 NotImplementedError for DTensor TP
- Dense TP tests

* revert some files

* Add distributed training scripts

- train_fsdp_tp.py: minimal FSDP+TP training example
- train_fsdp_tp_torchtitan_style.py: torchtitan-style training example
- verify_loading.py: save/load roundtrip verification
- run_compare.sh: FSDP+TP vs FSDP-only comparison
- run_verify_all.sh: run verification across all modes
- tmp_generate.py: quick generation test

* Remove train_fsdp_tp_torchtitan_style.py

* unify the utils for fsdp

* Fix CI: re-export moved FSDP utils + remove stale type: ignore

- Re-export is_fsdp_enabled and is_fsdp_managed_module from
  integrations/fsdp.py (moved to distributed/utils.py)
- Remove unused # type: ignore comments in generation/utils.py

* Fix ruff formatting in core_model_loading.py

* Fix ruff linting and formatting

* Backport new TP/FSDP API from orchestration-save-load branch

* Fix DTensor imports in Copied-from model files

* MoE expert parallelism + sequence parallelism (huggingface#45408)

* MoE expert parallelism + sequence parallelism

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral

* Fix ruff linting and formatting

* Fix ruff formatting in core_model_loading.py

* Restore _IdentityOp accidentally removed in 25a1f48

The _IdentityOp class (added by PR huggingface#44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Backport new TP/FSDP API + fix DTensor imports in Copied-from models

* from_pretrained orchestration + distributed save/load (huggingface#45409)

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>

* do monkey patching for rotary

* Revert modeling file diffs to match fsdp-core-model-loading base

Restores modeling files to their base branch versions so the PR diff
only shows the distributed/patches.py monkey-patch approach instead of
noisy function moves in modeling files.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Migrate all model TP plans from strings to TPStyle

- Convert string plan values ("colwise", "rowwise", etc.) to TPStyle
  objects across 66+ model configs and modular files
- Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...)
  with shard_plan
- Remove "replicated_with_grad_allreduce" entries (not needed for
  DTensor TP)
- Migrate _tp_plan class attributes in modeling files from
  "colwise_gather_output" to TPStyle("colwise", "allgather")
- Add TypeError in apply_tensor_parallel for unsupported plan values
- Remove old TensorParallelLayer tests (API removed in DTensor refactor)
- Regenerate auto-generated files via modular converter

* Restore mxfp4.py to match base branch

* Drop mla_kv_a_proj and moe_identity_expert from TP plans

These string plan values have no TPStyle equivalent in the DTensor
system. Remove them to avoid TypeError at apply_tensor_parallel time.
Affected models: deepseek_v2, glm4_moe_lite, glm_moe_dsa, longcat_flash.

* more comments

* fix tp for most models.  PyTorch doesn't implement all placement conversions (e.g. _StridedShard↔Shard). We force replicate beforehand

* fix tp through _replicate_dtensor

* revert small change

* push temporary fix for TP and strided shard for backward

* refactor a bit

* patches for rotary

* refactor MoEExpertsParallel

* fix tp for last models

* refactor moe expert parallels

* linting

* add sp plan for models

* add deepseek v2 sp plan

* undo sp plan for some tricky models

* remove lm_head from  config

* first pass of refactoring dtensor shard operator

* better refacto

* batter explanation of DtensorShardOperation

* refactor dtensor test to reflect real world scenario

* more comments

* fix tp olmo hybrid and exaone

* Enhance tensor parallel weight tying logic to prevent clobbering of lm_head when embed_tokens is not in the plan.

* fix fsdp mixin test due to missing args

* fix test non model

* skip sp plan for exaone and olmo hybrid

* linting

* fix import for ci

* test distributed config

* attempt to fix guarding import ci

* fix ci check repro

* add ALL_PARALLEL_STYLES registry alongside TPStyle

* route apply_tensor_parallel through ALL_PARALLEL_STYLES

* migrate modular files to string-based TP plans

* migrate standalone configs and modelings to string-based TP plans

* delete TPStyle dataclass

* fix use_local_output defaults for SequenceParallel and PrepareModuleInput in registry

* use parallel style from torch

* revert changes in weight converter

* remove dead code in set_param_for_module

* remove dead code

* cleaning again

* cleaning

* revert change

* linting

* refactor dtensor shard ops

* revert some stuff in core model loading

* core model loading clean

* guarding import

* better separation tensor parall and generic utils

* isolate DtensorShardOperation into a separate file

* no need to patch rotary

* better seperation

* simplify gather_full_state_dict

* simplify _replicate_dtensor

* fix and clean _replicate_dtensor

* better doc for DtensorShardOperation

* fix saving optimizer with DCP for fused weights

* save_pretrained(distributed_checkpoint=true)

* linting

* refactor into a single function _dtensor_from_local_like

* zeros_like instead of empty_like

* move tp and fsdp under distributed

* distribute_model

* fix deadlock when saving

* clip grad norm function

* maybe_disable_foreach_and_fused_for_mixed_dtensor_groups

* better TP api for ease of understanding

* remove shard_param to make it easier

* fix import in test

* _swap_dtensor_params_for_local

* fix qwen3 nanochat dots1

* add tpu

* move TP refactor experimentation scripts to backup branch

Move ad-hoc training / verification / compare scripts off this branch
into refactor-tp-dtensor-scripts so the diff stays focused on library
changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>

* linting

* register distributed sharding_utils and utils in __init__

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>

* rename TP plan styles to match new ALL_PARALLEL_STYLES registry

Replace pre-refactor names that no longer exist in
src/transformers/distributed/tensor_parallel.py:
  rowwise -> rowwise_allreduce
  moe_tp_experts -> moe_experts_allreduce
  replicated_with_grad_allreduce -> activation_seq_dim_2

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>

* enable EP

* Add enable_expert_parallel configuration option in test_distributed_config

* no more auto mode

* edit fsdp plan to every other models

* update fsdp mixin tests

* linting

* fix test fsdp

* fsdp linting

* revert gitignore

* _apply within for loop

* rename

* doc sp plan

* fix

* unified settattr + torch no grad + _local_tensor

* revert

* linting

* fix ruff

* make check-repository-consistency

* trigger fsdp mixin test in CI

* fix fsdp ci

* Reset tests/test_modeling_common.py to main

Restores legitimate improvements that were accidentally undone during a
stale merge of main into fsdp-vs-ddp:

- Restore test_resize_embeddings_untied_no_reinit_on_post_init
- Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo
  special-cases
- Restore skip_base_model parameter on test_reverse_loading_mapping
- Restore "is not None" guard on subconfig in test_initialization
- Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message

---------

Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
* Revert "init FSDP through from_pretrained (huggingface#46102)"

This reverts commit 0588858.

* Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (huggingface#46141)"

This reverts commit 634500b.

* Revert "Update cohere2_moe tp_plan (huggingface#46189)"

This reverts commit e65c3a2.

* Revert "FSDP + TP & native save/load distributed (huggingface#45028)"

This reverts commit 9ba8e85.

* fix

* they should have been deleted I think

* these are actually needed changes

* oops
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
* init

* FSDP2 (fully_shard) integration

- Add apply_fully_shard_data_parallel() with auto/manual mode block detection
- FSDP vs DDP loss/grad parity tests
- Distributed test helpers (testing_utils.py)
- is_fsdp_enabled(), is_fsdp_managed_module() utilities
- Minimal FSDP hooks in from_pretrained
- FSDP-aware flash attention check

* DistributedConfig + shard-on-read loading

- DtensorShardOperation for range-math shard-on-read
- spawn_materialize() enhancements
- from_pretrained wiring for distributed config
- Shard operation helpers in tensor_parallel
- Shard-on-read and LoadStateDictConfig tests

* TPStyle API + dense model tensor parallelism

- Replace hook-based TP with DTensor-based TPStyle API
- TPStyle dataclass with dense kinds: colwise, rowwise, vocab
- apply_tensor_parallel() using PyTorch parallelize_module
- verify_tp_plan() for plan validation
- Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle
- DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3
- Extended DistributedConfig with tp/fsdp size and plan fields
- DistributedConfig serialization in configuration_utils
- MXFP4 NotImplementedError for DTensor TP
- Dense TP tests

* revert some files

* Add distributed training scripts

- train_fsdp_tp.py: minimal FSDP+TP training example
- train_fsdp_tp_torchtitan_style.py: torchtitan-style training example
- verify_loading.py: save/load roundtrip verification
- run_compare.sh: FSDP+TP vs FSDP-only comparison
- run_verify_all.sh: run verification across all modes
- tmp_generate.py: quick generation test

* Remove train_fsdp_tp_torchtitan_style.py

* unify the utils for fsdp

* Fix CI: re-export moved FSDP utils + remove stale type: ignore

- Re-export is_fsdp_enabled and is_fsdp_managed_module from
  integrations/fsdp.py (moved to distributed/utils.py)
- Remove unused # type: ignore comments in generation/utils.py

* Fix ruff formatting in core_model_loading.py

* Fix ruff linting and formatting

* Backport new TP/FSDP API from orchestration-save-load branch

* Fix DTensor imports in Copied-from model files

* MoE expert parallelism + sequence parallelism (huggingface#45408)

* MoE expert parallelism + sequence parallelism

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral

* Fix ruff linting and formatting

* Fix ruff formatting in core_model_loading.py

* Restore _IdentityOp accidentally removed in 25a1f48

The _IdentityOp class (added by PR huggingface#44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Backport new TP/FSDP API + fix DTensor imports in Copied-from models

* from_pretrained orchestration + distributed save/load (huggingface#45409)

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>

* do monkey patching for rotary

* Revert modeling file diffs to match fsdp-core-model-loading base

Restores modeling files to their base branch versions so the PR diff
only shows the distributed/patches.py monkey-patch approach instead of
noisy function moves in modeling files.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>

* Migrate all model TP plans from strings to TPStyle

- Convert string plan values ("colwise", "rowwise", etc.) to TPStyle
  objects across 66+ model configs and modular files
- Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...)
  with shard_plan
- Remove "replicated_with_grad_allreduce" entries (not needed for
  DTensor TP)
- Migrate _tp_plan class attributes in modeling files from
  "colwise_gather_output" to TPStyle("colwise", "allgather")
- Add TypeError in apply_tensor_parallel for unsupported plan values
- Remove old TensorParallelLayer tests (API removed in DTensor refactor)
- Regenerate auto-generated files via modular converter

* Restore mxfp4.py to match base branch

* Drop mla_kv_a_proj and moe_identity_expert from TP plans

These string plan values have no TPStyle equivalent in the DTensor
system. Remove them to avoid TypeError at apply_tensor_parallel time.
Affected models: deepseek_v2, glm4_moe_lite, glm_moe_dsa, longcat_flash.

* more comments

* fix tp for most models.  PyTorch doesn't implement all placement conversions (e.g. _StridedShard↔Shard). We force replicate beforehand

* fix tp through _replicate_dtensor

* revert small change

* push temporary fix for TP and strided shard for backward

* refactor a bit

* patches for rotary

* refactor MoEExpertsParallel

* fix tp for last models

* refactor moe expert parallels

* linting

* add sp plan for models

* add deepseek v2 sp plan

* undo sp plan for some tricky models

* remove lm_head from  config

* first pass of refactoring dtensor shard operator

* better refacto

* batter explanation of DtensorShardOperation

* refactor dtensor test to reflect real world scenario

* more comments

* fix tp olmo hybrid and exaone

* Enhance tensor parallel weight tying logic to prevent clobbering of lm_head when embed_tokens is not in the plan.

* fix fsdp mixin test due to missing args

* fix test non model

* skip sp plan for exaone and olmo hybrid

* linting

* fix import for ci

* test distributed config

* attempt to fix guarding import ci

* fix ci check repro

* add ALL_PARALLEL_STYLES registry alongside TPStyle

* route apply_tensor_parallel through ALL_PARALLEL_STYLES

* migrate modular files to string-based TP plans

* migrate standalone configs and modelings to string-based TP plans

* delete TPStyle dataclass

* fix use_local_output defaults for SequenceParallel and PrepareModuleInput in registry

* use parallel style from torch

* revert changes in weight converter

* remove dead code in set_param_for_module

* remove dead code

* cleaning again

* cleaning

* revert change

* linting

* refactor dtensor shard ops

* revert some stuff in core model loading

* core model loading clean

* guarding import

* better separation tensor parall and generic utils

* isolate DtensorShardOperation into a separate file

* no need to patch rotary

* better seperation

* simplify gather_full_state_dict

* simplify _replicate_dtensor

* fix and clean _replicate_dtensor

* better doc for DtensorShardOperation

* fix saving optimizer with DCP for fused weights

* save_pretrained(distributed_checkpoint=true)

* linting

* refactor into a single function _dtensor_from_local_like

* zeros_like instead of empty_like

* move tp and fsdp under distributed

* distribute_model

* fix deadlock when saving

* clip grad norm function

* maybe_disable_foreach_and_fused_for_mixed_dtensor_groups

* better TP api for ease of understanding

* remove shard_param to make it easier

* fix import in test

* _swap_dtensor_params_for_local

* fix qwen3 nanochat dots1

* add tpu

* move TP refactor experimentation scripts to backup branch

Move ad-hoc training / verification / compare scripts off this branch
into refactor-tp-dtensor-scripts so the diff stays focused on library
changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>

* linting

* register distributed sharding_utils and utils in __init__

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>

* rename TP plan styles to match new ALL_PARALLEL_STYLES registry

Replace pre-refactor names that no longer exist in
src/transformers/distributed/tensor_parallel.py:
  rowwise -> rowwise_allreduce
  moe_tp_experts -> moe_experts_allreduce
  replicated_with_grad_allreduce -> activation_seq_dim_2

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>

* enable EP

* Add enable_expert_parallel configuration option in test_distributed_config

* no more auto mode

* edit fsdp plan to every other models

* update fsdp mixin tests

* linting

* fix test fsdp

* fsdp linting

* revert gitignore

* _apply within for loop

* rename

* doc sp plan

* fix

* unified settattr + torch no grad + _local_tensor

* revert

* linting

* fix ruff

* make check-repository-consistency

* trigger fsdp mixin test in CI

* fix fsdp ci

* Reset tests/test_modeling_common.py to main

Restores legitimate improvements that were accidentally undone during a
stale merge of main into fsdp-vs-ddp:

- Restore test_resize_embeddings_untied_no_reinit_on_post_init
- Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo
  special-cases
- Restore skip_base_model parameter on test_reverse_loading_mapping
- Restore "is not None" guard on subconfig in test_initialization
- Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message

---------

Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
* Revert "init FSDP through from_pretrained (huggingface#46102)"

This reverts commit 0588858.

* Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (huggingface#46141)"

This reverts commit 634500b.

* Revert "Update cohere2_moe tp_plan (huggingface#46189)"

This reverts commit e65c3a2.

* Revert "FSDP + TP & native save/load distributed (huggingface#45028)"

This reverts commit 9ba8e85.

* fix

* they should have been deleted I think

* these are actually needed changes

* oops
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants