mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Authorship: @pengdurice and @PKUWZP Related Issue: #7438 # Introduction [Muon](https://arxiv.org/abs/2502.16982), a new optimizer that has attracted the community’s attention recently shows promising results in training large language models. Adding the Muon Optimizer to DeepSpeed, a popular OSS framework for large scale training and inference is critically important for DeepSpeed users and developers. There has been a [PR](https://github.com/deepspeedai/DeepSpeed/pull/7454) attempting the adoption. (Huge Thanks to @qimcis), which is a good starting point. It still requires more substantial effort to make it fully compatible and work within DeepSpeed. We are publishing this PR to fully enable Muon Optimizer capabilities for DeepSpeed. # Issues and solutions ## Issues 1. With stage 1, 2 or 3, the optimizer states will be partitioned within the same data parallel group. This means that each process is already handling only parts of the model parameters and there is no need to use the DP solution as in the [code](https://github.com/KellerJordan/Muon/blob/master/muon.py#L195). 2. The parameters (and the gradients) will be flattened to 1D vector before being used in the optimizer, thus nullifying the major hypothesis of the muon optimizer: it works by orthogonalizing the updates for each matrix (dim >=2) ## Solutions To solve the issues, we propose this new PR in which: 1. We simplify the Muon code by [removing](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-c9052994e41caee9ca88363749c10af08655f8019f08dc971c018663d25a3712R22) the partitioning and muon updates logics. 2. We [move](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1867) the muon update to the [get_flat_partition](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1848) function of stage 1 and 2 DeepSpeedZeroOptimizer in which per parameter gradients are collected before being flattened and used by the optimizer to update the model parameters. Since each parameter is still in its original shape, we can easily apply the muon updates. 3. We also save the momentum buffer into the optimizer’ state so that we have a smooth convergence after applying the saved checkpoints. 4. We added comprehensive unit tests to validate Muon Optimizer's correctness and functionality. # Future directions and roadmap In the future, several follow up works are of interests: - [ ] Create a CPU offload version. - [ ] Apply Muon to Stage 3 - [ ] Use the highly optimized version of Adam for the Adam part of MuonWithAuxAdam optimizer. - [ ] More efficient implementations e.g. a) add specialized kernels for Newton-Schulz iteration and muon updates; b) parallelize updates for the parameters (currently, each parameter is updated separately and sequentially) --------- Co-authored-by: Peng Du <pedu@linkedin.com> Co-authored-by: pengdurice <pengduhit@gmail.com> Co-authored-by: Zhipeng Wang <zhipengbayern@gmail.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>