-
Notifications
You must be signed in to change notification settings - Fork 26.3k
fix gemm call for CUDABlas for THCUNN conv, #23545 #23552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
test/test_nn.py
Outdated
|
|
||
| @unittest.skipIf(not TEST_CUDA, 'CUDA not available') | ||
| def test_ConvTranspose2d_half_cublas_gemm(self): | ||
| inputs = torch.randn(1, 1, 16, 16, device='cuda', dtype=torch.half) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can also force CuDNN to be off (for future safety) using the context manager we have available (i think it's called with torch.backends.cudnn something something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I've added the context manager.
|
@pytorchbot merge this please also @soumith , this probably wants to be in 1.2 too i suppose? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: * Swapped `CUBLAS_OP_N` for `'n'` * added a test This PR should fix pytorch/pytorch#23545. Thanks at AlphabetMan for reporting the initial issue reported in [the forum](https://discuss.pytorch.org/t/cuda-10-1-error-using-transposeconv2d-with-output-padding-1/51414?u=ptrblck) as well as ngimel for the guidance. Pull Request resolved: pytorch/pytorch#23552 Differential Revision: D16580986 Pulled By: ezyang fbshipit-source-id: abc0bce1e84d9c9d96d44ae0296951725adc8424
CUBLAS_OP_Nfor'n'This PR should fix #23545.
Thanks at AlphabetMan for reporting the initial issue reported in the forum as well as @ngimel for the guidance.