-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix mm accuracy in ROCm for some inputs #116537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116537
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (8 Unrelated Failures)As of commit ba7fafd with merge base 57491d2 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label ciflow/periodic |
|
@pytorchbot label ciflow/trunk |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more detailed PR description and a test case in OpInfo that triggers this failure would be great.
malfet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but see comments about some minor issues (but as tests are limited to ROCm platform I'm not requesting changes)
|
@pytorchbot merge |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 6 jobs have failed, first few of them are: rocm, linux-binary-libtorch-pre-cxx11, linux-binary-manywheel, trunk, linux-binary-libtorch-cxx11-abi Details for Dev Infra teamRaised by workflow job |
|
@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
Any chance we can change the hipblaslt behavior to avoid cuda/hip divergence? |
Which part of it? The bias issue or the double type not being supported or both? |
| #if defined(USE_ROCM) | ||
| // This condition is needed for mm case on ROCm for hipblasLt path. | ||
| // Passing the bias ptr as null to avoid accuracy issues for mm case. | ||
| (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeffdaily I meant here that uses bias==nullptr to avoid setting the attributes in computeDesc iiuc. I wonder why setting those epilog attr will end up with wrong results.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: rocm, periodic Details for Dev Infra teamRaised by workflow job |
|
pytorchmergebot got confused. I had to remove and re-add ciflow labels. Hopefully all missing ciflows are available to mergebot now. |
|
@pytorchbot merge -f "unrelated failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR fixes the accuracy issues for hipblasLT for mm case on ROCm.
This PR is a follow up to the integration PR #114329 and #114890
The accuracy issue arises for mm usecase for ROCm where hipblasLT is enabled, and a bias has been passed which is not required. This PR addresses that issue.
Added a unit-test case for this issue (bias=None) case.
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang