-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Port SVD to ATen, enable batching for matrix inputs #21588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test failures are expected. I removed the TH bindings. |
|
I’ll investigate the test failures, they are related to this PR. |
…tch_svd function and use at::svd instead
|
@pytorchbot rebase this please |
|
@pytorchbot rebase this please |
|
@nairbv can you review this, please? |
- Add a comment about why empty tensors are created on the CPU while the input is a CUDA tensor in _create_U_S_VT - Print matrix in SVD doc example - strides() --> stride() in docs
|
@pytorchbot rebase this please |
|
looks like failing tests were due to whatever this fixed: |
|
Yes, I’ll manually rebase. |
|
LGTM |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nairbv is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Changelog: - Port SVD TH implementation to ATen/native/BatchLinearAlgebra.cpp - Port SVD THC implementation to ATen/native/cuda/BatchLinearAlgebra.cu - Allow batches of matrices as arguments to `torch.svd` - Remove existing implementations in TH and THC - Update doc string - Update derivatives to support batching - Modify nuclear norm implementation to use at::svd instead of _batch_svd - Remove _batch_svd as it is redundant Pull Request resolved: pytorch/pytorch#21588 Test Plan: - Add new test suite for SVD in test_torch.py with port to test_cuda.py - Add tests in common_methods_invocations.py for derivative testing Differential Revision: D16266115 Pulled By: nairbv fbshipit-source-id: e89bb0dbd8f2d58bd758b7830d2389c477aa61fb
Changelog:
torch.svdTest Plan: