Skip to content

Conversation

@doru1004
Copy link
Contributor

@doru1004 doru1004 commented Oct 10, 2024

This patch improves the performance of individual reductions on MI300X. These improvements are measured on individual sum reduction operations of varying sizes. The patch impacts the following tensor types:

  • 1D tensors
  • 2D tensors when reducing along dimension 0
  • 2D tensors when reducing along dimension 1

Runtime reduction between 0 and 75% depending on tensor shape.

The patch uses the maximum number of threads per CU and the number of CUs itself to control the number of threadblocks in various situations (i.e. for various reduction types and tensor dimensions).

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137737

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (3 Unrelated Failures)

As of commit a510da8 with merge base c272526 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 10, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: doru1004 / name: Gheorghe-Teodor Bercea (a510da8)

@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch release notes: cuda release notes category labels Oct 10, 2024
@doru1004 doru1004 force-pushed the performance-tuning-upstream branch 2 times, most recently from 0e00ac8 to e23df4e Compare October 11, 2024 15:44
@pruthvistony pruthvistony added rocm This tag is for PRs from ROCm team ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocm Trigger "default" config CI on ROCm labels Oct 11, 2024
@doru1004 doru1004 force-pushed the performance-tuning-upstream branch from e23df4e to 993f17e Compare October 15, 2024 13:46
@facebook-github-bot
Copy link
Contributor

@Mellonta has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@doru1004
Copy link
Contributor Author

This patch also addresses this performance issue here: #132964
It handles both 1D and 2D cases (along DIM 0 and DIM 1).

For the examples in the issue above the performance impact is as follows:
1D goes up by 8%
2D (along DIM 1) goes up by 61%
2D (along DIM 0) same performance as before

@doru1004 doru1004 force-pushed the performance-tuning-upstream branch 2 times, most recently from bb2bd99 to f62670b Compare October 22, 2024 15:58
jerrymannil pushed a commit to ROCm/pytorch that referenced this pull request Oct 22, 2024
Release version of upstream PR
[137737](pytorch#137737).
This has added support for GPUs with smaller number of CUs.  
Will upstream the smaller CU optimization later, once it is baked in
release branch.
Copy link
Contributor

@xw285cornell xw285cornell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Verified it improves some op from 15ms to 1.5ms.

@jianyuh
Copy link
Member

jianyuh commented Oct 23, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased performance-tuning-upstream onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout performance-tuning-upstream && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the performance-tuning-upstream branch from f62670b to ad01179 Compare October 23, 2024 04:18
@jerrymannil
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 23, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: periodic / linux-focal-cuda11.8-py3.10-gcc9-debug / test (default, 3, 5, lf.linux.4xlarge.nvidia.gpu), Meta Internal-Only Changes Check

Details for Dev Infra team Raised by workflow job

@jerrymannil
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased performance-tuning-upstream onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout performance-tuning-upstream && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the performance-tuning-upstream branch from ad01179 to a510da8 Compare October 23, 2024 20:24
@jerrymannil
Copy link
Collaborator

@jianyuh
We are hitting "Meta Internal-Only Changes Check" failure.
Can you help with resolving it ?

@jerrymannil
Copy link
Collaborator

@jianyuh We are hitting "Meta Internal-Only Changes Check" failure. Can you help with resolving it ?

Plz ignore. Its passing now after rebase

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jerrymannil pushed a commit to ROCm/pytorch that referenced this pull request Nov 5, 2024
Release version of upstream PR
[137737](pytorch#137737).
This has added support for GPUs with smaller number of CUs.  
Will upstream the smaller CU optimization later, once it is baked in
release branch.
jerrymannil pushed a commit to ROCm/pytorch that referenced this pull request Nov 19, 2024
Release version of upstream PR
[137737](pytorch#137737).
This has added support for GPUs with smaller number of CUs.  
Will upstream the smaller CU optimization later, once it is baked in
release branch.
caaatch22 pushed a commit to caaatch22/pytorch that referenced this pull request Jan 6, 2025
…ch#137737)

This patch improves the performance of individual reductions on MI300X. These improvements are measured on individual sum reduction operations of varying sizes. The patch impacts the following tensor types:
- 1D tensors
- 2D tensors when reducing along dimension 0
- 2D tensors when reducing along dimension 1

Runtime reduction between 0 and 75% depending on tensor shape.

The patch uses the maximum number of threads per CU and the number of CUs itself to control the number of threadblocks in various situations (i.e. for various reduction types and tensor dimensions).

Pull Request resolved: pytorch#137737
Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/pruthvistony, https://github.com/xw285cornell
caaatch22 pushed a commit to caaatch22/pytorch that referenced this pull request Jan 7, 2025
…ch#137737)

This patch improves the performance of individual reductions on MI300X. These improvements are measured on individual sum reduction operations of varying sizes. The patch impacts the following tensor types:
- 1D tensors
- 2D tensors when reducing along dimension 0
- 2D tensors when reducing along dimension 1

Runtime reduction between 0 and 75% depending on tensor shape.

The patch uses the maximum number of threads per CU and the number of CUs itself to control the number of threadblocks in various situations (i.e. for various reduction types and tensor dimensions).

Pull Request resolved: pytorch#137737
Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/pruthvistony, https://github.com/xw285cornell
rocm-mici pushed a commit to ROCm/pytorch that referenced this pull request Jan 14, 2025
Release version of upstream PR
[137737](pytorch#137737).
This has added support for GPUs with smaller number of CUs.  
Will upstream the smaller CU optimization later, once it is baked in
release branch.
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Mar 17, 2025
Release version of upstream PR
[137737](pytorch#137737).
This has added support for GPUs with smaller number of CUs.  
Will upstream the smaller CU optimization later, once it is baked in
release branch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source release notes: cuda release notes category rocm This tag is for PRs from ROCm team

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants