Skip to content

[BUG] save_checkpoint race when consolidating NVMe offloaded tensors → FileExistsError #7549

@H1manshu21

Description

@H1manshu21

Describe the bug
When calling save_checkpoint (via accelerator.save_state) from all ranks on a multi-GPU run with ZeRO Stage 3 NVMe offloading enabled, DeepSpeed attempts to consolidate per-rank NVMe offload directories into a single shared offloaded_tensors destination using shutil.copytree. Because all ranks call the method concurrently, two processes race to create/copy into the same destination directory and one fails with FileExistsError: [Errno 17] File exists. The docstring correctly states that all ranks must call save_checkpoint, but the consolidation step is not robust against concurrent creation of the same destination dir.

To Reproduce
Steps to reproduce the behavior:

  1. Launch a distributed run with 2 GPUs (e.g. accelerate launch --config_file default_config.yaml test.py).
  2. In your training script (executed by all ranks), call accelerator.save_state(output_dir="/path/to/test_checkpoint").
  3. Observe FileExistsError during the consolidation step when DeepSpeed attempts to copy NVMe offload files into the shared target.

NOTE
stage3_gather_16bit_weights_on_model_save is set to true

Expected behavior
DeepSpeed should successfully produce a single consolidated offloaded_tensors directory without throwing FileExistsError even when all ranks call save_checkpoint concurrently. And finally it should be able to save model checkpoint.

ds_report output

himanshu:~/test_checkpoint$ ds_report
[2025-09-09 12:46:49,291] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -c /tmp/tmp7l04eczs/test.c -o /tmp/tmp7l04eczs/test.o
x86_64-linux-gnu-gcc /tmp/tmp7l04eczs/test.o -L/usr/local/cuda -L/usr/local/cuda/lib64 -lcufile -o /tmp/tmp7l04eczs/a.out
x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -c /tmp/tmp1wt4a4se/test.c -o /tmp/tmp1wt4a4se/test.o
x86_64-linux-gnu-gcc /tmp/tmp1wt4a4se/test.o -laio -o /tmp/tmp1wt4a4se/a.out
gds .................... [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.6
 [WARNING]  using untested triton version (3.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/himanshu/.local/lib/python3.10/site-packages/torch']
torch version .................... 2.6.0+cu124
deepspeed install path ........... ['/home/himanshu/.local/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.16.9, unknown, unknown
torch cuda version ............... 12.4
torch hip version ................ None
nvcc version ..................... 12.3
deepspeed wheel compiled w. ...... torch 2.6, cuda 12.4
shared memory (/dev/shm) size .... 1.01 TB

Screenshots

Image

System info (please complete the following information):

  • OS: Ubuntu 22.04
  • One machine with x6 A6000
  • Python version 3.10

Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?
Accelerate Launcher from HuggingFace
Command: accelerate launch --config_file default_config.yaml test.py

Docker context
Are you using a specific docker image that you can share?
No

Additional context
Add any other context about the problem here.

Metadata

Metadata

Labels

bugSomething isn't workingtraining

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions