Skip to content

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Aug 15, 2025

What does this PR do?

Add support for context parallel in the Trainer

self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
self.is_cp_enabled = (
getattr(self.accelerator.state, "parallelism_config", None) is not None
and getattr(self.accelerator.state.parallelism_config, "cp_size", 1) > 1
Copy link
Contributor

@SalmanMohammadi SalmanMohammadi Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we only rely onparallelism_config to configure CP?

self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
self.is_cp_enabled = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to just use self.parallelism_config = getattr(self.accelerator.parallelism_config, None), and also to have a ref for parallelism_config in TrainerState

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 2102 to 2114
if not self.fsdp:
from accelerate.utils import FullyShardedDataParallelPlugin

self.fsdp_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
state_dict_type="FULL_STATE_DICT",
)
else:
# Ensure FSDP v2 is used when context parallelism is enabled
if self.fsdp_config.get("version", 1) != 2:
logger.warning("Context parallelism requires FSDP v2. Updating FSDP config to use version 2.")
self.fsdp_config["version"] = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it warn the user when it's enabling FSDP without explicit configuration from the user?

and num_items_in_batch is not None
):
loss *= self.accelerator.num_processes
# if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Need to understand why we need this realistically.

@S1ro1
Copy link
Contributor

S1ro1 commented Aug 18, 2025

Current status is a bit funky, the loss logging seems to be very weird, with outputs being correct, but logged losses being off sometimes:
image

@kashif
Copy link
Contributor Author

kashif commented Aug 22, 2025

@SunMarc I have fixed the issues you raised

logger.info(f"Saving model checkpoint to {output_dir}")

# Defer to accelerate's get_state_dict when using distributed setups that require special state dict handling
if state_dict is None and (self.is_fsdp2 or self.is_deepspeed_enabled):
Copy link
Contributor

@S1ro1 S1ro1 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this at all. save_pretrained works with torch parallelism just ok. I suppose we do want to keep this for non transformers models only?

@S1ro1
Copy link
Contributor

S1ro1 commented Aug 22, 2025

Failing tests seem unrelated.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! A few nits but overall LGTM

Comment on lines 4286 to 4287
if state_dict is None and (getattr(self.accelerator, "is_fsdp2", False) or self.is_deepspeed_enabled):
state_dict = self.accelerator.get_state_dict(self.model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an issue with how things are currently handled ? just to better understand

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would just silently fail at this point, but it's with custom models which is a rather rare use-case.

Comment on lines 3862 to 3865
if (
getattr(self.accelerator, "parallelism_config") is not None
and self.accelerator.parallelism_config.cp_enabled
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still need to fix that potentially but we can do that in a follow up otherwise

@SunMarc SunMarc requested a review from winglian August 22, 2025 12:50
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit and we can merge !

@kashif
Copy link
Contributor Author

kashif commented Aug 25, 2025

on it @SunMarc

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

@SunMarc SunMarc enabled auto-merge (squash) August 26, 2025 09:20
@SunMarc SunMarc merged commit 6d2bb1e into main Aug 26, 2025
25 checks passed
@SunMarc SunMarc deleted the trainer-cp branch August 26, 2025 09:28
@sfc-gh-sbekman
Copy link
Contributor

sfc-gh-sbekman commented Oct 23, 2025

Am I missing something or was this feature merged w/o adding any tests?

I'm working on an HF Trainer integration PR for ALST/UlyssesSP via huggingface/accelerate#3817 and I was hoping to have some existing CP tests I could extend/copy but I can't find any.

How will you know if this feature breaks if you have no tests? The Accelerate side doesn't test most of this feature either. I'm puzzled.

@kashif
Copy link
Contributor Author

kashif commented Oct 25, 2025

CI test PR #41860

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants