-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP2] Relaxed even sharding requirement for all-gather extensions #137005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137005
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit 7767d23 with merge base 8c29a0d ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…xtensions" cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
|
cc: @weifengpy for your thoughts on this |
| out: Optional[torch.Tensor] = None, | ||
| ) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: | ||
| assert metadata is None, f"{metadata}" | ||
| (tensor,) = all_gather_outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious where we removed the padding? the unit test works so it should be handled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good question. FSDP will trim the padding for you given the original size of the parameter.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test Details for Dev Infra teamRaised by workflow job |
|
linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test failure is not related: |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_1-test / test Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test Details for Dev Infra teamRaised by workflow job |
|
Same issue, looks unrelated: |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 4 checks: linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_1-test / test, linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
This PR relaxes the even sharding requirement for the all-gather extensions.
The
fsdp_pre_all_gathernow expects signature:def fsdp_pre_all_gather( self, mesh: DeviceMesh, + outer_size: torch.Size, + outer_stride: Tuple[int, ...], module: nn.Module, mp_policy: MixedPrecisionPolicy, ) -> Tuple[Tuple[torch.Tensor, ...], Any]:outer_stridewill always be contiguous strides since FSDP2 only supports contiguous strides for now.fsdp_pre_all_gather. This is risky territory because if the user does not do so, then this may manifest as a NCCL timeout, as only the ranks with padding will error out. However, I am not aware of any way around this.cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o