-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
matmul in a ScriptModule requires much more memory than in a regular (non-JIT) module.
To Reproduce
Steps to reproduce the behavior: Run the following script.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(768, 100000)
def forward(self, x):
x = x * 2
x = x * 2
x = x * 2
x = self.fc(x)
return x
# trace on cpu
m = Net()
x = torch.randn(10, 128, 768)
jm = torch.jit.trace(m, x)
jm = jm.to("cuda")
for i in range(0,2):
x = torch.randn(10, 128, 768, requires_grad=True).to("cuda")
y = jm(x)
tgt = torch.randn_like(y)
y.backward(tgt)I checked the GPU memory usage at the end of the script (by inserting input()).
nvidia-smi said
| 0 155985 C python 8109MiB |
The non-JIT version
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(768, 100000)
def forward(self, x):
x = x * 2
x = x * 2
x = x * 2
x = self.fc(x)
return x
m = Net().to("cuda")
for i in range(0,2):
x = torch.randn(10, 128, 768, requires_grad=True).to("cuda")
y = m(x)
tgt = torch.randn_like(y)
y.backward(tgt)requires much less.
nvidia-smi showed
| 0 156120 C python 2739MiB |
The memory usage increases during backward.
More precisely, matmul in LinearAlgebra.cpp is called in DifferentiableGraphOp and uses much memory.
The memory usage even increases by the second call of backward.
I thought this could be related to #18862, but the latest nightly build 1.2.0.dev20190605 didn't solve the problem.
Expected behavior
I expect that the JIT version requires the almost same amount of memory as the non-JIT version.
Environment
I created a new conda environment and installed the latest nightly build (1.2.0.dev20190605).
PyTorch version: 1.2.0.dev20190605
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: CentOS Linux release 7.5.1804 (Core)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-28)
CMake version: version 3.14.4
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.2.88
GPU models and configuration:
GPU 0: Tesla P100-PCIE-16GB
GPU 1: Tesla P100-PCIE-16GB
Nvidia driver version: 396.26
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.2.0.dev20190605
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch-nightly 1.2.0.dev20190605 py3.7_cuda9.0.176_cudnn7.5.1_0 pytorch