Skip to content

Methods that solve systems of linear equations are memory-inefficient (batch-wise broadcasting) #49252

@nikitaved

Description

@nikitaved

🐛 Bug

Certain methods defined in BatchedLinearAlgebra.cpp, such as

torch.solve
torch.cholesky_solve
torch.triangular_solve
torch.lu_solve

solve the system of linear equations AX = B and support batching with broadcasting over the batch dimensions.
The problem is that the brodcasted new matrices are being fully materialized, which happens after calling to the method cloneBatchedColumnMajor prior to applying LAPACK drivers. LAPACK drivers are applied per matrix in a batch, and it is possible to avoid materialization of full tensors with a better over-batch iteration and eliminating the call to cloneBatchedColumnMajor.

The example below shows what happens in practice:

To Reproduce

In [1]: import torch

In [2]: b = torch.rand(1000, 1000, 1000, 1) # 1000 x 1000 1000-dimensional rhs

In [3]: A = torch.rand(1000, 1000) # the system

In [4]: torch.solve(b, A.view(1, 1, 1000, 1000)) # we want A batch dims to broadcast over b batch dims
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-18e8d42bee82> in <module>
----> 1 torch.solve(b, A.view(1, 1, 1000, 1000)) # we want A batch dims to broadcast over b batch dims

RuntimeError: [enforce fail at CPUAllocator.cpp:67] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 4000000000000 bytes. Error code 12 (Cannot allocate memory)

In [5]: sol = torch.rand(1000, 1000, 1000, 1) # that should be the shape of the solution

In [6]: sol_larger = torch.rand(1000, 1000, 1000, 10)

In [7]: del b, A, sol, sol_larger

In [8]: A = torch.rand(1000, 1000, 1000, 1000) # the size torch.solve tries to allocate
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-1b3bd2cd8b95> in <module>
----> 1 A = torch.rand(1000, 1000, 1000, 1000)

RuntimeError: [enforce fail at CPUAllocator.cpp:67] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 4000000000000 bytes. Error code 12 (Cannot allocate memory)

In [9]: 

Expected behavior

The broadcasting of A over b should not materialize redundant copies.

Environment

PyTorch version: 1.8.0a0+056961b
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (crosstool-NG 1.23.0.449-a04d0) 7.3.0
Clang version: 8.0.0 (tags/RELEASE_800/final)
CMake version: version 3.18.2

Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: 
GPU 0: TITAN RTX
GPU 1: TITAN RTX

Nvidia driver version: 450.51.06
cuDNN version: /usr/local/cuda-10.2.89/targets/x86_64-linux/lib/libcudnn.so.7
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] torch==1.8.0a0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.1.243             h6bb024c_0  
[conda] magma-cuda101             2.5.1                         1    pytorch
[conda] mkl                       2020.2                      256  
[conda] mkl-include               2020.2                      256  
[conda] mkl-service               2.3.0            py36he904b0f_0  
[conda] mkl_fft                   1.2.0            py36h23d657b_0  
[conda] mkl_random                1.1.1            py36h0573a6f_0  
[conda] numpy                     1.19.1           py36hbc911f0_0  
[conda] numpy-base                1.19.1           py36hfa32c7d_0  
[conda] torch                     1.8.0a0                   dev_0    <develop>

cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: memory usagePyTorch is using more memory than it should, or it is leaking memorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions