-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add batched version of trtrs #18025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batched version of trtrs #18025
Conversation
156e143 to
e3421eb
Compare
- Remove single batch TH/THC implementations
e3421eb to
645c07c
Compare
Seems like you don't need a magma_queue_t object
c9baeb2 to
33ed8fa
Compare
| Tensor result_tmp; | ||
| result_tmp = at::_cholesky_solve_helper(self, A, upper); | ||
| result.resize_as_(result_tmp).copy_(result_tmp); | ||
| return result; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch!
|
@pytorchbot retest this please |
|
@ifedan Is this good to go? |
|
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - Remove single batch TH/THC implementations - Remove `_batch_trtrs_lower` from `multivariate_normal` - Add tests for batched behavior - Modify trtrs_backward to accommodate for batched case - Modify docs In a future PR, this will be renamed to `triangular_solve`. Pull Request resolved: pytorch/pytorch#18025 Differential Revision: D14523004 Pulled By: ifedan fbshipit-source-id: 11c6a967d107f969b60e5a5c73ce6bb8099ebbe1
_batch_trtrs_lowerfrommultivariate_normalIn a future PR, this will be renamed to
triangular_solve.