Skip to content

Gradients not averaged by GAS when using DeepSpeed + model_accepts_loss_kwargs=True (Qwen3, Llama3, etc.) #45305

@florian6973

Description

@florian6973

System Info

  • transformers version: 5.3.0
  • Platform: Linux-5.14.0-427.33.1.el9_4.x86_64-x86_64-with-glibc2.34
  • Python version: 3.11.15
  • Huggingface_hub version: 1.6.0
  • Safetensors version: 0.7.0
  • Accelerate version: 1.13.0
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: MULTI_GPU
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 2
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: True
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • DeepSpeed version: 0.18.9
  • PyTorch version (accelerator?): 2.6.0+cu124 (CUDA)
  • Using distributed or parallel set-up in script?: Yes
  • Using GPU in script?: Yes
  • GPU type: NVIDIA L40S

Who can help?

@SunMarc, @3outeille, @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Bug description

When using the HuggingFace Trainer with DeepSpeed and any modern model where model_accepts_loss_kwargs=True (e.g., Qwen3 [I tested], Llama3 [I tested], and most recent HuggingFace models), gradients are summed instead of averaged across gradient accumulation steps. This causes ~GAS× higher grad norms compared to training the same model without DeepSpeed, leading to divergent training dynamics.

Root cause

The issue is a conflict between two code paths in Trainer.training_step() (trainer.py, lines 1925–1933 in v5.3.0):

# Line 1925-1927: Conditional loss normalization
if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
    loss = loss / self.current_gradient_accumulation_steps
 
# Line 1931-1933: Unconditional scale_wrt_gas=False for DeepSpeed (added by PR #35808)
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
    kwargs["scale_wrt_gas"] = False
 
self.accelerator.backward(loss, **kwargs)

PR #35808 added scale_wrt_gas=False to prevent double-scaling when the Trainer already divides loss by GAS. This fix is correct when model_accepts_loss_kwargs=False (Trainer divides on line 1927, and scale_wrt_gas=False prevents DeepSpeed from dividing again).

However, for modern models where model_accepts_loss_kwargs=True:

  1. Trainer skips dividing (line 1925 condition is False)
  2. DeepSpeed's gradient scaling is disabled (scale_wrt_gas=False_backward_prologue_per_tensor hook is a no-op)
  3. Nobody divides by GAS → gradients are summed, not averaged

Trace of the bug

Step model_accepts_loss_kwargs=True (Qwen3, etc.) model_accepts_loss_kwargs=False (older models)
Trainer divides loss by GAS? ❌ No (line 1925 skipped) ✅ Yes (line 1927)
scale_wrt_gas=False passed? ✅ Yes (line 1932) ✅ Yes (line 1932)
DeepSpeed hook divides grads? ❌ No (_scale_wrt_gas=False) ❌ No (_scale_wrt_gas=False)
Net GAS scaling ❌ None — gradients summed ✅ 1/GAS from Trainer — correct

Minimal reproduction

Important: Run each configuration as a separate process so the model is freshly loaded each time.

"""
min_repro.py — Run separately for each configuration:
  python min_repro.py --use_ds false
  deepspeed --num_gpus=1 min_repro.py --use_ds true
 
Compare grad_norm at equivalent steps between the two runs.
With GAS=8 and the bug present, DS grad_norm will be ~8x higher.
"""
import argparse, json, os, random, tempfile
import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, set_seed,
)
 
DS_CONFIG = {
    "bf16": {"enabled": "auto"},
    "zero_optimization": {"stage": 3,
        "offload_optimizer": {"device": "cpu", "pin_memory": True},
        "offload_param": {"device": "cpu", "pin_memory": True},
        "overlap_comm": True, "contiguous_gradients": True,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto"},
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
}
 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--use_ds", type=str, default="false")
    parser.add_argument("--local_rank", type=int, default=-1)
    args = parser.parse_args()
    use_ds = args.use_ds.lower() in ("true", "1", "yes")
 
    set_seed(42)
 
    model_name = "Qwen/Qwen3-0.6B-Base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
 
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
    model.resize_token_embeddings(len(tokenizer))
 
    # Random tokens so loss does not trivially collapse
    random.seed(42)
    vocab, seq_len, n = len(tokenizer), 128, 200
    data = {
        "input_ids": [[random.randint(0, vocab-1) for _ in range(seq_len)] for _ in range(n)],
        "labels":    [[random.randint(0, vocab-1) for _ in range(seq_len)] for _ in range(n)],
    }
    dataset = Dataset.from_dict(data)
 
    ds_path = None
    if use_ds:
        ds_path = os.path.join(tempfile.mkdtemp(), "ds.json")
        with open(ds_path, "w") as f: json.dump(DS_CONFIG, f)
 
    training_args = TrainingArguments(
        output_dir=f"./test_gas_{'ds' if use_ds else 'no_ds'}",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        max_steps=20, logging_steps=1,
        learning_rate=5e-5, lr_scheduler_type="cosine",
        bf16=True, deepspeed=ds_path,
        report_to="none", save_strategy="no",
        max_grad_norm=0.0,  # Disable clipping so raw grad norms are visible
    )
    trainer = Trainer(model=model, args=training_args,
                      train_dataset=dataset, processing_class=tokenizer)
 
    print(f"\nDeepSpeed={use_ds}, model_accepts_loss_kwargs={trainer.model_accepts_loss_kwargs}")
    trainer.train()
 
if __name__ == "__main__":
    main()

Run as two separate invocations and compare the grad_norm column:

python min_repro.py --use_ds false          # baseline
deepspeed --num_gpus=1 min_repro.py --use_ds true   # ~8x higher grad_norm expected

Evidence from real training

Training Qwen3-1.7B-Base with identical configs (GAS=8, lr=2e-5, cosine schedule, bf16), with and without DeepSpeed ZeRO-3:

  • train/grad_norm: DeepSpeed run shows ~5-7× higher grad norms (compressed from theoretical 8× by max_grad_norm=1.0 clipping)
  • train/loss: DeepSpeed run plateaus ~0.5 higher — model underfits because effective LR is ~8× too large (unaveraged gradients × clipped norm still produce larger updates)
  • train/learning_rate: Schedules diverge because the optimizer behaves differently under inflated gradients

The ratio of grad norms matches the gradient_accumulation_steps=8 value, confirming that no GAS averaging is happening in the DeepSpeed path.

Expected behavior

Gradient norms and loss curves should be identical (up to numerical precision) regardless of whether DeepSpeed is enabled, given the same effective batch size and hyperparameters.

The fix should condition scale_wrt_gas on whether the Trainer already normalized the loss:

# Proposed fix in training_step():
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
    if not self.model_accepts_loss_kwargs or num_items_in_batch is None:
        # Trainer already divided by GAS above → disable DeepSpeed's scaling
        kwargs["scale_wrt_gas"] = False
    else:
        # Trainer did NOT divide → let DeepSpeed handle it
        kwargs["scale_wrt_gas"] = True

Alternatively, always divide in the Trainer and always set scale_wrt_gas=False:

# Always normalize, regardless of model_accepts_loss_kwargs
loss = loss / self.current_gradient_accumulation_steps
 
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
    kwargs["scale_wrt_gas"] = False

Workaround

Until this is fixed, users can subclass Trainer:

class FixedGASTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_accepts_loss_kwargs = False

This forces the Trainer to always divide loss by GAS (line 1927), and combined with the existing scale_wrt_gas=False, produces correct gradient averaging.

Related issues

Affected models

Any model where model_accepts_loss_kwargs=True, which includes most models added since late 2024: Qwen3, Gemma3, Llama 3, Phi-4, and any model whose forward() accepts **kwargs or has accepts_loss_kwargs = True.

LLM Disclaimer: I used Anthropic Claude 4.6 Opus Extended to structure the bug report based on my observations

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions