-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
Hi,
I'm using pytorch version 1.10.0+cu102.
I noticed that the results of torch.linalg.lstsq are not reproducible, which is not the case for torch.lstsq. By not reproducible I mean that running the same operation twice will yield to a different result. This causes me problem since my experiments become irreproducible, which is not the case with the previous implementation.
Example:
>>> import torch
>>> torch.manual_seed(2020)
<torch._C.Generator object at 0x7f7ed8d359b0>
>>> a = torch.rand(10, 3)
>>> b = torch.rand(10, 3)
>>> a
tensor([[0.4869, 0.1052, 0.5883],
[0.1161, 0.4949, 0.2824],
[0.5899, 0.8105, 0.2512],
[0.6307, 0.5403, 0.8033],
[0.7781, 0.4966, 0.8888],
[0.5570, 0.7127, 0.0339],
[0.1151, 0.8780, 0.0671],
[0.5173, 0.8126, 0.3861],
[0.4992, 0.5970, 0.0498],
[0.7595, 0.3198, 0.4828]])
>>> b
tensor([[0.7016, 0.9966, 0.5778],
[0.1164, 0.7253, 0.1315],
[0.7898, 0.5141, 0.8525],
[0.5273, 0.0228, 0.8944],
[0.1633, 0.8798, 0.7698],
[0.1208, 0.7704, 0.9297],
[0.8620, 0.6643, 0.9220],
[0.9273, 0.7530, 0.1844],
[0.7666, 0.0043, 0.8979],
[0.4878, 0.8685, 0.2580]])
>>> c1 = torch.linalg.lstsq(a, b).solution.t()
>>> c2 = torch.linalg.lstsq(a, b).solution.t()
>>> c1 == c2
tensor([[False, False, False],
[False, False, False],
[ True, True, False]])
>>> c11 = torch.lstsq(b, a)[0][:a.size(-1)].t()
<stdin>:1: UserWarning: torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.
torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in the returned tuple (although it returns other information about the problem).
To get the qr decomposition consider using torch.linalg.qr.
The returned solution in torch.lstsq stored the residuals of the solution in the last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the residuals in the field 'residuals' of the returned named tuple.
The unpacking of the solution, as in
X, _ = torch.lstsq(B, A).solution[:A.size(1)]
should be replaced with
X = torch.linalg.lstsq(A, B).solution (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:3668.)
>>> c22 = torch.lstsq(b, a)[0][:a.size(-1)].t()
>>> c11 == c22
tensor([[True, True, True],
[True, True, True],
[True, True, True]])
>>>Versions
PyTorch version: 1.10.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 8.4.0-3ubuntu2) 8.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.8.10 (default, Jun 2 2021, 10:49:15) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.8.0-49-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
Nvidia driver version: 450.102.04
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.20.0
[pip3] pytorch-lightning==1.5.6
[pip3] pytorch-metric-learning==1.0.0
[pip3] torch==1.10.0
[pip3] torch-cluster==1.5.9
[pip3] torch-geometric==2.0.4
[pip3] torch-points-kernels==0.7.0
[pip3] torch-points3d==1.3.0
[pip3] torch-scatter==2.0.9
[pip3] torch-sparse==0.6.12
[pip3] torch-spline-conv==1.2.1
[pip3] torchaudio==0.10.0
[pip3] torchfile==0.1.0
[pip3] torchmetrics==0.6.2
[pip3] torchnet==0.0.4
[pip3] torchvision==0.11.1
[conda] numpy 1.19.5 pypi_0 pypi
[conda] pytorch-metric-learning 0.9.99 pypi_0 pypi
[conda] torch 1.10.0 pypi_0 pypi
[conda] torch-cluster 1.5.9 pypi_0 pypi
[conda] torch-geometric 1.7.2 pypi_0 pypi
[conda] torch-points-kernels 0.6.10 pypi_0 pypi
[conda] torch-points3d 1.3.0 pypi_0 pypi
[conda] torch-scatter 2.0.9 pypi_0 pypi
[conda] torch-sparse 0.6.12 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchnet 0.0.4 pypi_0 pypi
[conda] torchvision 0.11.1 pypi_0 pypi
cc @mruberry @kurtamohler @jianyuh @nikitaved @pearu @walterddr @IvanYashchuk @xwang233 @lezcano