-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Enabling Muon Optimizer in DeepSpeed #7509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adding Muon dependencies to setup.py file.
Authorship: @pengdurice and @PKUWZP Related Issue: #7438 # Introduction [Muon](https://arxiv.org/abs/2502.16982), a new optimizer that has attracted the community’s attention recently shows promising results in training large language models. Adding the Muon Optimizer to DeepSpeed, a popular OSS framework for large scale training and inference is critically important for DeepSpeed users and developers. There has been a [PR](#7454) attempting the adoption. (Huge Thanks to @qimcis), which is a good starting point. It still requires more substantial effort to make it fully compatible and work within DeepSpeed. We are publishing this PR to fully enable Muon Optimizer capabilities for DeepSpeed. # Issues and solutions ## Issues 1. With stage 1, 2 or 3, the optimizer states will be partitioned within the same data parallel group. This means that each process is already handling only parts of the model parameters and there is no need to use the DP solution as in the [code](https://github.com/KellerJordan/Muon/blob/master/muon.py#L195). 2. The parameters (and the gradients) will be flattened to 1D vector before being used in the optimizer, thus nullifying the major hypothesis of the muon optimizer: it works by orthogonalizing the updates for each matrix (dim >=2) ## Solutions To solve the issues, we propose this new PR in which: 1. We simplify the Muon code by [removing](master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-c9052994e41caee9ca88363749c10af08655f8019f08dc971c018663d25a3712R22) the partitioning and muon updates logics. 2. We [move](master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1867) the muon update to the [get_flat_partition](master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1848) function of stage 1 and 2 DeepSpeedZeroOptimizer in which per parameter gradients are collected before being flattened and used by the optimizer to update the model parameters. Since each parameter is still in its original shape, we can easily apply the muon updates. 3. We also save the momentum buffer into the optimizer’ state so that we have a smooth convergence after applying the saved checkpoints. 4. We added comprehensive unit tests to validate Muon Optimizer's correctness and functionality. # Future directions and roadmap In the future, several follow up works are of interests: - [ ] Create a CPU offload version. - [ ] Apply Muon to Stage 3 - [ ] Use the highly optimized version of Adam for the Adam part of MuonWithAuxAdam optimizer. - [ ] More efficient implementations e.g. a) add specialized kernels for Newton-Schulz iteration and muon updates; b) parallelize updates for the parameters (currently, each parameter is updated separately and sequentially) --------- Co-authored-by: Peng Du <[email protected]> Co-authored-by: pengdurice <[email protected]> Co-authored-by: Zhipeng Wang <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Signed-off-by: Ma, Guokai <[email protected]>
The initialization of DeepCompile+Z1/2 now fails due to the change introduced in #7509. This PR resolves the issue by: - Adding an argument to optimizer.get_flat_partition - Skipping the entire allreduce function in the engine --------- Signed-off-by: Masahiro Tanaka <[email protected]>
|
@PKUWZP is MuonClip (Used by Kimi K2 https://arxiv.org/abs/2507.20534) also of future interest? From the paper clipping is essential to avoid exploding attention logits if the model is very large. |
The original Muon optimizer PR (#7509) requires user to explicitly set `use_muon` flags in `model.parameters()`, as shown in test https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/ops/muon/test_muon.py#L27 . This PR integrate setting of `use_muon` into DeepSpeed before engine initialization. This makes Muon optimizer easier to use. User only needs to change optimizer in `config.json` from `AdamW` to `Muon`, no need to change code. It will solve the following issue #7552 --------- Signed-off-by: Ma, Guokai <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]>
Original PR #7509 by PKUWZP Original: deepspeedai/DeepSpeed#7509
Merged from original PR #7509 Original: deepspeedai/DeepSpeed#7509
Authorship: @pengdurice and @PKUWZP Related Issue: deepspeedai#7438 # Introduction [Muon](https://arxiv.org/abs/2502.16982), a new optimizer that has attracted the community’s attention recently shows promising results in training large language models. Adding the Muon Optimizer to DeepSpeed, a popular OSS framework for large scale training and inference is critically important for DeepSpeed users and developers. There has been a [PR](deepspeedai#7454) attempting the adoption. (Huge Thanks to @qimcis), which is a good starting point. It still requires more substantial effort to make it fully compatible and work within DeepSpeed. We are publishing this PR to fully enable Muon Optimizer capabilities for DeepSpeed. # Issues and solutions ## Issues 1. With stage 1, 2 or 3, the optimizer states will be partitioned within the same data parallel group. This means that each process is already handling only parts of the model parameters and there is no need to use the DP solution as in the [code](https://github.com/KellerJordan/Muon/blob/master/muon.py#L195). 2. The parameters (and the gradients) will be flattened to 1D vector before being used in the optimizer, thus nullifying the major hypothesis of the muon optimizer: it works by orthogonalizing the updates for each matrix (dim >=2) ## Solutions To solve the issues, we propose this new PR in which: 1. We simplify the Muon code by [removing](deepspeedai/DeepSpeed@master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-c9052994e41caee9ca88363749c10af08655f8019f08dc971c018663d25a3712R22) the partitioning and muon updates logics. 2. We [move](deepspeedai/DeepSpeed@master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1867) the muon update to the [get_flat_partition](deepspeedai/DeepSpeed@master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1848) function of stage 1 and 2 DeepSpeedZeroOptimizer in which per parameter gradients are collected before being flattened and used by the optimizer to update the model parameters. Since each parameter is still in its original shape, we can easily apply the muon updates. 3. We also save the momentum buffer into the optimizer’ state so that we have a smooth convergence after applying the saved checkpoints. 4. We added comprehensive unit tests to validate Muon Optimizer's correctness and functionality. # Future directions and roadmap In the future, several follow up works are of interests: - [ ] Create a CPU offload version. - [ ] Apply Muon to Stage 3 - [ ] Use the highly optimized version of Adam for the Adam part of MuonWithAuxAdam optimizer. - [ ] More efficient implementations e.g. a) add specialized kernels for Newton-Schulz iteration and muon updates; b) parallelize updates for the parameters (currently, each parameter is updated separately and sequentially) --------- Co-authored-by: Peng Du <[email protected]> Co-authored-by: pengdurice <[email protected]> Co-authored-by: Zhipeng Wang <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
The initialization of DeepCompile+Z1/2 now fails due to the change introduced in deepspeedai#7509. This PR resolves the issue by: - Adding an argument to optimizer.get_flat_partition - Skipping the entire allreduce function in the engine --------- Signed-off-by: Masahiro Tanaka <[email protected]>
The original Muon optimizer PR (deepspeedai#7509) requires user to explicitly set `use_muon` flags in `model.parameters()`, as shown in test https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/ops/muon/test_muon.py#L27 . This PR integrate setting of `use_muon` into DeepSpeed before engine initialization. This makes Muon optimizer easier to use. User only needs to change optimizer in `config.json` from `AdamW` to `Muon`, no need to change code. It will solve the following issue deepspeedai#7552 --------- Signed-off-by: Ma, Guokai <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]>
|
How can I use Muon for stage 3? As I see it, it has not yet been applied to stage 3. |
Hi, team is working on enabling on stage 3. Stay tuned. Thanks! |
Authorship: @pengdurice and @PKUWZP
Related Issue: #7438
Introduction
Muon, a new optimizer that has attracted the community’s attention recently shows promising results in training large language models. Adding the Muon Optimizer to DeepSpeed, a popular OSS framework for large scale training and inference is critically important for DeepSpeed users and developers. There has been a PR attempting the adoption. (Huge Thanks to @qimcis), which is a good starting point. It still requires more substantial effort to make it fully compatible and work within DeepSpeed. We are publishing this PR to fully enable Muon Optimizer capabilities for DeepSpeed.
Issues and solutions
Issues
Solutions
To solve the issues, we propose this new PR in which:
We simplify the Muon code by removing the partitioning and muon updates logics.
We move the muon update to the get_flat_partition function of stage 1 and 2 DeepSpeedZeroOptimizer in which per parameter gradients are collected before being flattened and used by the optimizer to update the model parameters. Since each parameter is still in its original shape, we can easily apply the muon updates.
We also save the momentum buffer into the optimizer’ state so that we have a smooth convergence after applying the saved checkpoints.
We added comprehensive unit tests to validate Muon Optimizer's correctness and functionality.
Future directions and roadmap
In the future, several follow up works are of interests: