Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Sep 17, 2024

Stack from ghstack (oldest at bottom):

CUDA_VISIBLE_DEVICES=2,3,6,7 pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_train_parity_2d_transformer

cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Differential Revision: D62964658

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136237

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit eaf893c with merge base c64ae60 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Sep 17, 2024
if isinstance(grad, AsyncCollectiveTensor):
grad = grad.wait()
assert isinstance(grad, DTensor), f"{type(grad)}"
if any(pl.is_partial() for pl in grad.placements):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we changed any partial placements to replicate mainly targeting the case where replicated RMSNorm.weight had partial gradients, and we needed to trigger the all-reduce by converting from partial to replicated.

However, for the pos_embeddings in our toy Transformer class, we have Shard(0) placement and still Partial grad. We do not want to convert it to Replicate but rather to Shard(0).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that size the viewing into reduce-scatter output uses torch.as_strided, we were silently handling a larger/replicated gradient for pos_embeddings.weight.

@awgu
Copy link
Collaborator Author

awgu commented Sep 18, 2024

cc: @mori360 do you know if I need to make any changes to CI files to have the added test run in CI? We prefer to run it with 4 GPUs.

@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Sep 18, 2024
@awgu awgu marked this pull request as ready for review September 18, 2024 00:01
@mori360
Copy link
Contributor

mori360 commented Sep 18, 2024

do you know if I need to make any changes to CI files to have the added test run in CI? We prefer to run it with 4 GPUs.

The 2d_composability file is tested under multigpu-test.sh, we don't need to add changes in CI file.

placements = [
Replicate() if pl.is_partial() else pl for pl in grad.placements
]
placements = self._tp_spec.placements
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @tianyu-l on this change just as a heads up

for ref_param, (param_name, param) in zip(
ref_model.parameters(), model.named_parameters()
):
full_grad = param.grad.full_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also assert param.grad is sharded here? The test seems not differentiating Replicate vs Shard

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check specifically for pos_embeddings.weight and its gradient's TP placement because that was the parameter that exercised the bug. It is hard to assert param.grad is sharded generally since we have to case on the parallelize plan (e.g. norm weights are not sharded).

```
CUDA_VISIBLE_DEVICES=2,3,6,7 pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_train_parity_2d_transformer
```


cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
@awgu
Copy link
Collaborator Author

awgu commented Sep 18, 2024

@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 19, 2024
@awgu
Copy link
Collaborator Author

awgu commented Sep 19, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
```
CUDA_VISIBLE_DEVICES=2,3,6,7 pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_train_parity_2d_transformer
```

Differential Revision: [D62964658](https://our.internmc.facebook.com/intern/diff/D62964658)
Pull Request resolved: pytorch#136237
Approved by: https://github.com/weifengpy
@github-actions github-actions bot deleted the gh/awgu/641/head branch October 20, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants