-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ready] Add matrix_power #10068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ready] Add matrix_power #10068
Conversation
1. Tests added 2. Doc string added TODO: 3. Derivative for n <= 0
| } | ||
| return identities; | ||
| } else if (n < 0) { | ||
| AT_CHECK(a.dim() == 2, "Negative powers for batch matrices are currently not supported") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| Tensor matrix_power(const Tensor& a, int64_t n) { | ||
| AT_CHECK(a.dim() >= 2 && at::isFloatingType(a.type().scalarType()), | ||
| "pinverse(", a.type(), "{", a.sizes(), "}): expected a tensor " |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| Tensor result, z; | ||
| int64_t r; | ||
| while (n > 0) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return identities; | ||
| } else if (n < 0) { | ||
| AT_CHECK(a.dim() == 2, "Negative powers for batch matrices are currently not supported") | ||
| Tensor a_ = at::native::inverse(a); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| AT_CHECK(a.dim() >= 2 && at::isFloatingType(a.type().scalarType()), | ||
| "matrix_power(", a.type(), "{", a.sizes(), "}): expected a tensor " | ||
| "of floating types with dim at least 2"); | ||
| if (n == 0) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
This is ready for review. |
|
@fmassa Could you review this please? |
|
@fmassa Ping :) |
| "matrix_power(", a.type(), "{", a.sizes(), "}): expected a tensor " | ||
| "of floating types with dim at least 2"); | ||
| if (n == 0) { | ||
| return a.clone().copy_(at::eye(a.size(-2), a.options()).expand_as(a)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart from the issue in the n==0 case, everything else looks good to me
…input tensor from the autograd graph This reverts commit 32fd6f7.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SsnL has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
|
@vishwakftw I just retriggered the CI to see if the ASAN test failure is consistent, if so we probably should look at how to fix it or skip it. |
|
Build failure seems to be unrelated to the PR. |
|
@yf225 is this good to go? Sorry about the reminder. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
|
||
| # Single matrix, but full rank | ||
| # This is for negative powers | ||
| from test_autograd import random_fullrank_matrix_distinct_singular_value |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@yf225 could you try importing now? I am sorry if you are busy with other tasks - thought I should send a reminder just in case. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@yf225 is this good to go? |
|
@vishwakftw Apologies for the delay - there are some internal dependency errors that I am still trying to figure out how to fix. I will have an update here as soon as possible. |
|
No worries @yf225 . Thank you for helping out. |
|
@vishwakftw Sorry for the delay on this - I think your original idea of putting |
|
No problem @yf225. I will revert the previous commit right away. Thank you again for helping out with this. |
This reverts commit dcf13e5.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@yf225 did the internal tests pass? |
|
Closing in favour of #11421 |
Summary: vishwakftw Your patch needed some updates because the default native function dispatches changed from `[function, method]` to `[function]`. The CI was run before that change happened so it still shows green, but the internal test caught it. I did some changes when rebasing and updating so I didn't just force push to your branch. Let's see if this passes CI and internal test. If it does, let me know if you want me to force push to your branch or use this PR instead. Note to reviewers: patch was already approved at #10068 . cc yf225 Pull Request resolved: #11421 Differential Revision: D9733407 Pulled By: SsnL fbshipit-source-id: cf2ed293bb9942dcc5158934ff4def2f63252599
Summary: vishwakftw Your patch needed some updates because the default native function dispatches changed from `[function, method]` to `[function]`. The CI was run before that change happened so it still shows green, but the internal test caught it. I did some changes when rebasing and updating so I didn't just force push to your branch. Let's see if this passes CI and internal test. If it does, let me know if you want me to force push to your branch or use this PR instead. Note to reviewers: patch was already approved at pytorch#10068 . cc yf225 Pull Request resolved: pytorch#11421 Differential Revision: D9733407 Pulled By: SsnL fbshipit-source-id: cf2ed293bb9942dcc5158934ff4def2f63252599
Uh oh!
There was an error while loading. Please reload this page.