Skip to content

Conversation

@tohtana
Copy link
Collaborator

@tohtana tohtana commented Sep 3, 2025

This PR relaxes two restrictions on torch.autocast in the DeepSpeed engine:

  1. Nesting torch.autocast
    Currently, we do not expect torch.autocast to be used outside the DeepSpeed engine. Here is the current behavior:
  • If torch.autocast is enabled in the DeepSpeed config and the engine detects it is also enabled outside, a warning is displayed.
  • If it is disabled in the config, the engine raises an error.

This design prevents the following usage:

with torch.autocast(...):
    logits = deepspeed_model(...)
    loss = criteria_fn(logits)

In this case, we also want to apply autocast to criteria_fn. With the current behavior, we would need move deepspeed_model(...) outside the torch.autocast context, leading to inconsistent code between DeepSpeed and non-DeepSpeed setups. (cannot be handled with enabled arg of torch.autocast)

Change in this PR:
torch.autocast outside the DeepSpeed engine is ignored, and

  • If torch_autocast is enabled in the config, DeepSpeed will follow that setting.
  • If it is disabled, DeepSpeed falls back to its own mixed-precision support (or FP32).

In these cases, DeepSpeed engine shows a message to explain the behavior.

  1. Model’s dtype

Previously, DeepSpeed assumed the model’s dtype must be FP32 when torch.autocast was enabled. However, models with lower-precision parameters (e.g., BF16) can also be used with autocast. For example, if both the model and torch.autocast use BF16, autocast will upcast precision-sensitive ops as needed.

Change in this PR:
Removed the assertion that restricted the model’s dtype to FP32.

This PR also adds and updates tests to cover these new behaviors.

Signed-off-by: Masahiro Tanaka <[email protected]>
@tohtana tohtana force-pushed the tohtana/loosen_autocast_assertion branch from b25172c to 72dde3c Compare September 3, 2025 06:30
@tohtana tohtana changed the title loosen restriction of autocast Relax restrictions of torch.autocast integration Sep 3, 2025
Signed-off-by: Masahiro Tanaka <[email protected]>
@tohtana
Copy link
Collaborator Author

tohtana commented Sep 3, 2025

Hi @sfc-gh-truwase,

I see tests failed but it seems we ran a wrong revision. For example, the log shows

[gw0] [  6%] PASSED tests/unit/runtime/zero/test_zero_autocast.py::TestZeroAutoCast::test_error_autocast_outside_ds[dtype0-1-True]

But test_error_autocast_outside_ds doesn't exist any more (You can see it on the diff of this PR). Do you have any suggestion how to fix it?

@tohtana tohtana enabled auto-merge (squash) September 3, 2025 18:50
@tohtana tohtana disabled auto-merge September 3, 2025 19:14
@tohtana tohtana merged commit 66bf2a6 into master Sep 3, 2025
11 of 12 checks passed
@tohtana tohtana deleted the tohtana/loosen_autocast_assertion branch September 3, 2025 19:15
Flakes342 pushed a commit to Flakes342/DeepSpeed that referenced this pull request Sep 9, 2025
This PR relaxes two restrictions on torch.autocast in the DeepSpeed
engine:

1) Nesting torch.autocast
Currently, we do not expect `torch.autocast` to be used outside the
DeepSpeed engine. Here is the current behavior:
- If `torch.autocast` is enabled in the DeepSpeed config and the engine
detects it is also enabled outside, a warning is displayed.
- If it is disabled in the config, the engine raises an error.

This design prevents the following usage:
```python
with torch.autocast(...):
    logits = deepspeed_model(...)
    loss = criteria_fn(logits)
```
In this case, we also want to apply autocast to `criteria_fn`. With the
current behavior, we would need move `deepspeed_model(...)` outside the
`torch.autocast` context, leading to inconsistent code between DeepSpeed
and non-DeepSpeed setups. (cannot be handled with `enabled` arg of
`torch.autocast`)

Change in this PR:
`torch.autocast` outside the DeepSpeed engine is ignored, and
- If `torch_autocast` is enabled in the config, DeepSpeed will follow
that setting.
- If it is disabled, DeepSpeed falls back to its own mixed-precision
support (or FP32).

In these cases, DeepSpeed engine shows a message to explain the
behavior.

2) Model’s dtype

Previously, DeepSpeed assumed the model’s dtype must be FP32 when
`torch.autocast` was enabled. However, models with lower-precision
parameters (e.g., BF16) can also be used with autocast. For example, if
both the model and `torch.autocast` use BF16, autocast will upcast
precision-sensitive ops as needed.

Change in this PR:
Removed the assertion that restricted the model’s dtype to FP32.

This PR also adds and updates tests to cover these new behaviors.

---------

Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Flakes342 <[email protected]>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
This PR relaxes two restrictions on torch.autocast in the DeepSpeed
engine:

1) Nesting torch.autocast
Currently, we do not expect `torch.autocast` to be used outside the
DeepSpeed engine. Here is the current behavior:
- If `torch.autocast` is enabled in the DeepSpeed config and the engine
detects it is also enabled outside, a warning is displayed.
- If it is disabled in the config, the engine raises an error.

This design prevents the following usage:
```python
with torch.autocast(...):
    logits = deepspeed_model(...)
    loss = criteria_fn(logits)
```
In this case, we also want to apply autocast to `criteria_fn`. With the
current behavior, we would need move `deepspeed_model(...)` outside the
`torch.autocast` context, leading to inconsistent code between DeepSpeed
and non-DeepSpeed setups. (cannot be handled with `enabled` arg of
`torch.autocast`)

Change in this PR:
`torch.autocast` outside the DeepSpeed engine is ignored, and
- If `torch_autocast` is enabled in the config, DeepSpeed will follow
that setting.
- If it is disabled, DeepSpeed falls back to its own mixed-precision
support (or FP32).

In these cases, DeepSpeed engine shows a message to explain the
behavior.

2) Model’s dtype

Previously, DeepSpeed assumed the model’s dtype must be FP32 when
`torch.autocast` was enabled. However, models with lower-precision
parameters (e.g., BF16) can also be used with autocast. For example, if
both the model and `torch.autocast` use BF16, autocast will upcast
precision-sensitive ops as needed.

Change in this PR:
Removed the assertion that restricted the model’s dtype to FP32.

This PR also adds and updates tests to cover these new behaviors.

---------

Signed-off-by: Masahiro Tanaka <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants