-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
It seems that torch.lobpcg (https://pytorch.org/docs/stable/torch.html?highlight=lobpcg#torch.lobpcg) just always breaks when trying to take gradients via backward.
To Reproduce
Here's a minimalist example showing lobpcg breaking.
# lob.py
import torch as T
T.autograd.set_detect_anomaly(True)
A = T.randn(10, 10)
A.requires_grad_()
S = A.matmul(A.t())
e, v = T.lobpcg(S, k=3)
S_hat = T.einsum('ij,j,kj->ik', v, e, v) # v * diag(e) * v^T
loss = S_hat.abs().sum()
loss.backward() # breaks hereRunning that code produces the following error.
Warning: Error detected in MmBackward. Traceback of forward call that caused the error:
File "lob.py", line 9, in <module>
e, v = T.lobpcg(S, k=3)
File "/usr/local/lib/python3.5/dist-packages/torch/_lobpcg.py", line 261, in lobpcg
worker.run()
File "/usr/local/lib/python3.5/dist-packages/torch/_lobpcg.py", line 408, in run
self.update()
File "/usr/local/lib/python3.5/dist-packages/torch/_lobpcg.py", line 343, in update
self._update_ortho()
File "/usr/local/lib/python3.5/dist-packages/torch/_lobpcg.py", line 498, in _update_ortho
self.X[:, nc:] = mm(S_, Z[:, :n - nc])
(print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:60)
Traceback (most recent call last):
File "lob.py", line 12, in <module>
loss.backward() # breaks here
File "/usr/local/lib/python3.5/dist-packages/torch/tensor.py", line 198, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/usr/local/lib/python3.5/dist-packages/torch/autograd/__init__.py", line 100, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [10, 5]], which is output 0 of SliceBackward, is at version 14; expected version 11 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
I have a feeling that the problem is that torch.lobpcg's implementation is using an in-place operation when it shouldn't be.
This happened when running torch.__version__ == '1.5.0+cpu' installed with pip on Windows 10 WSL (Windows Subsystem for Linux) on Python 3.5.2.
Can this be fixed, or is torch.lobpcg not meant to support autograd?
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @vincentqb @vishwakftw @jianyuh @mruberry @ssnl