Skip to content

Conversation

@iotamudelta
Copy link
Contributor

@iotamudelta iotamudelta commented Oct 10, 2019

This adds support for gemm-style matrix multiplications with data and output in bf16 to PyTorch on ROCm to the backend (i.e., bgemm).

Enable operators depending on bgemm.

With this change, bf16 matrices on ROCm can be multiplied on the GPU.

@iotamudelta iotamudelta added module: rocm AMD GPU support for Pytorch open source labels Oct 10, 2019
@iotamudelta iotamudelta requested a review from bddppq October 10, 2019 22:28
@pytorchbot pytorchbot added module: cublas Problem related to cublas support module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: operators labels Oct 10, 2019
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 11, 2019
@iotamudelta
Copy link
Contributor Author

@gottbrath this is the BLAS bringup for bf16

@iotamudelta iotamudelta requested a review from ezyang November 15, 2019 21:44
@ezyang
Copy link
Contributor

ezyang commented Nov 18, 2019

This patch is described to be ROCm specific, but the contents of the diff suggest to me that it is turning on bfloat16 on regular CUDA as well. What's going on here?

Additionally, I didn't see any test modifications.

@iotamudelta
Copy link
Contributor Author

@ezyang thanks for looking at it! I don't think we can discriminate between CUDA/ROCm in the Declarations?

Test cases: that's a good point. We've tested with actual scripts but let me see if we can also enable some unit tests here.

@ezyang
Copy link
Contributor

ezyang commented Nov 19, 2019

I guess what I'm mostly wondering is, does this PR also accidentally add support for CUDA at the same time? Or will the CUDA paths just error.

@rohithkrn
Copy link
Contributor

@ezyang cuda paths will just error if bfloat16 type is used.

@ezyang
Copy link
Contributor

ezyang commented Nov 20, 2019

Thanks. This looks gtg, just needs to resolve merge conflict

Copy link
Contributor

@bddppq bddppq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some tests? Maybe refer to #27259 to see how to add bfloat16 tests.

@bddppq bddppq requested a review from izdeby November 21, 2019 07:10
THCTensor_(freeCopyTo)(state, cr, r_);
}
#elif defined(THC_REAL_IS_HALF)
#elif defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_BFLOAT16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: double space

@izdeby
Copy link
Contributor

izdeby commented Nov 21, 2019

Can you, please, add some description of what and why are these changes for? Also, add tests

@iotamudelta
Copy link
Contributor Author

@bddppq @izdeby tests are incoming.

@izdeby I thought the title was pretty self-explanatory but added more words to the description now. OK?

@rohithkrn
Copy link
Contributor

@bddppq @izdeby @ezyang the tests for gemms are under tensor_op_tests. Added bfloat16_precision(defaults to 1e-5) arg to the argument list and enabled bfloat16 tests for gemm ops on ROCm

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bddppq has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cublas Problem related to cublas support module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: rocm AMD GPU support for Pytorch open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants