Skip to content

Conversation

@HollowMan6
Copy link
Contributor

As in destroy, self.optimizer is called, but the error out calling to destroy can happen in __init__, even before optimizer and scheduler is configured. So we need to move self.optimizer to the top to avoid triggering another exception.

e.g.:

  File "deepspeed/runtime/engine.py", line 453, in _configure_tensor_parallel_states
    assert self.zero_optimization_stage(
AssertionError: Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated
Exception ignored in: <function DeepSpeedEngine.__del__ at 0x1516c0610820>
Traceback (most recent call last):
  File "deepspeed/runtime/engine.py", line 509, in __del__
    self.destroy()
  File "deepspeed/runtime/engine.py", line 512, in destroy
    if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
  File "deepspeed/runtime/engine.py", line 621, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DeepSpeedEngine' object has no attribute 'optimizer'

As in `destroy`, `self.optimizer` is called, but the error
out calling to `destroy` can happen in `__init__`, even before
optimizer and scheduler is configured. So we need to move
`self.optimizer` to the top to avoid triggering another
exception.

e.g.:
```logs
  File "deepspeed/runtime/engine.py", line 453, in _configure_tensor_parallel_states
    assert self.zero_optimization_stage(
AssertionError: Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated
Exception ignored in: <function DeepSpeedEngine.__del__ at 0x1516c0610820>
Traceback (most recent call last):
  File "deepspeed/runtime/engine.py", line 509, in __del__
    self.destroy()
  File "deepspeed/runtime/engine.py", line 512, in destroy
    if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
  File "deepspeed/runtime/engine.py", line 621, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DeepSpeedEngine' object has no attribute 'optimizer'
```

Signed-off-by: Hollow Man <[email protected]>
@sfc-gh-truwase
Copy link
Collaborator

@HollowMan6, thanks for the PR. It seems this is to help with graceful exit in the event of failure during initialization. Is that correct?

@HollowMan6
Copy link
Contributor Author

@HollowMan6, thanks for the PR. It seems this is to help with graceful exit in the event of failure during initialization. Is that correct?

Yes, especially when the optimizer is not initialized.

@sfc-gh-truwase
Copy link
Collaborator

@HollowMan6, got it. Is it possible to add a unit test?

@HollowMan6
Copy link
Contributor Author

@HollowMan6, got it. Is it possible to add a unit test?

I'm not quite sure if it's feasible to test, as we just need to make sure self.optimizer = None is set before any code that might cause an exception, and it's hard to craft a unit test for such a situation.

@sfc-gh-truwase sfc-gh-truwase merged commit 15f054d into deepspeedai:master Jul 7, 2025
10 checks passed
lpnpcs pushed a commit to lpnpcs/DeepSpeed that referenced this pull request Jul 30, 2025
…eedai#7410)

As in `destroy`, `self.optimizer` is called, but the error out calling
to `destroy` can happen in `__init__`, even before optimizer and
scheduler is configured. So we need to move `self.optimizer` to the top
to avoid triggering another exception.

e.g.:
```logs
  File "deepspeed/runtime/engine.py", line 453, in _configure_tensor_parallel_states
    assert self.zero_optimization_stage(
AssertionError: Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated
Exception ignored in: <function DeepSpeedEngine.__del__ at 0x1516c0610820>
Traceback (most recent call last):
  File "deepspeed/runtime/engine.py", line 509, in __del__
    self.destroy()
  File "deepspeed/runtime/engine.py", line 512, in destroy
    if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
  File "deepspeed/runtime/engine.py", line 621, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DeepSpeedEngine' object has no attribute 'optimizer'
```

Signed-off-by: Hollow Man <[email protected]>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
…eedai#7410)

As in `destroy`, `self.optimizer` is called, but the error out calling
to `destroy` can happen in `__init__`, even before optimizer and
scheduler is configured. So we need to move `self.optimizer` to the top
to avoid triggering another exception.

e.g.:
```logs
  File "deepspeed/runtime/engine.py", line 453, in _configure_tensor_parallel_states
    assert self.zero_optimization_stage(
AssertionError: Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated
Exception ignored in: <function DeepSpeedEngine.__del__ at 0x1516c0610820>
Traceback (most recent call last):
  File "deepspeed/runtime/engine.py", line 509, in __del__
    self.destroy()
  File "deepspeed/runtime/engine.py", line 512, in destroy
    if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
  File "deepspeed/runtime/engine.py", line 621, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DeepSpeedEngine' object has no attribute 'optimizer'
```

Signed-off-by: Hollow Man <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants