Skip to content

Conversation

@weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Sep 4, 2024

Stack from ghstack (oldest at bottom):

when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward

RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device

this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading

FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: ['0.weight']

pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135156

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 5a6a402 with merge base 356f14e (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Sep 4, 2024
weifengpy added a commit that referenced this pull request Sep 4, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d9d7725
Pull Request resolved: #135156
@weifengpy weifengpy marked this pull request as draft September 4, 2024 22:24
…ading"



`pytest -s distributed/_composable/fsdp/test_fully_shard_training.py -k test_to_float64_after_init`

resolve cpu offload error in TorchTune: meta-pytorch/torchtune#1412

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…ading"



`pytest -s distributed/_composable/fsdp/test_fully_shard_training.py -k test_to_float64_after_init`

resolve cpu offload error in TorchTune: meta-pytorch/torchtune#1412

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
@weifengpy weifengpy changed the title [FSDP2] move DTensor and local tensor to cpu for cpu offloading [FSDP2] construct DTensor parameters from cpu offloaded local tensors Sep 5, 2024
…cal tensors"



`pytest -s distributed/_composable/fsdp/test_fully_shard_training.py -k test_to_float64_after_init`

resolve cpu offload error in TorchTune: meta-pytorch/torchtune#1412

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 5, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 64b2d1c
Pull Request resolved: #135156
@weifengpy weifengpy marked this pull request as ready for review September 5, 2024 22:47
@weifengpy weifengpy requested a review from awgu September 5, 2024 22:47
@weifengpy weifengpy marked this pull request as draft September 7, 2024 00:11
@weifengpy
Copy link
Contributor Author

synced and I will modify the PR to throw error for gpu state dict instead of moving gpu state dict to cpu implicitly

…cal tensors"


resolve cpu offload error in TorchTune: meta-pytorch/torchtune#1412

this PR constructs DTensor from cpu offloaded local tensor

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…cal tensors"


resolve cpu offload error in TorchTune: meta-pytorch/torchtune#1412

this PR constructs DTensor from cpu offloaded local tensor

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…cal tensors"


resolve cpu offload error in TorchTune: meta-pytorch/torchtune#1412

this PR constructs DTensor from cpu offloaded local tensor

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…cal tensors"


resolve cpu offload error in TorchTune: meta-pytorch/torchtune#1412

this PR constructs DTensor from cpu offloaded local tensor

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward
```
RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device
```

this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading

```
FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: {param_names_not_on_cpu}
```

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 10, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4f88a48
Pull Request resolved: #135156
@weifengpy weifengpy marked this pull request as ready for review September 10, 2024 17:02
@weifengpy
Copy link
Contributor Author

@awgu I repurposed the PR to throw error msg when loading gpu state dict. ready for review

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just one suggestion for including the device in the error message for _validate_cpu_offload_params

]
if param_names_not_on_cpu:
raise RuntimeError(
"FSDP parameters should be materialized on cpu when enabling cpu offloading. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can capitalize CPU

Suggested change
"FSDP parameters should be materialized on cpu when enabling cpu offloading. "
"FSDP parameters should be materialized on CPU when enabling CPU offloading. "

if param_names_not_on_cpu:
raise RuntimeError(
"FSDP parameters should be materialized on cpu when enabling cpu offloading. "
'For example, load cpu state dict or call module.to_empty(device="cpu"). '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'For example, load cpu state dict or call module.to_empty(device="cpu"). '
'For example, load a CPU state dict or call module.to_empty(device="cpu"). '

Comment on lines 580 to 584
param_names_not_on_cpu = [
fsdp_param._param_fqn
for fsdp_param in self.fsdp_params
if fsdp_param.sharded_param.device.type != "cpu"
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a suggestion related to below. Specifically, I think it would be helpful to include the device in the error message in this case.

Suggested change
param_names_not_on_cpu = [
fsdp_param._param_fqn
for fsdp_param in self.fsdp_params
if fsdp_param.sharded_param.device.type != "cpu"
]
fsdp_params_not_on_cpu = [
fsdp_param
for fsdp_param in self.fsdp_params
if fsdp_param.sharded_param.device.type != "cpu"
]

raise RuntimeError(
"FSDP parameters should be materialized on cpu when enabling cpu offloading. "
'For example, load cpu state dict or call module.to_empty(device="cpu"). '
f"Found following parameters on non-cpu device: {param_names_not_on_cpu}\n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

part 2 of the suggestion of including the device in the error message
(needs some formatting)

Suggested change
f"Found following parameters on non-cpu device: {param_names_not_on_cpu}\n"
f"Found following parameters on non-cpu device: {[(fsdp_param._param_fqn, fsdp_param.sharded_param.device) for fsdp_param in fsdp_params_not_on_cpu]]}\n"

when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward
```
RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device
```

this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading

```
FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: ['0.weight']
```

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 10, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 61ecbe5
Pull Request resolved: #135156
@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 10, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / build

Details for Dev Infra team Raised by workflow job

when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward
```
RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device
```

this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading

```
FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: ['0.weight']
```

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 11, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 037e2d8
Pull Request resolved: #135156
@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward
```
RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device
```

this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading

```
FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: ['0.weight']
```

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`

Pull Request resolved: pytorch#135156
Approved by: https://github.com/awgu
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward
```
RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device
```

this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading

```
FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: ['0.weight']
```

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload`

Pull Request resolved: pytorch#135156
Approved by: https://github.com/awgu
@github-actions github-actions bot deleted the gh/weifengpy/14/head branch October 12, 2024 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants