Skip to content

Conversation

@rraminen
Copy link
Contributor

@rraminen rraminen commented Oct 23, 2025

This PR fixes an issue in deepspeed/runtime/fp16/fused_optimizer.py where the gradient overflow handling logic incorrectly exited the function too early, resulting in wrong forward pass and loss calculations in certain FP16 training scenarios.

The return self.overflow and self.timers.log(OVERFLOW_TIMERS) calls are now correctly moved inside the if self.overflow: block so that the function only returns early when an actual overflow is detected.

Origin of the error: 889f0ea

cc: @jithunnair-amd

@eternalNight
Copy link
Contributor

Thanks! This should fix #7632.

@rraminen
Copy link
Contributor Author

Hi @tjruwase, could you please help in reviewing this PR?

@rraminen
Copy link
Contributor Author

Hi @tohtana, could you please help in reviewing this PR?

Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rraminen Sorry for my late response! Approved it.

@tohtana tohtana enabled auto-merge (squash) October 31, 2025 01:40
@tohtana tohtana merged commit d56e847 into deepspeedai:master Oct 31, 2025
12 checks passed
rraminen added a commit to rraminen/DeepSpeed that referenced this pull request Dec 1, 2025
…edai#7645)

This PR fixes an issue in deepspeed/runtime/fp16/fused_optimizer.py
where the gradient overflow handling logic incorrectly exited the
function too early, resulting in wrong forward pass and loss
calculations in certain FP16 training scenarios.

The `return self.overflow` and `self.timers.log(OVERFLOW_TIMERS)` calls
are now correctly moved inside the `if self.overflow:` block so that the
function only returns early when an actual overflow is detected.

Origin of the error:
deepspeedai@889f0ea

cc: @jithunnair-amd

Co-authored-by: Olatunji Ruwase <[email protected]>
Signed-off-by: rraminen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants