Skip to content

Memory Leak in MPS Backend During LSTM Iterations (Out of Memory Error) #145374

@Tyndall-log

Description

@Tyndall-log

🐛 Describe the bug

Bug Description

When running a simple LSTM model on the MPS backend with a repetitive loop, memory usage steadily increases, eventually leading to an Out of Memory error. This issue occurs despite clearing the MPS memory cache using torch.mps.empty_cache() after every iteration. The error happens after running approximately 15,666 iterations with a batch size of 16 and hidden size of 256.

Reproduction Steps

Run the following code to reproduce the issue:

import torch
import torch.nn as nn
import platform

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, batch_first=True):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=batch_first)

    def forward(self, x, hidden):
        output, hidden = self.lstm(x, hidden)
        return output, hidden

def check_memory_leak():
    input_size = 256
    hidden_size = 256
    batch_size = 16
    sequence_length = 10
    num_iterations = 100000  # Set a high number to check for memory leaks

    # Use MPS if available
    device = "mps" if torch.backends.mps.is_available() else "cpu"

    # Model initialization
    model = LSTMModel(input_size, hidden_size).to(device)

    # Input data and hidden state initialization
    x = torch.randn(batch_size, sequence_length, input_size).to(device)
    hidden = (
        torch.zeros(1, batch_size, hidden_size).to(device),
        torch.zeros(1, batch_size, hidden_size).to(device),
    )

    print("Starting memory check...")
    for i in range(num_iterations):
        with torch.no_grad():
            output, hidden = model(x, hidden)
        
        # Clear MPS memory cache
        torch.mps.empty_cache()
        
        print(f"Iteration {i + 1}/{num_iterations}: Completed")

if __name__ == "__main__":
    print("PyTorch Version:", torch.__version__)
    print("Python Version:", platform.python_version())
    print("Platform:", platform.system(), platform.release())
    print("MPS Available:", torch.backends.mps.is_available())
    print("MPS Built:", torch.backends.mps.is_built())

    check_memory_leak()

Expected Behavior

Memory usage should remain stable or properly recycle after clearing the cache with torch.mps.empty_cache().

Observed Behavior

The program crashes with an Out of Memory error after ~15,666 iterations. The error message is as follows:

RuntimeError: MPS backend out of memory (MPS allocated: 24.00 MB, other allocations: 27.18 GB, max allowed: 27.20 GB). Tried to allocate 16.00 KB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

Environment Information

MacBook Air 15 M3(24GB)

PyTorch Version: 2.5.1
Python Version: 3.12.2
Platform: Darwin 24.3.0
MPS Available: True
MPS Built: True

Additional Context

This issue may be related to the MPS backend’s memory management while handling LSTM computations. Using torch.mps.empty_cache() does not appear to effectively release memory in this scenario. The problem persists even when torch.no_grad() is used.

Request

Could you please investigate the memory leak issue in the MPS backend for LSTM models? Let me know if further debugging or testing is needed.

Versions

Collecting environment information...
PyTorch version: 2.5.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.3 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.3
Libc version: N/A

Python version: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:54:21) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-15.3-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3

Versions of relevant libraries:
[pip3] efficientnet_pytorch==0.7.1
[pip3] numpy==1.26.4
[pip3] segmentation_models_pytorch==0.4.0
[pip3] torch==2.5.1
[pip3] torchaudio==2.4.1
[pip3] torchvision==0.19.1
[conda] efficientnet-pytorch      0.7.1                    pypi_0    pypi
[conda] numpy                     2.2.1                    pypi_0    pypi
[conda] numpy-base                1.26.4          py312he047099_0  
[conda] segmentation-models-pytorch 0.4.0                    pypi_0    pypi
[conda] torch                     2.5.1                    pypi_0    pypi
[conda] torchaudio                2.4.1                 py312_cpu    pytorch
[conda] torchvision               0.19.1                py312_cpu    pytorch

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @mikaylagawarecki @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

Labels

module: memory usagePyTorch is using more memory than it should, or it is leaking memorymodule: mpsRelated to Apple Metal Performance Shaders frameworkmodule: rnnIssues related to RNN support (LSTM, GRU, etc)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions