Skip to content

Conversation

@rahulsingh-intel
Copy link
Contributor

@rahulsingh-intel rahulsingh-intel commented Nov 5, 2024

Motivation: Generalize unit tests so that can be executed for cuda and non cuda devices.
Chnages: There are general changes in common_dtesnor module for device type generalization so that tests can be executed on non cuda devices too.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139749

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit bbd0afc with merge base 73278e6 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 5, 2024
@rahulsingh-intel
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Nov 5, 2024
@rahulsingh-intel
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorchbot label "topic: not user facing"

hi @kwen2501 , @fegin please review the changes

@rahulsingh-intel
Copy link
Contributor Author

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 7, 2024

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@ankurneog
Copy link

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased dtensor_common onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout dtensor_common && git pull --rebase)

@rahulsingh-intel
Copy link
Contributor Author

Successfully rebased dtensor_common onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout dtensor_common && git pull --rebase)

hi @kwen2501 , can you please review and approve the changes

@colesbury colesbury requested a review from wconstab November 8, 2024 17:04
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 8, 2024
Copy link
Collaborator

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the effort. Overall looks good to me. I left a couple comments. I think the goal is to use device_type as much as possible and reduce the dependency on cuda or hpu strings. You did that well in most places, just need to cover the few ones left. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education, why is there a numerical difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corrected.

Comment on lines 984 to 990
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please restore the original format?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this device_type definition used somewhere given that there is a

    @property
    def device_type(self) -> str:

below?
nit: can you restore the original two-line formatting?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make device_id a property of the class?

@property
def device_id(self) -> torch.device:
    device_count = torch.get_device_module(self.device_type).device_count()
    return torch.device(device_type, self.rank % device_count)

Thanks!

Comment on lines 351 to 352
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could truly generalize this line than adding another device string here?

Comment on lines 40 to 48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could first define DEVICE_TYPE then DEVICE_COUNT so that we can call _get_device_module(DEVICE_TYPE) then device_count()?

Comment on lines 49 to 51
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could build a map instead? Thanks!

Comment on lines 304 to 318
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about reusing the PG_BACKEND map above?

Copy link
Contributor

@zhangxiaoli73 zhangxiaoli73 Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we still keep those if...else statement, then a mapping list should be maintained and changed for each new backend added even it's out-of-tree.

Then, is it possible to use undefined backend which will find corresponding device backend by device type in process group init? e.g. nccl for cuda, gloo for cpu, hccl for hpu and xccl for xpu if hccl and xccl is registered out-of-tree.

In code logic,

  1. Init process group with undefined backend.
    https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L1569

  2. By backend name to query a backend_config which has a device_backend_map.
    https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L1816

undefined backend name will use default device backend map , which is gloo for cpu, nccl for cuda. Out-of-tree backend will also be added to this default map by https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L254

  1. Register available backend in backend_config.get_device_backend_map() to process group
    https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L1818

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about reusing the PG_BACKEND map above?

hi @kwen2501 , modified. can you review please

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we still keep those if...else statement, then a mapping list should be maintained and changed for each new backend added even it's out-of-tree.

Then, is it possible to use undefined backend which will find corresponding device backend by device type in process group init? e.g. nccl for cuda, gloo for cpu, hccl for hpu and xccl for xpu if hccl and xccl is registered out-of-tree.

In code logic,

  1. Init process group with undefined backend.
    https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L1569
  2. By backend name to query a backend_config which has a device_backend_map.
    https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L1816

undefined backend name will use default device backend map , which is gloo for cpu, nccl for cuda. Out-of-tree backend will also be added to this default map by https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L254

  1. Register available backend in backend_config.get_device_backend_map() to process group
    https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L1818

This can be straight one -to-one mapping for default backends :
#140536

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @kwen2501 please review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @kwen2501 CI ran fine, please approve after review.

@ankurneog
Copy link

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased dtensor_common onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout dtensor_common && git pull --rebase)

@rahulsingh-intel
Copy link
Contributor Author

Successfully rebased dtensor_common onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout dtensor_common && git pull --rebase)

hi @kwen2501 please review.

@ankurneog
Copy link

@kwen2501 : Could you please help with the approval , it would be good if we can push this before the code freeze for v2.6.0. Thank you

@rahulsingh-intel
Copy link
Contributor Author

@kwen2501 Gentle reminder !

@ankurneog
Copy link

@kwen2501 : Gentle reminder, could you please help with the approval. thank you.

Copy link
Collaborator

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Sorry about the delay

@rahulsingh-intel
Copy link
Contributor Author

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/139749/head returned non-zero exit code 1

Rebasing (1/12)
Auto-merging test/distributed/_tensor/test_dtensor_compile.py
Auto-merging test/distributed/_tensor/test_random_ops.py
CONFLICT (content): Merge conflict in test/distributed/_tensor/test_random_ops.py
Auto-merging test/distributed/_tensor/test_redistribute.py
Auto-merging torch/testing/_internal/distributed/_tensor/common_dtensor.py
error: could not apply 57761392580... Tests Generelization for multiple accelerator devices
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Could not apply 57761392580... Tests Generelization for multiple accelerator devices

Raised by https://github.com/pytorch/pytorch/actions/runs/12279796394

@rahulsingh-intel
Copy link
Contributor Author

@pytorchmergebot rebase

@rahulsingh-intel
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@rahulsingh-intel
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@guangyey
Copy link
Collaborator

"Try to land this since the failure is unrelated."
@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants