-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: dependency bugProblem is not caused by us, but caused by an upstream library we useProblem is not caused by us, but caused by an upstream library we usetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
torch.solve in GPU tensor fails when the batch size of the tensors > 65535, and throws a RuntimeError: CUDA error: invalid configuration argument . Interestingly, if you run the torch.solve multiple times, it will occasionally pass (but with wrong results). See the reproducing code.
To Reproduce
Steps to reproduce the behavior:
import torch
def test(batch_size):
mat = torch.randn(batch_size, 12, 12)
vec = torch.randn(batch_size, 12, 1)
res, _ = torch.solve(vec, mat)
print("CPU res: {}".format(torch.norm(torch.bmm(mat,res) - vec)))
mat = mat.cuda()
vec = vec.cuda()
res, _ = torch.solve(vec, mat)
res, _ = torch.solve(vec, mat)
print("GPU res: {}".format(torch.norm(torch.bmm(mat,res) - vec)))
test(batch_size=65535)
print()
test(batch_size=65536)It will output
CPU res: 0.07957195490598679
GPU res: 0.09418904036283493
CPU res: 0.06375715881586075
Traceback (most recent call last):
File "test.py", line 19, in <module>
test(batch_size=65536)
File "test.py", line 14, in test
res, _ = torch.solve(vec, mat)
RuntimeError: CUDA error: invalid configuration argument
if you execute multiple times like this
import torch
def test(batch_size):
mat = torch.randn(batch_size, 12, 12)
vec = torch.randn(batch_size, 12, 1)
res, _ = torch.solve(vec, mat)
print("CPU res: {}".format(torch.norm(torch.bmm(mat,res) - vec)))
mat = mat.cuda()
vec = vec.cuda()
res, _ = torch.solve(vec, mat)
try:
res, _ = torch.solve(vec, mat)
except:
pass
print("GPU res: {}".format(torch.norm(torch.bmm(mat,res) - vec)))
test(batch_size=65535)
print()
test(batch_size=65536)You will get
CPU res: 0.03482796251773834
GPU res: 0.03578644245862961
CPU res: 0.0608956404030323
GPU res: 3199.743896484375
Environment
Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: TITAN Xp
Nvidia driver version: 418.39
cuDNN version: /usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.5.1
Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.1.0
[pip] torchvision==0.3.0
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.1.0 py3.7_cuda10.0.130_cudnn7.5.1_0 pytorch
[conda] torchvision 0.3.0 py37_cu10.0.130_1 pytorch
Additional notes
For now, I use this snippet to get around with the bug. But the runtime becomes similar to that runs on CPU only.
def GPU_solve(As, bs):
batch_size, N, N = As.shape
Ws = As.new(size=(*As.shape[:2], 1))
smb = 65535
temp_LR = As.new(size=(smb, *As.shape[1:]))
for i in range(batch_size // smb + 1):
start = smb * i
end = smb * (i + 1) if i < batch_size // smb else batch_size
torch.solve(bs[start:end], As[start:end], out=(Ws[start:end], temp_LR[0:end-start]))
return WsSuspect this issue relates to this issue
Metadata
Metadata
Assignees
Labels
module: dependency bugProblem is not caused by us, but caused by an upstream library we useProblem is not caused by us, but caused by an upstream library we usetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module