-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add forward AD for torch.linalg.eigh #62163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit c3fc5ed (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
|
@albanD, CI fails with |
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few notes on forward mode AD for maps with constrained inputs.
|
|
||
| auto F = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1); | ||
| F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); | ||
| F = F.pow(-1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to divide by F than to compute the inverse explicitly an then multiply.
| // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 | ||
| // Section 3.1 Eigenvalues and eigenvectors | ||
|
|
||
| auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), input_tangent), eigenvectors); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that the input tangent needs to be symmetrised as: const auto in = 0.5 * (input_tangent + input_tangent.transpose(-2, -1).conj()).
Denote Her(n) the n x n Hermitian matrices and U(n) the unitary matrices. We have that eigh : Her(n) -> R^n x U(n). This means that the differential of eigh goes from the tangent to Her(n) to the tangent of R^n times the tangent of U(n) (at the output matrices). Now, the tangent of Her(n) is Her(n) itself as Her(n) is just a linear subspace. As such, the input needs to be symmetric for this function to make sense.
Computing (A + A^H)/2 happens to be the orthogonal projection of a matrix onto the space Her(n). The theorem that formalises all this is the one that says that the differential of a map on an embedded manifold is the differential of the map on the total space restricted to the tangent space of the embedded manifold.
I wrote all this down in a comment in eigh_backward. In general it would be interesting that the way we compute forward / backward are somewhat equivalent (i.e. one is "taking the transpose" of the other).
| F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); | ||
| F = F.pow(-1); | ||
|
|
||
| auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), input_tangent), eigenvectors); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_tangent should be projected on to the Hermitian matrices (see comment above).
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slowFor more information, please take a look at the CI Flow Wiki. |
Codecov Report
@@ Coverage Diff @@
## master #62163 +/- ##
==========================================
- Coverage 66.60% 62.00% -4.60%
==========================================
Files 716 716
Lines 92689 92689
==========================================
- Hits 61735 57472 -4263
- Misses 30954 35217 +4263 |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I'll wait for @lezcano final approval and I'll land.
| auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj()); | ||
|
|
||
| auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors); | ||
| auto eigenvectors_tangent = at::matmul(eigenvectors, tmp.div(E)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit you can return directly here to avoid the extra assignment.
| decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack], | ||
| skips=( | ||
| # Gradcheck for complex hangs for this function | ||
| SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an issue open tracking this? Do we want one if not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no open issue about it. I added a NotImplementedError for complex inputs for now, we'll enable it later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo @albanD's comments.
The only point is that it might be worth to land this after the support for several outputs in AD mode as, at the moment, this AD is doing twice the same computations. Once for eigvals and another one for eigvectors.
|
Now that this is ready, I think we can just add this to master. And merge the implementation when the multi-output support is added. |
|
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This PR adds forward mode differentiation for
torch.linalg.eighand a few other functions required for tests to pass.For some reason running tests for
torch.linalg.eigvalshand complextorch.linalg.eighhangs. These tests are skipped for now.cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @jianyuh @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233