-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ROCm] Add bfloat16 support in linear algebra on ROCm #27719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@gottbrath this is the BLAS bringup for bf16 |
|
This patch is described to be ROCm specific, but the contents of the diff suggest to me that it is turning on bfloat16 on regular CUDA as well. What's going on here? Additionally, I didn't see any test modifications. |
|
@ezyang thanks for looking at it! I don't think we can discriminate between CUDA/ROCm in the Declarations? Test cases: that's a good point. We've tested with actual scripts but let me see if we can also enable some unit tests here. |
|
I guess what I'm mostly wondering is, does this PR also accidentally add support for CUDA at the same time? Or will the CUDA paths just error. |
|
@ezyang cuda paths will just error if bfloat16 type is used. |
|
Thanks. This looks gtg, just needs to resolve merge conflict |
bddppq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some tests? Maybe refer to #27259 to see how to add bfloat16 tests.
| THCTensor_(freeCopyTo)(state, cr, r_); | ||
| } | ||
| #elif defined(THC_REAL_IS_HALF) | ||
| #elif defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_BFLOAT16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: double space
|
Can you, please, add some description of what and why are these changes for? Also, add tests |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bddppq has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This adds support for gemm-style matrix multiplications with data and output in bf16 to PyTorch on ROCm to the backend (i.e., bgemm).
Enable operators depending on bgemm.
With this change, bf16 matrices on ROCm can be multiplied on the GPU.