Skip to content

Conversation

@medivh-xp
Copy link
Contributor

@medivh-xp medivh-xp commented Apr 13, 2023

Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.

This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.

The main issues addressed are:

1. Device decision for FSDP wrapping of Modules without Parameters

Users typically organize FSDP code as follows:

m = Module().to('my_device:0')
fsdp_m = FSDP(m)

or like this:

m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))

If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.

2. Abstraction of a cuda-like device

Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99024

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a5152f8:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Apr 13, 2023
@medivh-xp medivh-xp force-pushed the fsdp_head branch 4 times, most recently from 09b5229 to 7f2b840 Compare April 15, 2023 07:40
@medivh-xp medivh-xp changed the title Allow the FSDP constructor to run even without CUDA Making fsdp device-agnostic Apr 15, 2023
@medivh-xp medivh-xp changed the title Making fsdp device-agnostic Making fsdp device-agnostic for device implement cuda-semantics Apr 15, 2023
@medivh-xp medivh-xp force-pushed the fsdp_head branch 19 times, most recently from 1d7a2ac to f8599de Compare April 19, 2023 09:17
@medivh-xp medivh-xp marked this pull request as ready for review April 19, 2023 09:18
@medivh-xp medivh-xp requested a review from mrshenli as a code owner April 19, 2023 09:18
@medivh-xp medivh-xp force-pushed the fsdp_head branch 5 times, most recently from 7eb7bd3 to 9fb91a9 Compare April 20, 2023 11:46
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 25, 2023
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks good to me! I left some nits and can approve next review.

@medivh-xp medivh-xp force-pushed the fsdp_head branch 2 times, most recently from 9cd5566 to 4cc28f6 Compare April 26, 2023 02:51
@zhaojuanmao
Copy link
Contributor

wondering whether CI has some setup to test non-cuda devices? if so, could we add some unit tests for non-cuda devices?

@awgu
Copy link
Collaborator

awgu commented Apr 26, 2023

wondering whether CI has some setup to test non-cuda devices? if so, could we add some unit tests for non-cuda devices?

This might be challenging since I believe @medivh-xp uses a custom hardware (correct me if I am wrong), and otherwise, I am not sure if there are easily accessible CUDA-like devices that we can use in CI.

Personally, as long as this does not regress the CUDA code path, then I am okay with landing. This is similar to adding the _param_extensions for 2D support, where we added some generalizations to key points in the execution.

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! I just left one more nit.

@medivh-xp
Copy link
Contributor Author

wondering whether CI has some setup to test non-cuda devices? if so, could we add some unit tests for non-cuda devices?

This might be challenging since I believe @medivh-xp uses a custom hardware (correct me if I am wrong), and otherwise, I am not sure if there are easily accessible CUDA-like devices that we can use in CI.

Personally, as long as this does not regress the CUDA code path, then I am okay with landing. This is similar to adding the _param_extensions for 2D support, where we added some generalizations to key points in the execution.

Yes! We use custom hardware and will ensure that it supports the semantics of CUDA, so that we can directly benefit from the excellent features of the community.

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 27, 2023
@awgu
Copy link
Collaborator

awgu commented Apr 27, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged merging open source release notes: distributed (fsdp) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants