Skip to content

Conversation

@sfc-gh-truwase
Copy link
Collaborator

@sfc-gh-truwase sfc-gh-truwase commented Aug 25, 2025

Enabled via stage=0 which corresponds to DDP.
Remove hardwired path to b16_optimizer.
Enabletorch.autocast for DDP training
Enable native mixed precision DDP for bfloat16
Update torch.autocast and native mixed precision UTs

image

Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
@sfc-gh-truwase
Copy link
Collaborator Author

@tohtana I have extended your torch.autocast PR to non-ZeRO case and also added some docs. Can you please review?

@sfc-gh-truwase
Copy link
Collaborator Author

@stas00 FYI.

@sfc-gh-truwase sfc-gh-truwase changed the title Enable non-ZeRO run Enable non-ZeRO mode Aug 25, 2025
@PKUWZP
Copy link
Collaborator

PKUWZP commented Aug 26, 2025

@sfc-gh-truwase Can you check why the cpu-torch unit tests are failing? Also can we say for non-ZeRO training, we fully rely on torch.autocast ?

@Antlera
Copy link
Collaborator

Antlera commented Aug 26, 2025

The cpu-torch CI fails due to the strict check in engine._do_optimizer_sanity_check:

elif model_dtype == grad_accum_dtype:
    if model_dtype == torch.bfloat16:
        if self.pipeline_parallelism:
            logger.warning(...)
            return BFLOAT16
        else:
            raise NotImplementedError(...)

It only allows BF16 accumulation when PP is enabled.
Suggest relaxing this to a warning for non-ZeRO/non-PP cases instead of raising NotImplementedError.

@sfc-gh-truwase
Copy link
Collaborator Author

@Antlera and @PKUWZP I found the cause, but a bit delayed in pushing a fix. Thanks for reviewing.

@sfc-gh-truwase
Copy link
Collaborator Author

Also can we say for non-ZeRO training, we fully rely on torch.autocast ?

@PKUWZP, great question. No, non-ZeRO can also use native mixed-precision training. Below is how this PR expands the mixed precision training options of DeepSpeed:

image

Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Olatunji Ruwase <[email protected]>
Copy link
Collaborator

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making stage=0 work, Tunji

The PR is looking good, added a few minor suggestions

@sfc-gh-truwase sfc-gh-truwase merged commit 889f0ea into master Aug 27, 2025
12 checks passed
@sfc-gh-truwase sfc-gh-truwase deleted the sfc-gh-truwase/disable_zero branch August 27, 2025 18:07
delock pushed a commit that referenced this pull request Sep 1, 2025
Enabled via `stage=0` which corresponds to DDP.
Remove hardwired path to b16_optimizer.
Enable`torch.autocast` for DDP training
Enable native mixed precision DDP for bfloat16
Update torch.autocast and native mixed precision UTs

<img width="976" height="184" alt="image"
src="https://github.com/user-attachments/assets/92904cdc-e312-46a4-943f-011eb5ab146a"
/>

---------

Signed-off-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Signed-off-by: Ma, Guokai <[email protected]>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
Enabled via `stage=0` which corresponds to DDP. 
Remove hardwired path to b16_optimizer.
Enable`torch.autocast` for DDP training
Enable native mixed precision DDP for bfloat16
Update torch.autocast and native mixed precision UTs

<img width="976" height="184" alt="image"
src="https://github.com/user-attachments/assets/92904cdc-e312-46a4-943f-011eb5ab146a"
/>

---------

Signed-off-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants