Skip to content

Conversation

@tohtana
Copy link
Collaborator

@tohtana tohtana commented Nov 20, 2025

DeepSpeed optimizer always creates fp32 master params/gradients/optimizer states.
However, we sometimes want to keep them lower precision given torch.autocast support.
This PR allows lower precision master params/grads/optimizer states with bf16/fp16 enabled.

DeepSpeed currently accepts fp16_master_weights_and_gradients option under fp16 section (not documented) with ZeRO1/2. This PR extends this for bf16 and also ZeRO3.

In bf16 section, we can have new items bf16_master_weights_and_grads and bf16_optimizer_states.
Similary to fp16_master_weights_and_grads, bf16_master_weights_and_grads keeps master parameters in bf16. bf16_optimizer_states keeps optimizer states also in bf16. Here is an example configuration:

        "bf16": {
            "enabled": true,
            "bf16_master_weights_and_grads": true,
            "bf16_optimizer_states": true
        }

Note that bf16_master_weights_and_grads==True and bf16_optimizer_states==False is supported only with cpu offloading. Also, we don't have fp16_optimizer_states as it won't be practical. More details are described in config-json.md

Previously, torch.autocast support (torch_autocast section in config) was not compatible with bf16 fp16 enabled, but we now accept the combination.

This PR also adds some test cases for the configurations as well as the combination with torch.autocast.

Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
@tohtana tohtana marked this pull request as ready for review November 20, 2025 02:51
@tohtana tohtana enabled auto-merge (squash) December 4, 2025 03:33
@tohtana tohtana merged commit 39a682d into deepspeedai:master Dec 4, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants