FSDP + TP & native save/load distributed#45028
Merged
Merged
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
fcea5ce to
f98e208
Compare
3outeille
commented
Apr 8, 2026
- Add apply_fully_shard_data_parallel() with auto/manual mode block detection - FSDP vs DDP loss/grad parity tests - Distributed test helpers (testing_utils.py) - is_fsdp_enabled(), is_fsdp_managed_module() utilities - Minimal FSDP hooks in from_pretrained - FSDP-aware flash attention check
- DtensorShardOperation for range-math shard-on-read - spawn_materialize() enhancements - from_pretrained wiring for distributed config - Shard operation helpers in tensor_parallel - Shard-on-read and LoadStateDictConfig tests
607cc11 to
739332c
Compare
- Replace hook-based TP with DTensor-based TPStyle API - TPStyle dataclass with dense kinds: colwise, rowwise, vocab - apply_tensor_parallel() using PyTorch parallelize_module - verify_tp_plan() for plan validation - Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle - DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3 - Extended DistributedConfig with tp/fsdp size and plan fields - DistributedConfig serialization in configuration_utils - MXFP4 NotImplementedError for DTensor TP - Dense TP tests
1aa7f5f to
11b55a2
Compare
3 tasks
- train_fsdp_tp.py: minimal FSDP+TP training example - train_fsdp_tp_torchtitan_style.py: torchtitan-style training example - verify_loading.py: save/load roundtrip verification - run_compare.sh: FSDP+TP vs FSDP-only comparison - run_verify_all.sh: run verification across all modes - tmp_generate.py: quick generation test
dbc9619 to
c567240
Compare
34a5085 to
eb428cc
Compare
- Re-export is_fsdp_enabled and is_fsdp_managed_module from integrations/fsdp.py (moved to distributed/utils.py) - Remove unused # type: ignore comments in generation/utils.py
c567240 to
c1dab9e
Compare
Contributor
CI ResultsCommit Info
The test failure analysis could not be completed. Please check the workflow run for details. |
Restores legitimate improvements that were accidentally undone during a stale merge of main into fsdp-vs-ddp: - Restore test_resize_embeddings_untied_no_reinit_on_post_init - Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo special-cases - Restore skip_base_model parameter on test_reverse_loading_mapping - Restore "is not None" guard on subconfig in test_initialization - Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, apertus, arcee, aria, audioflamingo3, bamba, bitnet, cohere, cohere2, csm, cwm, data2vec, dbrx, deepseek_v2, deepseek_v3, deepseek_v4 |
Comment on lines
500
to
509
| if enable_sp: | ||
| base_model = getattr(model, model.base_model_prefix, model) | ||
|
|
||
| def _inject_sp_metadata(mod, args, kwargs): | ||
| input_ids = kwargs.get("input_ids", args[0] if args else None) | ||
| if input_ids is None: | ||
| return args, kwargs | ||
| if "position_ids" not in kwargs or kwargs["position_ids"] is None: | ||
| seq_len = input_ids.shape[1] | ||
| kwargs["position_ids"] = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) |
Collaborator
There was a problem hiding this comment.
flagging again let's remove
Comment on lines
+1101
to
+1104
| param_value = torch.nn.Parameter(dtensor_param, requires_grad=ref.requires_grad) | ||
| # super important otherwise _init_weight will re-init the param | ||
| param_value._is_hf_initialized = True | ||
| setattr(module_obj, param_name, param_value) |
Closed
This was referenced May 25, 2026
vasqu
added a commit
that referenced
this pull request
May 28, 2026
* Revert "init FSDP through from_pretrained (#46102)" This reverts commit 0588858. * Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (#46141)" This reverts commit 634500b. * Revert "Update cohere2_moe tp_plan (#46189)" This reverts commit e65c3a2. * Revert "FSDP + TP & native save/load distributed (#45028)" This reverts commit 9ba8e85. * fix * they should have been deleted I think * these are actually needed changes * oops
IlyasMoutawwakil
added a commit
that referenced
this pull request
May 28, 2026
Resolves the FSDP+TP rewrite (PR #45028) which moved `src/transformers/integrations/tensor_parallel.py` to `src/transformers/distributed/tensor_parallel.py` under a new `MoEExpertsParallel` TPStyle API. Accepted the deletion of the old TP file; `to_local` is now sourced from `transformers.distributed.utils` so the FP8/sonicmoe/deepgemm integrations import a single canonical helper. V4 config: adopted upstream's `moe_experts_allreduce` rename and `base_model_fsdp_plan`, preserved our indexer TP entries (`q_b_proj` colwise, `scorer.weights_proj` colwise, `scorer` all_reduce). Mega-MoE TP hooks (router-side remap skip, process_group injection, post-forward all_reduce skip) are not yet ported to the new MoEExpertsParallel lifecycle — tracked as a separate task.
yuchenxie4645
pushed a commit
to yuchenxie4645/transformers
that referenced
this pull request
May 28, 2026
* init * FSDP2 (fully_shard) integration - Add apply_fully_shard_data_parallel() with auto/manual mode block detection - FSDP vs DDP loss/grad parity tests - Distributed test helpers (testing_utils.py) - is_fsdp_enabled(), is_fsdp_managed_module() utilities - Minimal FSDP hooks in from_pretrained - FSDP-aware flash attention check * DistributedConfig + shard-on-read loading - DtensorShardOperation for range-math shard-on-read - spawn_materialize() enhancements - from_pretrained wiring for distributed config - Shard operation helpers in tensor_parallel - Shard-on-read and LoadStateDictConfig tests * TPStyle API + dense model tensor parallelism - Replace hook-based TP with DTensor-based TPStyle API - TPStyle dataclass with dense kinds: colwise, rowwise, vocab - apply_tensor_parallel() using PyTorch parallelize_module - verify_tp_plan() for plan validation - Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle - DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3 - Extended DistributedConfig with tp/fsdp size and plan fields - DistributedConfig serialization in configuration_utils - MXFP4 NotImplementedError for DTensor TP - Dense TP tests * revert some files * Add distributed training scripts - train_fsdp_tp.py: minimal FSDP+TP training example - train_fsdp_tp_torchtitan_style.py: torchtitan-style training example - verify_loading.py: save/load roundtrip verification - run_compare.sh: FSDP+TP vs FSDP-only comparison - run_verify_all.sh: run verification across all modes - tmp_generate.py: quick generation test * Remove train_fsdp_tp_torchtitan_style.py * unify the utils for fsdp * Fix CI: re-export moved FSDP utils + remove stale type: ignore - Re-export is_fsdp_enabled and is_fsdp_managed_module from integrations/fsdp.py (moved to distributed/utils.py) - Remove unused # type: ignore comments in generation/utils.py * Fix ruff formatting in core_model_loading.py * Fix ruff linting and formatting * Backport new TP/FSDP API from orchestration-save-load branch * Fix DTensor imports in Copied-from model files * MoE expert parallelism + sequence parallelism (huggingface#45408) * MoE expert parallelism + sequence parallelism - Add PackedColwiseParallel for fused gate_up_proj weights - Add MoEExpertsParallel with per-expert DTensor sharding - Add PrepareModuleInputOutput for SP allgather/split hooks - Add _AllReduceBackward for MoE routing weight gradients - Extend TPStyle with moe_experts, packed_colwise, activation, module kinds - _StridedShard handling in core_model_loading for interleaved weights - MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans - DTensor rotary_pos_emb guard for mixtral * Fix ruff linting and formatting * Fix ruff formatting in core_model_loading.py * Restore _IdentityOp accidentally removed in 25a1f48 The _IdentityOp class (added by PR huggingface#44983) was accidentally deleted during the MoE expert parallelism work. It is needed by finegrained_fp8.py and metal_quantization.py as a pass-through reverse_op for dequantize operations. Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]> * Backport new TP/FSDP API + fix DTensor imports in Copied-from models * from_pretrained orchestration + distributed save/load (huggingface#45409) * from_pretrained orchestration + save/load - Add gather_full_state_dict() for DTensor→full tensor saving - Add convert_strided_to_shard() / restore_strided_from_shard() for DCP - Add _redistribute_dtensor() helper - Full distributed_config integration in from_pretrained/save_pretrained - Rename apply_fsdp2 → apply_fully_shard_data_parallel - save_optimizer() / load_optimizer() in distributed/utils - Trainer integration with distributed_config - Updated FSDP and TP tests for new orchestration API - DTensor shard-on-read test updates * revert distributed utils * eaaea * all tests for core modeling are passing * populate import from init for tp * ruff * ruff --------- Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]> * do monkey patching for rotary * Revert modeling file diffs to match fsdp-core-model-loading base Restores modeling files to their base branch versions so the PR diff only shows the distributed/patches.py monkey-patch approach instead of noisy function moves in modeling files. Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]> * Migrate all model TP plans from strings to TPStyle - Convert string plan values ("colwise", "rowwise", etc.) to TPStyle objects across 66+ model configs and modular files - Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...) with shard_plan - Remove "replicated_with_grad_allreduce" entries (not needed for DTensor TP) - Migrate _tp_plan class attributes in modeling files from "colwise_gather_output" to TPStyle("colwise", "allgather") - Add TypeError in apply_tensor_parallel for unsupported plan values - Remove old TensorParallelLayer tests (API removed in DTensor refactor) - Regenerate auto-generated files via modular converter * Restore mxfp4.py to match base branch * Drop mla_kv_a_proj and moe_identity_expert from TP plans These string plan values have no TPStyle equivalent in the DTensor system. Remove them to avoid TypeError at apply_tensor_parallel time. Affected models: deepseek_v2, glm4_moe_lite, glm_moe_dsa, longcat_flash. * more comments * fix tp for most models. PyTorch doesn't implement all placement conversions (e.g. _StridedShard↔Shard). We force replicate beforehand * fix tp through _replicate_dtensor * revert small change * push temporary fix for TP and strided shard for backward * refactor a bit * patches for rotary * refactor MoEExpertsParallel * fix tp for last models * refactor moe expert parallels * linting * add sp plan for models * add deepseek v2 sp plan * undo sp plan for some tricky models * remove lm_head from config * first pass of refactoring dtensor shard operator * better refacto * batter explanation of DtensorShardOperation * refactor dtensor test to reflect real world scenario * more comments * fix tp olmo hybrid and exaone * Enhance tensor parallel weight tying logic to prevent clobbering of lm_head when embed_tokens is not in the plan. * fix fsdp mixin test due to missing args * fix test non model * skip sp plan for exaone and olmo hybrid * linting * fix import for ci * test distributed config * attempt to fix guarding import ci * fix ci check repro * add ALL_PARALLEL_STYLES registry alongside TPStyle * route apply_tensor_parallel through ALL_PARALLEL_STYLES * migrate modular files to string-based TP plans * migrate standalone configs and modelings to string-based TP plans * delete TPStyle dataclass * fix use_local_output defaults for SequenceParallel and PrepareModuleInput in registry * use parallel style from torch * revert changes in weight converter * remove dead code in set_param_for_module * remove dead code * cleaning again * cleaning * revert change * linting * refactor dtensor shard ops * revert some stuff in core model loading * core model loading clean * guarding import * better separation tensor parall and generic utils * isolate DtensorShardOperation into a separate file * no need to patch rotary * better seperation * simplify gather_full_state_dict * simplify _replicate_dtensor * fix and clean _replicate_dtensor * better doc for DtensorShardOperation * fix saving optimizer with DCP for fused weights * save_pretrained(distributed_checkpoint=true) * linting * refactor into a single function _dtensor_from_local_like * zeros_like instead of empty_like * move tp and fsdp under distributed * distribute_model * fix deadlock when saving * clip grad norm function * maybe_disable_foreach_and_fused_for_mixed_dtensor_groups * better TP api for ease of understanding * remove shard_param to make it easier * fix import in test * _swap_dtensor_params_for_local * fix qwen3 nanochat dots1 * add tpu * move TP refactor experimentation scripts to backup branch Move ad-hoc training / verification / compare scripts off this branch into refactor-tp-dtensor-scripts so the diff stays focused on library changes. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * linting * register distributed sharding_utils and utils in __init__ Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * rename TP plan styles to match new ALL_PARALLEL_STYLES registry Replace pre-refactor names that no longer exist in src/transformers/distributed/tensor_parallel.py: rowwise -> rowwise_allreduce moe_tp_experts -> moe_experts_allreduce replicated_with_grad_allreduce -> activation_seq_dim_2 Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * enable EP * Add enable_expert_parallel configuration option in test_distributed_config * no more auto mode * edit fsdp plan to every other models * update fsdp mixin tests * linting * fix test fsdp * fsdp linting * revert gitignore * _apply within for loop * rename * doc sp plan * fix * unified settattr + torch no grad + _local_tensor * revert * linting * fix ruff * make check-repository-consistency * trigger fsdp mixin test in CI * fix fsdp ci * Reset tests/test_modeling_common.py to main Restores legitimate improvements that were accidentally undone during a stale merge of main into fsdp-vs-ddp: - Restore test_resize_embeddings_untied_no_reinit_on_post_init - Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo special-cases - Restore skip_base_model parameter on test_reverse_loading_mapping - Restore "is not None" guard on subconfig in test_initialization - Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message --------- Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>
yuchenxie4645
pushed a commit
to yuchenxie4645/transformers
that referenced
this pull request
May 28, 2026
* Revert "init FSDP through from_pretrained (huggingface#46102)" This reverts commit 0588858. * Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (huggingface#46141)" This reverts commit 634500b. * Revert "Update cohere2_moe tp_plan (huggingface#46189)" This reverts commit e65c3a2. * Revert "FSDP + TP & native save/load distributed (huggingface#45028)" This reverts commit 9ba8e85. * fix * they should have been deleted I think * these are actually needed changes * oops
kashif
pushed a commit
to kashif/transformers
that referenced
this pull request
Jun 1, 2026
* init * FSDP2 (fully_shard) integration - Add apply_fully_shard_data_parallel() with auto/manual mode block detection - FSDP vs DDP loss/grad parity tests - Distributed test helpers (testing_utils.py) - is_fsdp_enabled(), is_fsdp_managed_module() utilities - Minimal FSDP hooks in from_pretrained - FSDP-aware flash attention check * DistributedConfig + shard-on-read loading - DtensorShardOperation for range-math shard-on-read - spawn_materialize() enhancements - from_pretrained wiring for distributed config - Shard operation helpers in tensor_parallel - Shard-on-read and LoadStateDictConfig tests * TPStyle API + dense model tensor parallelism - Replace hook-based TP with DTensor-based TPStyle API - TPStyle dataclass with dense kinds: colwise, rowwise, vocab - apply_tensor_parallel() using PyTorch parallelize_module - verify_tp_plan() for plan validation - Update dense model configs (llama, mistral, qwen2, phi, glm) to TPStyle - DTensor apply_rotary_pos_emb guard for llama, mistral, qwen3 - Extended DistributedConfig with tp/fsdp size and plan fields - DistributedConfig serialization in configuration_utils - MXFP4 NotImplementedError for DTensor TP - Dense TP tests * revert some files * Add distributed training scripts - train_fsdp_tp.py: minimal FSDP+TP training example - train_fsdp_tp_torchtitan_style.py: torchtitan-style training example - verify_loading.py: save/load roundtrip verification - run_compare.sh: FSDP+TP vs FSDP-only comparison - run_verify_all.sh: run verification across all modes - tmp_generate.py: quick generation test * Remove train_fsdp_tp_torchtitan_style.py * unify the utils for fsdp * Fix CI: re-export moved FSDP utils + remove stale type: ignore - Re-export is_fsdp_enabled and is_fsdp_managed_module from integrations/fsdp.py (moved to distributed/utils.py) - Remove unused # type: ignore comments in generation/utils.py * Fix ruff formatting in core_model_loading.py * Fix ruff linting and formatting * Backport new TP/FSDP API from orchestration-save-load branch * Fix DTensor imports in Copied-from model files * MoE expert parallelism + sequence parallelism (huggingface#45408) * MoE expert parallelism + sequence parallelism - Add PackedColwiseParallel for fused gate_up_proj weights - Add MoEExpertsParallel with per-expert DTensor sharding - Add PrepareModuleInputOutput for SP allgather/split hooks - Add _AllReduceBackward for MoE routing weight gradients - Extend TPStyle with moe_experts, packed_colwise, activation, module kinds - _StridedShard handling in core_model_loading for interleaved weights - MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans - DTensor rotary_pos_emb guard for mixtral * Fix ruff linting and formatting * Fix ruff formatting in core_model_loading.py * Restore _IdentityOp accidentally removed in 25a1f48 The _IdentityOp class (added by PR huggingface#44983) was accidentally deleted during the MoE expert parallelism work. It is needed by finegrained_fp8.py and metal_quantization.py as a pass-through reverse_op for dequantize operations. Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]> * Backport new TP/FSDP API + fix DTensor imports in Copied-from models * from_pretrained orchestration + distributed save/load (huggingface#45409) * from_pretrained orchestration + save/load - Add gather_full_state_dict() for DTensor→full tensor saving - Add convert_strided_to_shard() / restore_strided_from_shard() for DCP - Add _redistribute_dtensor() helper - Full distributed_config integration in from_pretrained/save_pretrained - Rename apply_fsdp2 → apply_fully_shard_data_parallel - save_optimizer() / load_optimizer() in distributed/utils - Trainer integration with distributed_config - Updated FSDP and TP tests for new orchestration API - DTensor shard-on-read test updates * revert distributed utils * eaaea * all tests for core modeling are passing * populate import from init for tp * ruff * ruff --------- Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]> * do monkey patching for rotary * Revert modeling file diffs to match fsdp-core-model-loading base Restores modeling files to their base branch versions so the PR diff only shows the distributed/patches.py monkey-patch approach instead of noisy function moves in modeling files. Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]> * Migrate all model TP plans from strings to TPStyle - Convert string plan values ("colwise", "rowwise", etc.) to TPStyle objects across 66+ model configs and modular files - Consolidate MoE expert sub-entries into TPStyle("moe_experts", ...) with shard_plan - Remove "replicated_with_grad_allreduce" entries (not needed for DTensor TP) - Migrate _tp_plan class attributes in modeling files from "colwise_gather_output" to TPStyle("colwise", "allgather") - Add TypeError in apply_tensor_parallel for unsupported plan values - Remove old TensorParallelLayer tests (API removed in DTensor refactor) - Regenerate auto-generated files via modular converter * Restore mxfp4.py to match base branch * Drop mla_kv_a_proj and moe_identity_expert from TP plans These string plan values have no TPStyle equivalent in the DTensor system. Remove them to avoid TypeError at apply_tensor_parallel time. Affected models: deepseek_v2, glm4_moe_lite, glm_moe_dsa, longcat_flash. * more comments * fix tp for most models. PyTorch doesn't implement all placement conversions (e.g. _StridedShard↔Shard). We force replicate beforehand * fix tp through _replicate_dtensor * revert small change * push temporary fix for TP and strided shard for backward * refactor a bit * patches for rotary * refactor MoEExpertsParallel * fix tp for last models * refactor moe expert parallels * linting * add sp plan for models * add deepseek v2 sp plan * undo sp plan for some tricky models * remove lm_head from config * first pass of refactoring dtensor shard operator * better refacto * batter explanation of DtensorShardOperation * refactor dtensor test to reflect real world scenario * more comments * fix tp olmo hybrid and exaone * Enhance tensor parallel weight tying logic to prevent clobbering of lm_head when embed_tokens is not in the plan. * fix fsdp mixin test due to missing args * fix test non model * skip sp plan for exaone and olmo hybrid * linting * fix import for ci * test distributed config * attempt to fix guarding import ci * fix ci check repro * add ALL_PARALLEL_STYLES registry alongside TPStyle * route apply_tensor_parallel through ALL_PARALLEL_STYLES * migrate modular files to string-based TP plans * migrate standalone configs and modelings to string-based TP plans * delete TPStyle dataclass * fix use_local_output defaults for SequenceParallel and PrepareModuleInput in registry * use parallel style from torch * revert changes in weight converter * remove dead code in set_param_for_module * remove dead code * cleaning again * cleaning * revert change * linting * refactor dtensor shard ops * revert some stuff in core model loading * core model loading clean * guarding import * better separation tensor parall and generic utils * isolate DtensorShardOperation into a separate file * no need to patch rotary * better seperation * simplify gather_full_state_dict * simplify _replicate_dtensor * fix and clean _replicate_dtensor * better doc for DtensorShardOperation * fix saving optimizer with DCP for fused weights * save_pretrained(distributed_checkpoint=true) * linting * refactor into a single function _dtensor_from_local_like * zeros_like instead of empty_like * move tp and fsdp under distributed * distribute_model * fix deadlock when saving * clip grad norm function * maybe_disable_foreach_and_fused_for_mixed_dtensor_groups * better TP api for ease of understanding * remove shard_param to make it easier * fix import in test * _swap_dtensor_params_for_local * fix qwen3 nanochat dots1 * add tpu * move TP refactor experimentation scripts to backup branch Move ad-hoc training / verification / compare scripts off this branch into refactor-tp-dtensor-scripts so the diff stays focused on library changes. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * linting * register distributed sharding_utils and utils in __init__ Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * rename TP plan styles to match new ALL_PARALLEL_STYLES registry Replace pre-refactor names that no longer exist in src/transformers/distributed/tensor_parallel.py: rowwise -> rowwise_allreduce moe_tp_experts -> moe_experts_allreduce replicated_with_grad_allreduce -> activation_seq_dim_2 Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * enable EP * Add enable_expert_parallel configuration option in test_distributed_config * no more auto mode * edit fsdp plan to every other models * update fsdp mixin tests * linting * fix test fsdp * fsdp linting * revert gitignore * _apply within for loop * rename * doc sp plan * fix * unified settattr + torch no grad + _local_tensor * revert * linting * fix ruff * make check-repository-consistency * trigger fsdp mixin test in CI * fix fsdp ci * Reset tests/test_modeling_common.py to main Restores legitimate improvements that were accidentally undone during a stale merge of main into fsdp-vs-ddp: - Restore test_resize_embeddings_untied_no_reinit_on_post_init - Restore clipseg / Timm / evolla / parakeet_* / pi0 / musicflamingo special-cases - Restore skip_base_model parameter on test_reverse_loading_mapping - Restore "is not None" guard on subconfig in test_initialization - Fix typo: "ot" -> "or" in test_reverse_loading_mapping assert message --------- Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>
kashif
pushed a commit
to kashif/transformers
that referenced
this pull request
Jun 1, 2026
* Revert "init FSDP through from_pretrained (huggingface#46102)" This reverts commit 0588858. * Revert "Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (huggingface#46141)" This reverts commit 634500b. * Revert "Update cohere2_moe tp_plan (huggingface#46189)" This reverts commit e65c3a2. * Revert "FSDP + TP & native save/load distributed (huggingface#45028)" This reverts commit 9ba8e85. * fix * they should have been deleted I think * these are actually needed changes * oops
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Highlights
DistributedConfigtofrom_pretrainedto applyTP,FSDP2, or both — sharding happens during weight load (shard-on-read), notafter.
are dicts of
{module_pattern: style_name}; style names ("colwise_allgather", "rowwise_allreduce", "packed_colwise", "moe_experts_allreduce", "colwise_loss_parallel", …)resolve through theALL_PARALLEL_STYLESregistry totorch.distributed.tensor.parallel.ParallelStyleinstances and are applied by torch'sparallelize_module.tensors and immediately slices down to its local
DTensorshard for any combinationof placements on any-D mesh —
Replicate, Shard(d), _StridedShard(d, sf=N). Theclass encapsulates the
(mesh, placements)pair so slicing logic isn't repeated atevery call site, and handles both shipped checkpoint layouts: one stacked tensor
(e.g.
[num_experts, in, out]) orN per-experttensors. No full-tensormaterialization on
rank 0; no post-load redistribute.toggled with enable_sequence_parallel=True.
grouped_mm kernels, including
_StridedShardfor interleaved shards and anall-reduce-on-backwardpath forrouting weightsvia_AllReduceBackward.get_fusion_metadata/unfuse_optimizer_state/fuse_optimizer_state) so a singleoptimizer_stateslotcovering a fused parameter like
gate_up_projis split/rejoined cleanly acrosssave/load.FSDP=2 × TP=2, reloadunder
TP=4, continue training — same checkpoint, different topology. The PR's demotrains 5 steps under one config, reloads under another, finishes training, then runs
inference under a third — and verifies the model overfits the target sentence
verbatim.
distributed_checkpoint) writes a fully-gathered, plain safetensors checkpoint
that loads anywhere — single GPU, different parallelism, vLLM, etc.
clip_grad_normhandlesDTensorparameters acrossthe full mesh; optimizer
save/loadauto-disablesforeach / fusedkernels onparameter groups that mix regular tensors and
DTensors.Reproduction