Skip to content

[MPS] Torch 2.5.x and Nightlies are using 50% more memory and are 60% slower than 2.4.1 run Stable Diffusion #139389

@Vargol

Description

@Vargol

🐛 Describe the bug

I've seen a lot of people mentioning in various forums but not seen an issue raised here so here we go. I'm reporting this as an MPS issue as I've not been able to test on NVIDIA or AMD etc.

Running various Stable Diffusion and new transformer DiT models using 2.5.x and nightly releases of PyTorch shows a very significant downgrade in performance compared to running them in in the same environment using pytorch 2.4.1.

For example a standard SDXL run via Diffusers using 2.4.1, the python binary reports using 9.5Gb and runs at 5.7 seconds per iteration, under 2.5.1 or nightly it reports 14.9 GB and runs at 8.5 s/i

SD3.5 goes from running without setting PYTORCH_MPS_HIGH_WATERMARK_RATIO on a 24Gb M3 to failing if its not set to 0.0 and using 37Gb compared to 27.7GB when they first start iterating (2.5.1 takes ages to start iterating they'll use more when it kicks in properly).

Here is the diffusers SDXL script I used as that'll be the smallest and fastest model. Its a little messy as was a test script of something else but I wanted to give you the exact same script, so you'll have to forgive the unused imports and variables :-)

from diffusers import StableDiffusionXLPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
import torch
from torch import mps
import time

torch.mps.set_per_process_memory_fraction(0.0)

prompt = "Emily Booth as Sappho, cold color palette, vivid colors, detailed, 8k, 35mm photo, Kodachrome, Lomography, highly detailed"
negative_prompt = "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured"

prompt="Nina Hagen as a 1960s (stingray puppet)++, supermarionation++ strings visible ,  ,  . vivid colors, detailed, 8k, 35mm photo, Kodachrome, Lomography, highly detailed"

negative_prompt="painting, drawing, illustration, glitch, mutated, cross-eyed, ugly, disfigured"

isteps=30
height=1024
width=1024
seed = 432773521 

vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
                                    subfolder='vae',
                                    torch_dtype=torch.bfloat16,
                                    force_upcast=False).to('mps')


pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", vae=vae,
    torch_dtype=torch.bfloat16, variant="fp16").to('mps')

#pipe.enable_vae_tiling()
#pipe.enable_attention_slicing()

pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

generator =  torch.Generator("mps").manual_seed(seed);

image = pipe(prompt=prompt, negative_prompt=negative_prompt,
               height=height, width=width, 
               num_inference_steps=10,
               guidance_scale=8,
               generator=generator,
               ).images[0]

image.save(f'bfloat16.png')

environment id a straight forward diffusers venv

python -m venv sd35
cd sd35
. bin/activate
pip install diffusers accelerate transformers

torch version was switch around br doing an uninstall and reinstall

pip uninstall torch torchvision 
pip install torch==2.5.1 torchvision==0.20.1

or

pip uninstall torch torchvision 
pip install torch==2.4.1 torchvision==0.19.1

Versions

python lib/python3.11/site-packages/torch/utils/collect_env.py
Collecting environment information...
PyTorch version: 2.5.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.0.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.4)
CMake version: version 3.29.5
Libc version: N/A

Python version: 3.11.10 (main, Sep 7 2024, 08:05:54) [Clang 16.0.0 (clang-1600.0.26.3)] (64-bit runtime)
Python platform: macOS-15.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3

Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

Labels

high prioritymodule: memory usagePyTorch is using more memory than it should, or it is leaking memorymodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions