Skip to content

torch.solve in GPU fails when batch > 65535 #21643

@YurongYou

Description

@YurongYou

🐛 Bug

torch.solve in GPU tensor fails when the batch size of the tensors > 65535, and throws a RuntimeError: CUDA error: invalid configuration argument . Interestingly, if you run the torch.solve multiple times, it will occasionally pass (but with wrong results). See the reproducing code.

To Reproduce

Steps to reproduce the behavior:

import torch

def test(batch_size):
    mat = torch.randn(batch_size, 12, 12)
    vec = torch.randn(batch_size, 12, 1)

    res, _ = torch.solve(vec, mat)
    print("CPU res: {}".format(torch.norm(torch.bmm(mat,res) - vec)))

    mat = mat.cuda()
    vec = vec.cuda()

    res, _ = torch.solve(vec, mat)
    res, _ = torch.solve(vec, mat)
    print("GPU res: {}".format(torch.norm(torch.bmm(mat,res) - vec)))

test(batch_size=65535)
print()
test(batch_size=65536)

It will output

CPU res: 0.07957195490598679
GPU res: 0.09418904036283493

CPU res: 0.06375715881586075
Traceback (most recent call last):
  File "test.py", line 19, in <module>
    test(batch_size=65536)
  File "test.py", line 14, in test
    res, _ = torch.solve(vec, mat)
RuntimeError: CUDA error: invalid configuration argument

if you execute multiple times like this

import torch

def test(batch_size):
    mat = torch.randn(batch_size, 12, 12)
    vec = torch.randn(batch_size, 12, 1)

    res, _ = torch.solve(vec, mat)
    print("CPU res: {}".format(torch.norm(torch.bmm(mat,res) - vec)))

    mat = mat.cuda()
    vec = vec.cuda()

    res, _ = torch.solve(vec, mat)
    try:
        res, _ = torch.solve(vec, mat)
    except:
        pass
    print("GPU res: {}".format(torch.norm(torch.bmm(mat,res) - vec)))

test(batch_size=65535)
print()
test(batch_size=65536)

You will get

CPU res: 0.03482796251773834
GPU res: 0.03578644245862961

CPU res: 0.0608956404030323
GPU res: 3199.743896484375

Environment

Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: TITAN Xp
Nvidia driver version: 418.39
cuDNN version: /usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.5.1

Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.1.0
[pip] torchvision==0.3.0
[conda] blas                      1.0                         mkl
[conda] mkl                       2019.4                      243
[conda] mkl_fft                   1.0.12           py37ha843d7b_0
[conda] mkl_random                1.0.2            py37hd81dba3_0
[conda] pytorch                   1.1.0           py3.7_cuda10.0.130_cudnn7.5.1_0    pytorch
[conda] torchvision               0.3.0           py37_cu10.0.130_1    pytorch

Additional notes

For now, I use this snippet to get around with the bug. But the runtime becomes similar to that runs on CPU only.

def GPU_solve(As, bs):
    batch_size, N, N = As.shape
    Ws = As.new(size=(*As.shape[:2], 1))
    smb = 65535
    temp_LR = As.new(size=(smb, *As.shape[1:]))
    for i in range(batch_size // smb + 1):
        start = smb * i
        end = smb * (i + 1) if i < batch_size // smb else batch_size
        torch.solve(bs[start:end], As[start:end], out=(Ws[start:end], temp_LR[0:end-start]))
    return Ws

Suspect this issue relates to this issue

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dependency bugProblem is not caused by us, but caused by an upstream library we usetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions