-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP2] Fixed 2D mismatched grad placements #136237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136237
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit eaf893c with merge base c64ae60 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| if isinstance(grad, AsyncCollectiveTensor): | ||
| grad = grad.wait() | ||
| assert isinstance(grad, DTensor), f"{type(grad)}" | ||
| if any(pl.is_partial() for pl in grad.placements): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, we changed any partial placements to replicate mainly targeting the case where replicated RMSNorm.weight had partial gradients, and we needed to trigger the all-reduce by converting from partial to replicated.
However, for the pos_embeddings in our toy Transformer class, we have Shard(0) placement and still Partial grad. We do not want to convert it to Replicate but rather to Shard(0).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that size the viewing into reduce-scatter output uses torch.as_strided, we were silently handling a larger/replicated gradient for pos_embeddings.weight.
|
cc: @mori360 do you know if I need to make any changes to CI files to have the added test run in CI? We prefer to run it with 4 GPUs. |
The 2d_composability file is tested under multigpu-test.sh, we don't need to add changes in CI file. |
| placements = [ | ||
| Replicate() if pl.is_partial() else pl for pl in grad.placements | ||
| ] | ||
| placements = self._tp_spec.placements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @tianyu-l on this change just as a heads up
| for ref_param, (param_name, param) in zip( | ||
| ref_model.parameters(), model.named_parameters() | ||
| ): | ||
| full_grad = param.grad.full_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also assert param.grad is sharded here? The test seems not differentiating Replicate vs Shard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a check specifically for pos_embeddings.weight and its gradient's TP placement because that was the parameter that exercised the bug. It is hard to assert param.grad is sharded generally since we have to case on the parallelize plan (e.g. norm weights are not sharded).
``` CUDA_VISIBLE_DEVICES=2,3,6,7 pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_train_parity_2d_transformer ``` cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
``` CUDA_VISIBLE_DEVICES=2,3,6,7 pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_train_parity_2d_transformer ``` Differential Revision: [D62964658](https://our.internmc.facebook.com/intern/diff/D62964658) Pull Request resolved: pytorch#136237 Approved by: https://github.com/weifengpy
Stack from ghstack (oldest at bottom):
shard_placement_fnarg #136221cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o
Differential Revision: D62964658