-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
Memory usage increasing steadily when repeatedly calling forward.
To Reproduce
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(torch.jit.ScriptModule):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 1)
self.lin = nn.Linear(100, 1)
@torch.jit.script_method
def forward(self, x):
x = F.relu(self.conv(x))
x = x.view(x.shape[0], -1)
x = F.relu(self.lin(x))
return x
m = Net()
x = torch.ones((1, 1, 10, 10))
for i in range(1000000):
m(x)
Expected behavior
Environment
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: None
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-16ubuntu3) 7.3.0
CMake version: version 3.10.2
Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip3] numpy==1.14.2
[pip3] numpydoc==0.7.0
[pip3] torch==1.1.0
[pip3] torchvision==0.3.0
[conda] blas 1.0 mkl
[conda] mkl 2018.0.1 h19d6760_4
[conda] mkl-dnn 0.14 2 intel
[conda] mkl-service 1.1.2 py36h17a0993_4
[conda] pytorch-cpu 1.1.0 py3.6_cpu_0 pytorch
[conda] torchvision-cpu 0.3.0 py36_cuNone_1 pytorch
Additional context
Used memory_profiler to record the memory usage by mprof run script.py. Results are shown below.
