Skip to content

FSDP2 - High memory usage with LORA #3474

@byi8220

Description

@byi8220

System Info

- `Accelerate` version: 1.6.0.dev0
- Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /root/workspace/lora-repro/repro/bin/accelerate
- Python version: 3.11.11
- Numpy version: 2.2.4
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch SDAA available: False
- PyTorch MUSA available: False
- System RAM: 503.51 GB
- GPU type: NVIDIA A40
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Original context in: #3394 (comment)

When attempting to prepare a Lora model with accelerate, the memory used is significantly higher than that of running the model normally. Loading and running Qwen/Qwen2.5-1.5B with a lora allocates a peak around ~3700 MiB. However, when attempting to prepare this with accelerate in order to distribute it across 2 GPUs, the allocations skyrockets on both GPUs to over 9000 MiB, with more reserved but unused.

Reproduction code:

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

import torch
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from accelerate import Accelerator, FullyShardedDataParallelPlugin

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

fsdp_plugin = FullyShardedDataParallelPlugin(
    fsdp_version=2
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B")
max_seq_length = 2048

dtype = torch.bfloat16
device_index = torch.cuda.current_device()
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-1.5B", # Note, warning about sliding window attn is not real: https://github.com/huggingface/transformers/pull/36316
    torch_dtype=dtype,
)

lora_config = LoraConfig(
    r = 32,
    lora_alpha = 128,
    target_modules = "all-linear",
    lora_dropout = 0,
    bias = "none",
    task_type = TaskType.CAUSAL_LM,
)

print(model.dtype)
# Get LoRA and setup model
model = get_peft_model(model, lora_config)

# Make sure we don't require gradients on non-lora params
with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name: param.requires_grad_(True)
        else: param.requires_grad_(False)

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

optimizer = torch.optim.Adam(model.parameters())

model.train()

print(f"[Device {device_index}] Memory summary before accelerator prepare")
print(torch.cuda.memory_summary())

model, optimizer = accelerator.prepare(model, optimizer)

torch.cuda.empty_cache()
print(f"[Device {device_index}] Memory summary after accelerator prepare")
print(torch.cuda.memory_summary())

print(f"[Device {device_index}] Is accelerator FSDP2? {accelerator.is_fsdp2}")

sample = "Lorem ipsum dolor sit amet consectetur adipiscing elit. Quisque faucibus ex sapien vitae pellentesque sem placerat."


enc = tokenizer(sample)
tokens = torch.tensor(enc.input_ids).to(model.device)
attn_mask = torch.tensor(enc.attention_mask).to(model.device)
input_ids = tokens[None, :-1]
labels = tokens[None, 1:]

out = model(input_ids, labels=labels)
loss = out['loss']

accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()

torch.cuda.empty_cache()
print(f"[Device {device_index}] Memory summary after one backward pass")
print(torch.cuda.memory_summary())

Full results when running python lora-repro.py > no_accelerate.txt: https://gist.github.com/byi8220/8f69f92e64620aaf69417295ff41a9ac

Full results when running NCCL_P2P_DISABLE=1 accelerate launch --fsdp_version=2 --fsdp_cpu_ram_efficient_loading=true --mixed_precision=bf16 lora-repro.py > with_accelerate.txt: https://gist.github.com/byi8220/f4214fb70dbef20ba353f58af496c317

Expected behavior

Expect that when using accelerate, GPU memory usage on both GPUs is significantly lower (preferably lower than that of the single device use case)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions