Describe the bug
When using pipeline parallelism (i.e. PipelineEngine), and gradient_accumulation_steps > 1, it appears that the gradients across micro batches are summed instead of averaged.
Screenshots

Gray: micro_batch_size_per_gpu=8, gradient_accumulation_steps=4
Purple: micro_batch_size_per_gpu=2, gradient_accumulation_steps=16
Global batch size is the same, but grad norm is 4 times higher, consistent with summing instead of averaging. Every training setting except batch size and GAS is identical between the runs.
Additional Info
I only noticed this after implementing grad norm logging into a custom optimizer that I'm using.
I'm using Deepspeed as the pipeline parallelism backend for my own training code, which is complex in places, so it's possible I'm doing something wrong or have messed up the Deepspeed integration. However, looking at the Deepspeed code, I can't find where it would be averaging the gradients with pipeline parallelism + GAS>1:
- In DeepSpeedEngine, it scales the loss by GAS here. However this is disabled in PipelineEngine by this line.
- Searching PipelineEngine for _scale_loss_by_gas, I find 3 uses.
- One is here, which is only called by eval.
- The other 2 are here, which is called in training inside a no_grad block only to get the loss metrics for logging.
Looking at the code for the forward and backward pipeline stage implementations, I can't find any scaling of the loss or gradients either.
To Reproduce
I am using DeepSpeed in the context of diffusion-pipe, which I am the developer of. To reproduce you would need to follow the installation instructions there, choose a model and make a dataset, and run the training yourself, making sure to use the GenericOptim optimizer since only that one has grad norm logging. It's a bit hard to give an easy reproduction process.
Expected behavior
Gradients are averaged when using gradient_accumulation_steps. It should act the same as when using a larger batch size instead of GAS.
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
[WARNING] FP Quantizer is using an untested triton version (3.5.1), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.9
[WARNING] using untested triton version (3.5.1), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/anon/miniconda3/envs/diffusion-pipe/lib/python3.1/2/site-packages/torch']
torch version .................... 2.9.1+cu128
deepspeed install path ........... ['/home/anon/miniconda3/envs/diffusion-pipe/lib/python3.12/site-packages/deepspeed']
deepspeed info ................... 0.18.2, unknown, unknown
torch cuda version ............... 12.8
torch hip version ................ None
nvcc version ..................... 12.8
deepspeed wheel compiled w. ...... torch 2.9, cuda 12.8
shared memory (/dev/shm) size .... 125.77 GB
System info (please complete the following information):
- Ubuntu 24.04
- 4x4090
- Python 3.12
Launcher context
deepspeed launcher
Describe the bug
When using pipeline parallelism (i.e. PipelineEngine), and gradient_accumulation_steps > 1, it appears that the gradients across micro batches are summed instead of averaged.
Screenshots
Purple: micro_batch_size_per_gpu=2, gradient_accumulation_steps=16
Global batch size is the same, but grad norm is 4 times higher, consistent with summing instead of averaging. Every training setting except batch size and GAS is identical between the runs.
Additional Info
I only noticed this after implementing grad norm logging into a custom optimizer that I'm using.
I'm using Deepspeed as the pipeline parallelism backend for my own training code, which is complex in places, so it's possible I'm doing something wrong or have messed up the Deepspeed integration. However, looking at the Deepspeed code, I can't find where it would be averaging the gradients with pipeline parallelism + GAS>1:
Looking at the code for the forward and backward pipeline stage implementations, I can't find any scaling of the loss or gradients either.
To Reproduce
I am using DeepSpeed in the context of diffusion-pipe, which I am the developer of. To reproduce you would need to follow the installation instructions there, choose a model and make a dataset, and run the training yourself, making sure to use the GenericOptim optimizer since only that one has grad norm logging. It's a bit hard to give an easy reproduction process.
Expected behavior
Gradients are averaged when using gradient_accumulation_steps. It should act the same as when using a larger batch size instead of GAS.
ds_report output
System info (please complete the following information):
Launcher context
deepspeedlauncher