Skip to content

Conversation

@ncullen93
Copy link

@ncullen93 ncullen93 commented Jul 9, 2017

this implements the symeig function for variables... I grabbed the gradient implementation directly from theano ... And I test it in this gist to show that it gets the same gradient as the theano version and that it passes gradcheck. I've also used this in actual networks so I'm fairly confident it's correct.

I obviously need to add actual tests, so any direction for that and anything else is much appreciated.

x, w, v, = ctx.saved_variables

# gives an error if I don't do this..
x = x.data

This comment was marked as off-topic.

This comment was marked as off-topic.

tri1 = lambda a: torch.triu(a, 1)

def G(n):
return sum([v[:, m] * grad_v.t()[n].matmul(v[:, m]) / (w[n] - w[m])

This comment was marked as off-topic.

This comment was marked as off-topic.

return sum([v[:, m] * grad_v.t()[n].matmul(v[:, m]) / (w[n] - w[m])
for m in range(N) if m != n])

g = sum([torch.ger(v[:, n], v[:, n] * grad_w[n] + G(n))

This comment was marked as off-topic.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Do you think it would be possible to get eig as well?
I made some comments that I think need to be addressed before this PR can be merged. You could get the template for the tests from the gesv or inverse functions (I can give them to you tomorrow if needed, typing from the phone)


@staticmethod
def forward(ctx, input, eigenvectors=False, upper=True):
ctx.eigenvectors = eigenvectors

This comment was marked as off-topic.

This comment was marked as off-topic.

@ncullen93
Copy link
Author

ncullen93 commented Jul 13, 2017

Major questions right now:

  1. What to do if eigenvectors=False. I think an exception should be raised, because it's not clear how to compute the gradient without it.

  2. Why do I have to call .data on x, w, and v? Am I saving those for the backend correctly?

Also, standard eig doesn't have implemented gradient in Theano... Neither does tensorflow. Both only have symmetric version. so I'll check around the literature but seems unlikely right now.

formatting for pyflakes
@apaszke
Copy link
Contributor

apaszke commented Jul 13, 2017

  1. You should probably override the flag and save the result for backward.
  2. They're saved correctly, but something must be wrong with your backward implementation. What error do you get when you don't unpack the variables?

@apaszke
Copy link
Contributor

apaszke commented Jul 13, 2017

Ad 2. You can use @once_differentiable and use ctx.saved_tensors. It will work, but won't be differentiable twice

NC Cullen added 2 commits July 13, 2017 16:19
small typo
removed `.data` call (not needed) and renamed `symeig` return values to `e` and `V` instead of `w` and `x` to be consistent with pytorch documentation.
@ncullen93
Copy link
Author

ncullen93 commented Jul 13, 2017

Re: 1. Ok, great that's a good solution. I will work on that.

Re: 2. Well I just tested it again - in the codebase, it only works without .data. However, implemented stand-alone it only works with .data. [EDIT]: Ah it's because in the docs stand-alone example, you use saved_tensors and in the backend you use saved_variables. It think it's all good now.

@apaszke
Copy link
Contributor

apaszke commented Jul 13, 2017

It seems that the only problem is that you're mixing tensors with Variables, so just making sure that you keep everything wrapped would be enough. Add a couple of prints and see why do they get mixed up

facebook-github-bot pushed a commit that referenced this pull request Jul 25, 2018
Summary:
Partially fixes #6890. (backward pass for non-symmetric eigen-decomposition is not implemented in other packages, e.g. autograd, mxnet, tensorflow, presumably because the eigenvalues can be imaginary for the general case, and AFAIK we cannot support complex numbers).

This patch adds a backward function for the symmetric eigen-decomposition function `torch.symeig`. The formula used is taken from [here](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf). Unit tests are added to verify correctness.

There is still one outstanding issue, which is how to handle the case where the `symeig` is called with `eigenvectors=False`. In this case, the eigenvectors are returned as a zero tensor, but the backward computation for the eigenvalues depends on the eigenvectors. There was a previous attempt to implement this in #2026, where apaszke mentioned that the `eigenvectors` argument should be overridden so that they are saved for the backwards pass. The forward code is autogenerated, though, and it isn't clear to me how that would be done. I'd appreciate any guidance. For now, there is a unit test that will fail until that issue is resolved.
Pull Request resolved: #8586

Reviewed By: ezyang

Differential Revision: D8872760

Pulled By: SsnL

fbshipit-source-id: 76614495d0f9c118fec163a428f32e5480b4d115
jramseyer pushed a commit to jramseyer/pytorch that referenced this pull request Jul 30, 2018
Summary:
Partially fixes pytorch#6890. (backward pass for non-symmetric eigen-decomposition is not implemented in other packages, e.g. autograd, mxnet, tensorflow, presumably because the eigenvalues can be imaginary for the general case, and AFAIK we cannot support complex numbers).

This patch adds a backward function for the symmetric eigen-decomposition function `torch.symeig`. The formula used is taken from [here](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf). Unit tests are added to verify correctness.

There is still one outstanding issue, which is how to handle the case where the `symeig` is called with `eigenvectors=False`. In this case, the eigenvectors are returned as a zero tensor, but the backward computation for the eigenvalues depends on the eigenvectors. There was a previous attempt to implement this in pytorch#2026, where apaszke mentioned that the `eigenvectors` argument should be overridden so that they are saved for the backwards pass. The forward code is autogenerated, though, and it isn't clear to me how that would be done. I'd appreciate any guidance. For now, there is a unit test that will fail until that issue is resolved.
Pull Request resolved: pytorch#8586

Reviewed By: ezyang

Differential Revision: D8872760

Pulled By: SsnL

fbshipit-source-id: 76614495d0f9c118fec163a428f32e5480b4d115
@zou3519
Copy link
Contributor

zou3519 commented Jul 31, 2018

I think this is superceded by #8586; please reopen if I am incorrect.

@zou3519 zou3519 closed this Jul 31, 2018
goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
Partially fixes pytorch#6890. (backward pass for non-symmetric eigen-decomposition is not implemented in other packages, e.g. autograd, mxnet, tensorflow, presumably because the eigenvalues can be imaginary for the general case, and AFAIK we cannot support complex numbers).

This patch adds a backward function for the symmetric eigen-decomposition function `torch.symeig`. The formula used is taken from [here](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf). Unit tests are added to verify correctness.

There is still one outstanding issue, which is how to handle the case where the `symeig` is called with `eigenvectors=False`. In this case, the eigenvectors are returned as a zero tensor, but the backward computation for the eigenvalues depends on the eigenvectors. There was a previous attempt to implement this in pytorch#2026, where apaszke mentioned that the `eigenvectors` argument should be overridden so that they are saved for the backwards pass. The forward code is autogenerated, though, and it isn't clear to me how that would be done. I'd appreciate any guidance. For now, there is a unit test that will fail until that issue is resolved.
Pull Request resolved: pytorch#8586

Reviewed By: ezyang

Differential Revision: D8872760

Pulled By: SsnL

fbshipit-source-id: 76614495d0f9c118fec163a428f32e5480b4d115
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants