-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Make CUDA triu / tril support batches of size > 65535 #21067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ngimel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok with this PR if we need only functional support, if performance is important some changes are necessary. Code is nicely simplified.
|
@ngimel This is not a necessarily performance critical kernel - it's generally used in tandem with many linear algebra operations. However, since you mentioned that indexing based on int32 helps a lot for smaller inputs, I've incorporated it in the PR. |
|
@pytorchbot rebase this please |
|
@pytorchbot merge this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: In the previous implementation of triu / tril, we passed the batch size in the 2nd dimension of a grid. This is limited to 65535, which means that performing triu / tril on a tensor with batch size > 65535 will throw an error. This PR removes the dependence on the 2nd dimension, and corresponding non-contiguity constraints. Changelog: - Compute offset, row and col in the kernel - Use 1st dimension of grid alone - Remove unnecessary contiguity checks on tensors as a result of this change. Pull Request resolved: pytorch/pytorch#21067 Differential Revision: D15572501 Pulled By: ezyang fbshipit-source-id: 93851cb661918ce794d43eeb12c8a38762e1358c
In the previous implementation of triu / tril, we passed the batch size in the 2nd dimension of a grid. This is limited to 65535, which means that performing triu / tril on a tensor with batch size > 65535 will throw an error. This PR removes the dependence on the 2nd dimension, and corresponding non-contiguity constraints.
Changelog:
Test Plan: