-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Ready] Make potrs batched #13453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Ready] Make potrs batched #13453
Conversation
00d213d to
b4ff163
Compare
test/test_cuda.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
b4ff163 to
314f875
Compare
- This is straightforward PR, building up on the batch inverse PR, except for one change:
- The GENERATE_LINALG_HELPER_n_ARGS macro has been removed, since it is not very general
and the resulting code is actually not very copy-pasty.
314f875 to
551360a
Compare
|
@zou3519 This is ready for review, just for your information. |
|
@vishwakftw I'll take a look later today or tomorrow |
|
Thank you, appreciate it. :-) |
|
No, thank you for your contribution :) |
|
|
||
| template<class scalar_t> | ||
| void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwork, int *info) { | ||
| void lapackGetri(int n, scalar_t* a, int lda, int* ipiv, scalar_t* work, int lwork, int* info) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', | ||
| 'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*', 'slice', | ||
| 'randint(_out)?', | ||
| 'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*', '_potrs.*', |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
| b = cast(torch.randn(2, 1, 3, 4, 6)) | ||
| L = get_cholesky(A, upper) | ||
| x = torch.potrs(b, L, upper=upper) | ||
| x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
| self.assertEqual(x.data, cast(x_exp)) | ||
|
|
||
| # broadcasting A | ||
| A = cast(random_symmetric_pd_matrix(4)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thank you @vishwakftw!
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@zou3519 is there anything that I need to do? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@vishwakftw I think it should be good, I'll let you know if any action is required |
|
@zou3519 just a notification: there were merge conflicts after the recent changes to |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - This is a straightforward PR, building up on the batch inverse PR, except for one change: - The GENERATE_LINALG_HELPER_n_ARGS macro has been removed, since it is not very general and the resulting code is actually not very copy-pasty. Billing of changes: - Add batching for `potrs` - Add relevant tests - Modify doc string Minor changes: - Remove `_gesv_single`, `_getri_single` from `aten_interned_strings.h`. - Add test for CUDA `potrs` (2D Tensor op) - Move the batched shape checking to `LinearAlgebraUtils.h` Pull Request resolved: pytorch/pytorch#13453 Reviewed By: soumith Differential Revision: D12942039 Pulled By: zou3519 fbshipit-source-id: 1b8007f00218e61593fc415865b51c1dac0b6a35
Billing of changes:
potrsMinor changes:
_gesv_single,_getri_singlefromaten_interned_strings.h.potrs(2D Tensor op)LinearAlgebraUtils.h