-
Notifications
You must be signed in to change notification settings - Fork 4.7k
fix: engine initializes optimizer attributes at the beginning #7410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
As in `destroy`, `self.optimizer` is called, but the error
out calling to `destroy` can happen in `__init__`, even before
optimizer and scheduler is configured. So we need to move
`self.optimizer` to the top to avoid triggering another
exception.
e.g.:
```logs
File "deepspeed/runtime/engine.py", line 453, in _configure_tensor_parallel_states
assert self.zero_optimization_stage(
AssertionError: Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated
Exception ignored in: <function DeepSpeedEngine.__del__ at 0x1516c0610820>
Traceback (most recent call last):
File "deepspeed/runtime/engine.py", line 509, in __del__
self.destroy()
File "deepspeed/runtime/engine.py", line 512, in destroy
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
File "deepspeed/runtime/engine.py", line 621, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DeepSpeedEngine' object has no attribute 'optimizer'
```
Signed-off-by: Hollow Man <[email protected]>
|
@HollowMan6, thanks for the PR. It seems this is to help with graceful exit in the event of failure during initialization. Is that correct? |
Yes, especially when the optimizer is not initialized. |
|
@HollowMan6, got it. Is it possible to add a unit test? |
I'm not quite sure if it's feasible to test, as we just need to make sure |
…eedai#7410) As in `destroy`, `self.optimizer` is called, but the error out calling to `destroy` can happen in `__init__`, even before optimizer and scheduler is configured. So we need to move `self.optimizer` to the top to avoid triggering another exception. e.g.: ```logs File "deepspeed/runtime/engine.py", line 453, in _configure_tensor_parallel_states assert self.zero_optimization_stage( AssertionError: Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated Exception ignored in: <function DeepSpeedEngine.__del__ at 0x1516c0610820> Traceback (most recent call last): File "deepspeed/runtime/engine.py", line 509, in __del__ self.destroy() File "deepspeed/runtime/engine.py", line 512, in destroy if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): File "deepspeed/runtime/engine.py", line 621, in __getattr__ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") AttributeError: 'DeepSpeedEngine' object has no attribute 'optimizer' ``` Signed-off-by: Hollow Man <[email protected]>
…eedai#7410) As in `destroy`, `self.optimizer` is called, but the error out calling to `destroy` can happen in `__init__`, even before optimizer and scheduler is configured. So we need to move `self.optimizer` to the top to avoid triggering another exception. e.g.: ```logs File "deepspeed/runtime/engine.py", line 453, in _configure_tensor_parallel_states assert self.zero_optimization_stage( AssertionError: Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated Exception ignored in: <function DeepSpeedEngine.__del__ at 0x1516c0610820> Traceback (most recent call last): File "deepspeed/runtime/engine.py", line 509, in __del__ self.destroy() File "deepspeed/runtime/engine.py", line 512, in destroy if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): File "deepspeed/runtime/engine.py", line 621, in __getattr__ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") AttributeError: 'DeepSpeedEngine' object has no attribute 'optimizer' ``` Signed-off-by: Hollow Man <[email protected]>
As in
destroy,self.optimizeris called, but the error out calling todestroycan happen in__init__, even before optimizer and scheduler is configured. So we need to moveself.optimizerto the top to avoid triggering another exception.e.g.: