-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Description:
I observed that bucket.elements is not correctly cleared in the function __reduce_and_partition_ipg_grads when using ZeRO Stage 3.
Specifically, I printed the contents of bucket.elements along with the number of parameters in the bucket every time __reduce_and_partition_ipg_grads is called. As shown in the output below:
The output shows that after the first reduction, bucket.elements is not properly reset. As a result, only a single parameter is reduced in subsequent steps.
To fix the issue, I added an explicit line to reset bucket.elements to 0:
params_in_bucket.clear()
bucket.elements = 0
After applying this change, the reduction behavior returned to normal, as shown below:

This suggests that params_in_bucket.clear() does not correctly reset the bucket state.
DeepSpeed Config:
ds_config = {
"train_batch_size": args.batch_size,
"bf16": {
"enabled": True
},
"zero_optimization": {
"stage": 3,
"reduce_bucket_size": 5e8,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
},
},
"scheduler": {
"type": "WarmupCosineLR",
"params": {
"total_num_steps": total_steps,
"warmup_num_steps": int(args.warmup * total_steps)
}
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": args.lr,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": args.weight_decay
}
},
"gradient_accumulation_steps": 1,
"gradient_clipping": 1.0,
"logging": {
"level": "info"
},
Version Info:
DeepSpeed: 0.17.2+da60a878