-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Generalization of distributed UT content to enable non cuda device execution #131758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalization of distributed UT content to enable non cuda device execution #131758
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131758
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit 5fb4048 with merge base beb46de ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
559aadf to
45baafa
Compare
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
45baafa to
0fcc085
Compare
|
@albanD : could you please help with the review. Thanks |
0fcc085 to
fee499d
Compare
7b2e1cd to
c6401b0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wonder if we really need a new class, or we should just update the behavior of MultiProcessTestCase in a safe way?
It seems like all of these changes are designed to be 'safe' but the current PR doesn't test this. Modifying the existing class would force running it on all the existing tests and make sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab : thanks for your comment. Yes, I was also thinking on similar lines initially, however that would need extensive changes to all the already existing files. The idea is to incrementally introduce change to the existing files with the new class and eventually merge the MultiProcessTestCase with the new class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think its fine to make all the extensive changes at once, given they are exercised by CI. its actually riskier to land a new thing that isn't well tested in case it doesn't really work as advertised or in case we only half-migrate to it.
also cc @fegin who has been thinking more about test regfactors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think its fine to make all the extensive changes at once, given they are exercised by CI. its actually riskier to land a new thing that isn't well tested in case it doesn't really work as advertised or in case we only half-migrate to it.
also cc @fegin who has been thinking more about test base refactors. We should probably decide if we want to adopt DistributedTestBase more widely or build more on top of MPTC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @wconstab , will check how much is the impact of adapting the change in all files.
But in general there are lot of different ways this is getting implemented
For eg: class MultiProcContinousTest introduced with #125648 implements a mechanism that do not use MPTC but devices related nuances are passed to the individual test file for implementation
Whereas class DTensorTestBase in torch/testing/_internal/distributed/_tensor/common_dtensor.py , inherits MPTC but tries put the devices related nuances in the common code.
and there are a few more
hence, the suggestion of staggered migration to a common base class that takes into account all the challenges module by module. Since Gaudi/XPU wants to use distributed test content for validation, we can take up this migration effort.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps gradually refactoring things out of MPTC like what this PR proposes is better considering there are multiple features coupled in the MPTC base class right now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess our consideration is the number of TestCases to maintain.
What tests are you interested in today? test_c10d_... ? If so, it seems we can integrate the changes into MPTC first. Then gradually to other TCs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The device-backend logic should probably live in distributed_c10d.py somewhere. I'm not sure the best way to do it, but it would seem reasonable to have a helper that is usable not only in tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I think in general we have lot of scope of harmonizing the APIs for different devices in torch.distributed , can I add that as part of a follow up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking i've seen this code snippet before. i found at least one case in profiler-
if self.use_device and hasattr(torch, self.use_device):
device_module = getattr(torch, self.use_device)
I think it'd be better to provide a utility for this perhaps in torch.device module. Something like this
def get_module(device: Union[str, torch.device]): -> Module
Maybe something you could add in a follow up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab : there are multiple places where the handle is retrieved from device ,
eg: torch/distributed/fsdp/_common_utils.py
_get_device_handle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree that _get_device_module seems to be a more general way than calling:
if TEST_ABC and torch.abc.device_count() > 0 everywhere.
c6401b0 to
677086a
Compare
|
@wconstab : Could you please let me know if you are ok with my responses? |
677086a to
d93f5d7
Compare
|
@wconstab : could you please provide your response. thank you. |
kwen2501
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of using device_handle() or _get_device_module() to make our tests device agnostic. I recommend picking one TestCase of most interest to refactor first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the device_handle way (as later in this PR) for every device type?
We can have a global device_handle that gets set based on TEST_HPU, TEST_CUDA, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this mapping is available through the backend_feature map?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree that _get_device_module seems to be a more general way than calling:
if TEST_ABC and torch.abc.device_count() > 0 everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess our consideration is the number of TestCases to maintain.
What tests are you interested in today? test_c10d_... ? If so, it seems we can integrate the changes into MPTC first. Then gradually to other TCs.
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
d93f5d7 to
5fb4048
Compare
This PR needs a
|
# Motivation This pr is an extension of #131758. As described in #131758, these changes are looking to make distributed UTs more accessible to users of all device types. It is a demonstration of a few changes discussed by @kwen2501 and @jgong5 in the discussion for #131758(#131758 (comment)) This PR contains two types of changes, the first is to the common distributed folder where we have added a new class derived from MultiProcessTestCase which helps abstracts out the process group creation /deletion and other functionality for a given device. The new generalized content can be added by deriving from this base class. Also includes other misc changes for gaudi support The second changed file is test_functional_api. a test file in common distributed. This file is a POC for how we can use this new class to write more device agnostic distributed test cases. The following changes have been made to test_functional_api.py: -Functionality has been added to test for non cuda devices using intel HPU as an example -Multiple set up steps previously required by MultiProcessTestCase have been abstracted out -Misc adaptations to allow for general call to accelerators while adding test skips instead explicitly skipping for multiple GPUs -Skipifhpu flags have been added to enable skipping a few Multithreaded test cases which are as yet not supported on HPUs NOTE: Within test functional api, there are tests which require the use of some multithreading functions which are as yet not supported on HPUs. These have been skipped for hpu using skipHPU decorator. I will be raising a separate PR to improve usability pf said decorators in a device agnostic setting in the manner suggested by @kwen2501 in a comment on this PR. This pr is a cleaned up version of a previous PR(#136988) which I closed due to human error. I have addressed some of the comments made by @kwen2501 in this as well Pull Request resolved: #138216 Approved by: https://github.com/kwen2501, https://github.com/guangyey
…h#138216) # Motivation This pr is an extension of pytorch#131758. As described in pytorch#131758, these changes are looking to make distributed UTs more accessible to users of all device types. It is a demonstration of a few changes discussed by @kwen2501 and @jgong5 in the discussion for pytorch#131758(pytorch#131758 (comment)) This PR contains two types of changes, the first is to the common distributed folder where we have added a new class derived from MultiProcessTestCase which helps abstracts out the process group creation /deletion and other functionality for a given device. The new generalized content can be added by deriving from this base class. Also includes other misc changes for gaudi support The second changed file is test_functional_api. a test file in common distributed. This file is a POC for how we can use this new class to write more device agnostic distributed test cases. The following changes have been made to test_functional_api.py: -Functionality has been added to test for non cuda devices using intel HPU as an example -Multiple set up steps previously required by MultiProcessTestCase have been abstracted out -Misc adaptations to allow for general call to accelerators while adding test skips instead explicitly skipping for multiple GPUs -Skipifhpu flags have been added to enable skipping a few Multithreaded test cases which are as yet not supported on HPUs NOTE: Within test functional api, there are tests which require the use of some multithreading functions which are as yet not supported on HPUs. These have been skipped for hpu using skipHPU decorator. I will be raising a separate PR to improve usability pf said decorators in a device agnostic setting in the manner suggested by @kwen2501 in a comment on this PR. This pr is a cleaned up version of a previous PR(pytorch#136988) which I closed due to human error. I have addressed some of the comments made by @kwen2501 in this as well Pull Request resolved: pytorch#138216 Approved by: https://github.com/kwen2501, https://github.com/guangyey
Motivation
The distributed UT content is primarily targeted for cuda devices , however other devices like the intel Gaudi do support much of the functionality validated.
Since the code has explicit cuda api calls , it becomes hard to adapt these for non cuda devices.
Here we introduce a new class derived from MultiProcessTestCase, which helps abstracts out the process group creation /deletion and other functionality for a given device.
The tests can be instantiated per device using existing utilities such as instantiate_device_type_tests , which will pass the device as argument and accordingly create the PG with the right backend
The new generalized content can be added by deriving from this base class.
Also includes other misc changes for gaudi support
Note that this is a follow-up of functionality introduced with PR : #126970
Example Use : Here is a snippet from existing content (distributed/tests/test_c10d_nccl.py) to illustrate the use
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames