-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Relax restrictions of torch.autocast integration #7543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
+106
−40
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Signed-off-by: Masahiro Tanaka <[email protected]>
b25172c to
72dde3c
Compare
Signed-off-by: Masahiro Tanaka <[email protected]>
Collaborator
Author
|
Hi @sfc-gh-truwase, I see tests failed but it seems we ran a wrong revision. For example, the log shows But |
Flakes342
pushed a commit
to Flakes342/DeepSpeed
that referenced
this pull request
Sep 9, 2025
This PR relaxes two restrictions on torch.autocast in the DeepSpeed
engine:
1) Nesting torch.autocast
Currently, we do not expect `torch.autocast` to be used outside the
DeepSpeed engine. Here is the current behavior:
- If `torch.autocast` is enabled in the DeepSpeed config and the engine
detects it is also enabled outside, a warning is displayed.
- If it is disabled in the config, the engine raises an error.
This design prevents the following usage:
```python
with torch.autocast(...):
logits = deepspeed_model(...)
loss = criteria_fn(logits)
```
In this case, we also want to apply autocast to `criteria_fn`. With the
current behavior, we would need move `deepspeed_model(...)` outside the
`torch.autocast` context, leading to inconsistent code between DeepSpeed
and non-DeepSpeed setups. (cannot be handled with `enabled` arg of
`torch.autocast`)
Change in this PR:
`torch.autocast` outside the DeepSpeed engine is ignored, and
- If `torch_autocast` is enabled in the config, DeepSpeed will follow
that setting.
- If it is disabled, DeepSpeed falls back to its own mixed-precision
support (or FP32).
In these cases, DeepSpeed engine shows a message to explain the
behavior.
2) Model’s dtype
Previously, DeepSpeed assumed the model’s dtype must be FP32 when
`torch.autocast` was enabled. However, models with lower-precision
parameters (e.g., BF16) can also be used with autocast. For example, if
both the model and `torch.autocast` use BF16, autocast will upcast
precision-sensitive ops as needed.
Change in this PR:
Removed the assertion that restricted the model’s dtype to FP32.
This PR also adds and updates tests to cover these new behaviors.
---------
Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Flakes342 <[email protected]>
mauryaavinash95
pushed a commit
to DataStates/DeepSpeed
that referenced
this pull request
Oct 4, 2025
This PR relaxes two restrictions on torch.autocast in the DeepSpeed
engine:
1) Nesting torch.autocast
Currently, we do not expect `torch.autocast` to be used outside the
DeepSpeed engine. Here is the current behavior:
- If `torch.autocast` is enabled in the DeepSpeed config and the engine
detects it is also enabled outside, a warning is displayed.
- If it is disabled in the config, the engine raises an error.
This design prevents the following usage:
```python
with torch.autocast(...):
logits = deepspeed_model(...)
loss = criteria_fn(logits)
```
In this case, we also want to apply autocast to `criteria_fn`. With the
current behavior, we would need move `deepspeed_model(...)` outside the
`torch.autocast` context, leading to inconsistent code between DeepSpeed
and non-DeepSpeed setups. (cannot be handled with `enabled` arg of
`torch.autocast`)
Change in this PR:
`torch.autocast` outside the DeepSpeed engine is ignored, and
- If `torch_autocast` is enabled in the config, DeepSpeed will follow
that setting.
- If it is disabled, DeepSpeed falls back to its own mixed-precision
support (or FP32).
In these cases, DeepSpeed engine shows a message to explain the
behavior.
2) Model’s dtype
Previously, DeepSpeed assumed the model’s dtype must be FP32 when
`torch.autocast` was enabled. However, models with lower-precision
parameters (e.g., BF16) can also be used with autocast. For example, if
both the model and `torch.autocast` use BF16, autocast will upcast
precision-sensitive ops as needed.
Change in this PR:
Removed the assertion that restricted the model’s dtype to FP32.
This PR also adds and updates tests to cover these new behaviors.
---------
Signed-off-by: Masahiro Tanaka <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR relaxes two restrictions on torch.autocast in the DeepSpeed engine:
Currently, we do not expect
torch.autocastto be used outside the DeepSpeed engine. Here is the current behavior:torch.autocastis enabled in the DeepSpeed config and the engine detects it is also enabled outside, a warning is displayed.This design prevents the following usage:
In this case, we also want to apply autocast to
criteria_fn. With the current behavior, we would need movedeepspeed_model(...)outside thetorch.autocastcontext, leading to inconsistent code between DeepSpeed and non-DeepSpeed setups. (cannot be handled withenabledarg oftorch.autocast)Change in this PR:
torch.autocastoutside the DeepSpeed engine is ignored, andtorch_autocastis enabled in the config, DeepSpeed will follow that setting.In these cases, DeepSpeed engine shows a message to explain the behavior.
Previously, DeepSpeed assumed the model’s dtype must be FP32 when
torch.autocastwas enabled. However, models with lower-precision parameters (e.g., BF16) can also be used with autocast. For example, if both the model andtorch.autocastuse BF16, autocast will upcast precision-sensitive ops as needed.Change in this PR:
Removed the assertion that restricted the model’s dtype to FP32.
This PR also adds and updates tests to cover these new behaviors.