Skip to content

matmul in a ScriptModule requires much memory #21406

@nict-wisdom

Description

@nict-wisdom

🐛 Bug

matmul in a ScriptModule requires much more memory than in a regular (non-JIT) module.

To Reproduce

Steps to reproduce the behavior: Run the following script.

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(768, 100000)

    def forward(self, x):
        x = x * 2
        x = x * 2
        x = x * 2
        x = self.fc(x)
        return x

# trace on cpu
m = Net()
x = torch.randn(10, 128, 768)
jm = torch.jit.trace(m, x)
jm = jm.to("cuda")

for i in range(0,2):
    x = torch.randn(10, 128, 768, requires_grad=True).to("cuda")
    y = jm(x)
    tgt = torch.randn_like(y)
    y.backward(tgt)

I checked the GPU memory usage at the end of the script (by inserting input()).
nvidia-smi said

|    0    155985      C   python                                      8109MiB |

The non-JIT version

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(768, 100000)

    def forward(self, x):
        x = x * 2
        x = x * 2
        x = x * 2
        x = self.fc(x)
        return x


m = Net().to("cuda")

for i in range(0,2):
    x = torch.randn(10, 128, 768, requires_grad=True).to("cuda")
    y = m(x)
    tgt = torch.randn_like(y)
    y.backward(tgt)

requires much less.
nvidia-smi showed

|    0    156120      C   python                                      2739MiB |

The memory usage increases during backward.
More precisely, matmul in LinearAlgebra.cpp is called in DifferentiableGraphOp and uses much memory.
The memory usage even increases by the second call of backward.

I thought this could be related to #18862, but the latest nightly build 1.2.0.dev20190605 didn't solve the problem.

Expected behavior

I expect that the JIT version requires the almost same amount of memory as the non-JIT version.

Environment

I created a new conda environment and installed the latest nightly build (1.2.0.dev20190605).

PyTorch version: 1.2.0.dev20190605
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: CentOS Linux release 7.5.1804 (Core)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-28)
CMake version: version 3.14.4

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.2.88
GPU models and configuration:
GPU 0: Tesla P100-PCIE-16GB
GPU 1: Tesla P100-PCIE-16GB

Nvidia driver version: 396.26
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.2.0.dev20190605
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch-nightly 1.2.0.dev20190605 py3.7_cuda9.0.176_cudnn7.5.1_0 pytorch

Additional context

Metadata

Metadata

Assignees

Labels

high priorityoncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions