-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Enable broadcasting of batch dimensions RHS and LHS tensors for lu_solve #24333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Changelog: - Enable broadcasting of RHS and LHS tensors for lu_solve. This means that you can now have RHS with size `3 x 2` and LHS with size `4 x 3 x 3` for instance - Remove deprecated behavior of having 2D tensors for RHS. Now all tensors have to have a last dimension which equals the number of right hand sides - Modified docs Test Plan: - Add tests for new behavior in test_torch.py with a port to test_cuda.py
|
@pytorchbot rebase this please |
|
@pytorchbot rebase this please |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks correct. I had some comments on style and cleaning up the testing code
|
@zou3519 I actually missed out on adding tests for broadcasting behavior earlier. I've added them now, just FYI. |
|
@zou3519 except for test refactoring and specifying sizes in error message (which will be addressed in follow-up PRs), the PR should be good to review again. |
| return torch.stack(all_matrices).reshape(*(batches + (l, l))) | ||
|
|
||
|
|
||
| def random_linalg_solve_processed_inputs(A_dims, b_dims, gen_fn, transform_fn, cast_fn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once this PR is merged, I will use this function in other places in the test suite for *solve methods. This would reduce duplication of code.
|
@pytorchbot rebase this please |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, the code looks a lot better now :)
I had some last comments about repetitiveness in the testing code
…into lu_solve-new-version
|
@pytorchbot rebase this please |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thank you
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…lve (#24333) Summary: Changelog: - Enable broadcasting of RHS and LHS tensors for lu_solve. This means that you can now have RHS with size `3 x 2` and LHS with size `4 x 3 x 3` for instance - Remove deprecated behavior of having 2D tensors for RHS. Now all tensors have to have a last dimension which equals the number of right hand sides - Modified docs Pull Request resolved: pytorch/pytorch#24333 Test Plan: - Add tests for new behavior in test_torch.py with a port to test_cuda.py Differential Revision: D17165463 Pulled By: zou3519 fbshipit-source-id: cda5d5496ddb29ed0182bab250b5d90f8f454aa6
Summary: Changelog: - De-duplicate the code in tests for torch.solve, torch.cholesky_solve, torch.triangular_solve - Skip tests explicitly if requirements aren't met for e.g., if NumPy / SciPy aren't available in the environment - Add generic helpers for these tests in test/common_utils.py Pull Request resolved: #25733 Test Plan: - All tests should pass to confirm that the change is not erroneous Clears one point specified in the discussion in #24333. Differential Revision: D17315330 Pulled By: zou3519 fbshipit-source-id: c72a793e89af7e2cdb163521816d56747fd70a0e
Changelog:
3 x 2and LHS with size4 x 3 x 3for instanceTest Plan: