-
Notifications
You must be signed in to change notification settings - Fork 31.4k
[Trainer] accelerate contextparallel support in trainer #40205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/transformers/trainer.py
Outdated
| self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None | ||
| self.is_cp_enabled = ( | ||
| getattr(self.accelerator.state, "parallelism_config", None) is not None | ||
| and getattr(self.accelerator.state.parallelism_config, "cp_size", 1) > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we only rely onparallelism_config to configure CP?
src/transformers/trainer.py
Outdated
| self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None | ||
| self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None | ||
| self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None | ||
| self.is_cp_enabled = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to just use self.parallelism_config = getattr(self.accelerator.parallelism_config, None), and also to have a ref for parallelism_config in TrainerState
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
src/transformers/training_args.py
Outdated
| if not self.fsdp: | ||
| from accelerate.utils import FullyShardedDataParallelPlugin | ||
|
|
||
| self.fsdp_plugin = FullyShardedDataParallelPlugin( | ||
| fsdp_version=2, | ||
| auto_wrap_policy="transformer_based_wrap", | ||
| state_dict_type="FULL_STATE_DICT", | ||
| ) | ||
| else: | ||
| # Ensure FSDP v2 is used when context parallelism is enabled | ||
| if self.fsdp_config.get("version", 1) != 2: | ||
| logger.warning("Context parallelism requires FSDP v2. Updating FSDP config to use version 2.") | ||
| self.fsdp_config["version"] = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it warn the user when it's enabling FSDP without explicit configuration from the user?
src/transformers/trainer.py
Outdated
| and num_items_in_batch is not None | ||
| ): | ||
| loss *= self.accelerator.num_processes | ||
| # if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Need to understand why we need this realistically.
|
@SunMarc I have fixed the issues you raised |
src/transformers/trainer.py
Outdated
| logger.info(f"Saving model checkpoint to {output_dir}") | ||
|
|
||
| # Defer to accelerate's get_state_dict when using distributed setups that require special state dict handling | ||
| if state_dict is None and (self.is_fsdp2 or self.is_deepspeed_enabled): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this at all. save_pretrained works with torch parallelism just ok. I suppose we do want to keep this for non transformers models only?
|
Failing tests seem unrelated. |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! A few nits but overall LGTM
src/transformers/trainer.py
Outdated
| if state_dict is None and (getattr(self.accelerator, "is_fsdp2", False) or self.is_deepspeed_enabled): | ||
| state_dict = self.accelerator.get_state_dict(self.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there an issue with how things are currently handled ? just to better understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would just silently fail at this point, but it's with custom models which is a rather rare use-case.
| if ( | ||
| getattr(self.accelerator, "parallelism_config") is not None | ||
| and self.accelerator.parallelism_config.cp_enabled | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still need to fix that potentially but we can do that in a follow up otherwise
Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit and we can merge !
|
on it @SunMarc |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks !
|
Am I missing something or was this feature merged w/o adding any tests? I'm working on an HF Trainer integration PR for ALST/UlyssesSP via huggingface/accelerate#3817 and I was hoping to have some existing CP tests I could extend/copy but I can't find any. How will you know if this feature breaks if you have no tests? The Accelerate side doesn't test most of this feature either. I'm puzzled. |
|
CI test PR #41860 |

What does this PR do?
Add support for context parallel in the Trainer