-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
Bug Description
When running a simple LSTM model on the MPS backend with a repetitive loop, memory usage steadily increases, eventually leading to an Out of Memory error. This issue occurs despite clearing the MPS memory cache using torch.mps.empty_cache() after every iteration. The error happens after running approximately 15,666 iterations with a batch size of 16 and hidden size of 256.
Reproduction Steps
Run the following code to reproduce the issue:
import torch
import torch.nn as nn
import platform
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, batch_first=True):
super(LSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=batch_first)
def forward(self, x, hidden):
output, hidden = self.lstm(x, hidden)
return output, hidden
def check_memory_leak():
input_size = 256
hidden_size = 256
batch_size = 16
sequence_length = 10
num_iterations = 100000 # Set a high number to check for memory leaks
# Use MPS if available
device = "mps" if torch.backends.mps.is_available() else "cpu"
# Model initialization
model = LSTMModel(input_size, hidden_size).to(device)
# Input data and hidden state initialization
x = torch.randn(batch_size, sequence_length, input_size).to(device)
hidden = (
torch.zeros(1, batch_size, hidden_size).to(device),
torch.zeros(1, batch_size, hidden_size).to(device),
)
print("Starting memory check...")
for i in range(num_iterations):
with torch.no_grad():
output, hidden = model(x, hidden)
# Clear MPS memory cache
torch.mps.empty_cache()
print(f"Iteration {i + 1}/{num_iterations}: Completed")
if __name__ == "__main__":
print("PyTorch Version:", torch.__version__)
print("Python Version:", platform.python_version())
print("Platform:", platform.system(), platform.release())
print("MPS Available:", torch.backends.mps.is_available())
print("MPS Built:", torch.backends.mps.is_built())
check_memory_leak()Expected Behavior
Memory usage should remain stable or properly recycle after clearing the cache with torch.mps.empty_cache().
Observed Behavior
The program crashes with an Out of Memory error after ~15,666 iterations. The error message is as follows:
RuntimeError: MPS backend out of memory (MPS allocated: 24.00 MB, other allocations: 27.18 GB, max allowed: 27.20 GB). Tried to allocate 16.00 KB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).
Environment Information
MacBook Air 15 M3(24GB)
PyTorch Version: 2.5.1
Python Version: 3.12.2
Platform: Darwin 24.3.0
MPS Available: True
MPS Built: True
Additional Context
This issue may be related to the MPS backend’s memory management while handling LSTM computations. Using torch.mps.empty_cache() does not appear to effectively release memory in this scenario. The problem persists even when torch.no_grad() is used.
Request
Could you please investigate the memory leak issue in the MPS backend for LSTM models? Let me know if further debugging or testing is needed.
Versions
Collecting environment information...
PyTorch version: 2.5.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.3 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.3
Libc version: N/A
Python version: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:54:21) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-15.3-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3
Versions of relevant libraries:
[pip3] efficientnet_pytorch==0.7.1
[pip3] numpy==1.26.4
[pip3] segmentation_models_pytorch==0.4.0
[pip3] torch==2.5.1
[pip3] torchaudio==2.4.1
[pip3] torchvision==0.19.1
[conda] efficientnet-pytorch 0.7.1 pypi_0 pypi
[conda] numpy 2.2.1 pypi_0 pypi
[conda] numpy-base 1.26.4 py312he047099_0
[conda] segmentation-models-pytorch 0.4.0 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] torchaudio 2.4.1 py312_cpu pytorch
[conda] torchvision 0.19.1 py312_cpu pytorch
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @mikaylagawarecki @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen