Enabling Muon Optimizer in DeepSpeed (#7509)

Authorship: @pengdurice and @PKUWZP 

Related Issue: #7438

# Introduction

[Muon](https://arxiv.org/abs/2502.16982), a new optimizer that has
attracted the community’s attention recently shows promising results in
training large language models. Adding the Muon Optimizer to DeepSpeed,
a popular OSS framework for large scale training and inference is
critically important for DeepSpeed users and developers. There has been
a [PR](https://github.com/deepspeedai/DeepSpeed/pull/7454) attempting
the adoption. (Huge Thanks to @qimcis), which is a good starting point.
It still requires more substantial effort to make it fully compatible
and work within DeepSpeed. We are publishing this PR to fully enable
Muon Optimizer capabilities for DeepSpeed.

# Issues and solutions
## Issues
1. With stage 1, 2 or 3, the optimizer states will be partitioned within
the same data parallel group. This means that each process is already
handling only parts of the model parameters and there is no need to use
the DP solution as in the
[code](https://github.com/KellerJordan/Muon/blob/master/muon.py#L195).
2. The parameters (and the gradients) will be flattened to 1D vector
before being used in the optimizer, thus nullifying the major hypothesis
of the muon optimizer: it works by orthogonalizing the updates for each
matrix (dim >=2)

## Solutions
To solve the issues, we propose this new PR in which: 
1. We simplify the Muon code by
[removing](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-c9052994e41caee9ca88363749c10af08655f8019f08dc971c018663d25a3712R22)
the partitioning and muon updates logics.

2. We
[move](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1867)
the muon update to the
[get_flat_partition](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1848)
function of stage 1 and 2 DeepSpeedZeroOptimizer in which per parameter
gradients are collected before being flattened and used by the optimizer
to update the model parameters. Since each parameter is still in its
original shape, we can easily apply the muon updates.
3. We also save the momentum buffer into the optimizer’ state so that we
have a smooth convergence after applying the saved checkpoints.
4. We added comprehensive unit tests to validate Muon Optimizer's
correctness and functionality.

# Future directions and roadmap
In the future, several follow up works are of interests:
- [ ] Create a CPU offload version.
- [ ] Apply Muon to Stage 3
- [ ] Use the highly optimized version of Adam for the Adam part of
MuonWithAuxAdam optimizer.
- [ ] More efficient implementations e.g. a) add specialized kernels for
Newton-Schulz iteration and muon updates; b) parallelize updates for the
parameters (currently, each parameter is updated separately and
sequentially)

---------

Co-authored-by: Peng Du <pedu@linkedin.com>
Co-authored-by: pengdurice <pengduhit@gmail.com>
Co-authored-by: Zhipeng Wang <zhipengbayern@gmail.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Zhipeng Wang
2025-08-26 18:34:35 -07:00
committed by GitHub
parent e4662faffd
commit 66ad278048
8 changed files with 551 additions and 24 deletions

View File

@ -77,9 +77,11 @@ MUADAM_OPTIMIZER = 'muadam'
MUADAMW_OPTIMIZER = 'muadamw'
MUSGD_OPTIMIZER = 'musgd'
LION_OPTIMIZER = 'lion'
MUON_OPTIMIZER = 'muon'
DEEPSPEED_OPTIMIZERS = [
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER,
ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER, LION_OPTIMIZER
ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER, LION_OPTIMIZER, MUON_OPTIMIZER
]
# extra optimizer parameters for adam/adamw

View File

@ -44,12 +44,13 @@ from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
MUSGD_OPTIMIZER, LION_OPTIMIZER
MUSGD_OPTIMIZER, LION_OPTIMIZER, MUON_OPTIMIZER
from deepspeed.runtime.model_checkpointing.constants import ValidationMode, \
CHECKPOINT_TAG_VALIDATION, CHECKPOINT_WRITER, CHECKPOINT_SERIALIZATION
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \
@ -1574,6 +1575,29 @@ class DeepSpeedEngine(Module):
except ImportError:
logger.error("Install mup to use MuSGD optimizer")
optimizer = MuSGD(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == MUON_OPTIMIZER:
zero_stage = self.zero_optimization_stage()
assert zero_stage <= ZeroStageEnum.gradients, "Muon optimizer is not yet compatible with ZeRO Stage 3"
if not all([hasattr(p, 'use_muon') for p in model_parameters]):
msg = "Muon optimizer is used, but the use_muon attribute is NOT configured for some of the model parameters, " \
"please set by `param.use_muon = True / False` for all params"
logger.error(msg)
muon_params = [p for p in model_parameters if p.use_muon]
non_muon_params = [p for p in model_parameters if not p.use_muon]
param_groups = []
if muon_params:
accepted_parameters = dict()
for key in ["lr", "momentum", "weight_decay"]:
if key in optimizer_parameters:
accepted_parameters[key] = optimizer_parameters[key]
param_groups.append(dict(params=muon_params, use_muon=True, **accepted_parameters))
if non_muon_params:
accepted_parameters = dict()
for key in ["lr", "betas", "eps", "weight_decay"]:
if key in optimizer_parameters:
accepted_parameters[key] = optimizer_parameters[key]
param_groups.append(dict(params=non_muon_params, use_muon=False, **accepted_parameters))
optimizer = MuonWithAuxAdam(param_groups)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Peng Du and Zhipeng Wang
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team

View File

@ -0,0 +1,48 @@
# Copyright (c) 2025 Peng Du and Zhipeng Wang
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
try:
from deepspeed.runtime.zero.muon.original_muon import MuonWithAuxAdam as BaseMuonWithAuxAdam
from deepspeed.runtime.zero.muon.original_muon import adam_update
except ImportError:
pass
class MuonWithAuxAdam(BaseMuonWithAuxAdam):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if group["use_muon"]:
# we move the muon update part to the deepspeed's optimizer since the parameter here is a flat version
# thus not suitable for muon update
for p in group["params"]:
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(p.grad.reshape(p.shape), alpha=-group["lr"])
else:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"],
group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
return loss

View File

@ -0,0 +1,323 @@
# Copyright (c) 2024 Keller Jordan
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
MIT License
Copyright (c) 2024 Keller Jordan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import torch
import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check
def zeropower_via_newtonschulz5(G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
return update
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
advantage that it can be stably run in bfloat16 on the GPU.
Muon should only be used for hidden weight layers. The input embedding, final output layer,
and any internal gains or biases should be optimized using a standard method such as AdamW.
Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
collapsing their last 3 dimensions.
Arguments:
lr: The learning rate, in units of spectral norm per update.
weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
params = sorted(params, key=lambda x: x.size(), reverse=True)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])
] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
params_pad[base_i + dist.get_rank()])
return loss
class SingleDeviceMuon(torch.optim.Optimizer):
"""
Muon variant for usage in non-distributed settings.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
return loss
def adam_update(grad, buf1, buf2, step, betas, eps):
buf1.lerp_(grad, 1 - betas[0])
buf2.lerp_(grad.square(), 1 - betas[1])
buf1c = buf1 / (1 - betas[0]**step)
buf2c = buf2 / (1 - betas[1]**step)
return buf1c / (buf2c.sqrt() + eps)
class MuonWithAuxAdam(torch.optim.Optimizer):
"""
Distributed Muon variant that can be used for all parameters in the network, since it runs an
internal AdamW for the parameters that are not compatible with Muon. The user must manually
specify which parameters shall be optimized with Muon and which with Adam by passing in a
list of param_groups with the `use_muon` flag set.
The point of this class is to allow the user to have a single optimizer in their code, rather
than having both a Muon and an Adam which each need to be stepped.
You can see an example usage below:
https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
```
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [model.lm_head.weight]
from muon import MuonWithAuxAdam
adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
param_groups = [*adam_groups, muon_group]
optimizer = MuonWithAuxAdam(param_groups)
```
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
group["betas"] = group.get("betas", (0.9, 0.95))
group["eps"] = group.get("eps", 1e-10)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
super().__init__(param_groups, dict())
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if group["use_muon"]:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])
] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
params_pad[base_i + dist.get_rank()])
else:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"],
group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
return loss
class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
"""
Non-distributed variant of MuonWithAuxAdam.
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
group["betas"] = group.get("betas", (0.9, 0.95))
group["eps"] = group.get("eps", 1e-10)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
super().__init__(param_groups, dict())
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if group["use_muon"]:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
else:
for p in group["params"]:
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"],
group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
return loss

View File

@ -32,7 +32,7 @@ from deepspeed.git_version_info import version
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.muon.original_muon import muon_update
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
@ -561,7 +561,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# will store the averaged gradients required by this partition
self.averaged_gradients = {}
self.all_grad_tensors = {}
# For cpu_offload, will store the averaged gradients required by this partition
self.offload_gradient_dict = {}
@ -850,25 +850,24 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
if self.cpu_offload is False:
for i, _ in enumerate(self.bit16_groups):
if i not in self.averaged_gradients or self.averaged_gradients[i] is None:
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.all_grad_tensors[i] = self.get_all_grad_tensors(self.params_in_partition[i],
dtype=self.gradient_accumulation_dtype)
else:
avg_new = self.get_all_grad_tensors(self.params_in_partition[i],
dtype=self.gradient_accumulation_dtype)
for accumulated_grad, new_avg_grad in zip(self.all_grad_tensors[i], avg_new):
accumulated_grad.add_(new_avg_grad)
if self.is_gradient_accumulation_boundary:
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=self.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
param_group_idx=i,
return_tensor_list=True)
else:
avg_new = self.get_flat_partition(self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=self.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
return_tensor_list=True)
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
accumulated_grad.add_(new_avg_grad)
self.all_grad_tensors[i] = None
self._release_ipg_buffers()
@ -1847,20 +1846,55 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
return total_norm
# creates a flat fused tensor from the tensor list starting at the first_offset
# in the first tensor of the list. If there are not enough elements in the tensor
# list then the flat tensor will be padded with zeros
def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, device, return_tensor_list=False):
flat_tensor_list = []
current_size = 0
def get_all_grad_tensors(self, tensor_list, dtype):
all_grad_tensors = []
for i, tensor in enumerate(tensor_list):
grad_accum = self.get_param_gradient_attribute(tensor)
if grad_accum is None:
grad_accum = torch.zeros_like(tensor, dtype=dtype)
all_grad_tensors.append(grad_accum)
return all_grad_tensors
# creates a flat fused tensor from the tensor list starting at the first_offset
# in the first tensor of the list. If there are not enough elements in the tensor
# list then the flat tensor will be padded with zeros
def get_flat_partition(self,
tensor_list,
first_offset,
partition_size,
dtype,
device,
param_group_idx,
return_tensor_list=False):
flat_tensor_list = []
current_size = 0
# find the flatten copy in the optimizer's state
flatten_copy = self.optimizer.param_groups[param_group_idx]['params'][0]
if (not self.optimizer.state[flatten_copy]) and getattr(
tensor_list[0], 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
self.optimizer.state[flatten_copy] = {}
if "momentum_buffer" not in self.optimizer.state[flatten_copy] and getattr(
tensor_list[0], 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
# need to check the total # of elements in the parameters in this group and this partition
total_size = sum([t.numel() for t in tensor_list])
flatten_bf_list = [torch.zeros([total_size], dtype=dtype)] # put on cpu to save space
self.optimizer.state[flatten_copy]["momentum_buffer"] = self.flatten(flatten_bf_list)
buffer_idx = 0
for i, tensor in enumerate(tensor_list):
grad_accum = self.all_grad_tensors[param_group_idx][i]
if getattr(tensor, 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}"
# create a gpu copy
buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
tensor.numel()).view(tensor.size()).to(device).to(dtype)
grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum'])
# write back to the cpu copy
torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
tensor.numel()).data.copy_(buffer.view(-1).data)
tensor = grad_accum
num_elements = tensor.numel()
buffer_idx += num_elements
tensor_offset = 0
# we need to offset to get to the right element
@ -1979,6 +2013,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
else:
for k in self.averaged_gradients.keys():
self.averaged_gradients[k] = None
self.all_grad_tensors[k] = None
see_memory_usage('After overflow after clearing gradients')
@ -2040,7 +2075,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.free_grad_in_param_list(self.params_in_partition[i])
self.averaged_gradients[i] = None
self.all_grad_tensors[i] = None
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()

View File

@ -45,6 +45,13 @@ ZERO_SUPPORTED_OPTIMIZERS = [
DeepSpeedCPULion, FusedLion
]
# Add MuonWithAuxAdam to supported list if muon is installed
try:
from deepspeed.runtime.muon_optimizer import MuonWithAuxAdam
ZERO_SUPPORTED_OPTIMIZERS.append(MuonWithAuxAdam)
except ImportError:
pass
# Add apex FusedAdam to supported list if apex is installed
try:
import apex

View File

@ -0,0 +1,84 @@
# Copyright (c) 2025 Peng Du and Zhipeng Wang
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import deepspeed
import torch
import pytest
from unit.common import DistributedTest
from unit.simple_model import SimpleModel
from deepspeed.accelerator import get_accelerator
if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
# 'optimizer_type, zero_stage, lr, hidden_dim, nlayer'
muon_configs = []
for optimizer_name in ['muon', 'adam']:
for stage in [1, 2]:
for lr in [0.01, 0.05]:
for model_dim in [32, 128]:
for nlayer in [5, 10]:
muon_configs.append([optimizer_name, stage, lr, model_dim, nlayer])
def set_muon_flag(params):
for p in params:
if p.ndim >= 2:
setattr(p, "use_muon", True)
else:
setattr(p, "use_muon", False)
@pytest.mark.parametrize('optimizer_type, zero_stage, lr, hidden_dim, nlayer', muon_configs)
class TestMuonConfigs(DistributedTest):
def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer):
optimizer_params = {"lr": lr}
batch_size = 8
config_dict = {
"train_batch_size": batch_size,
"optimizer": {
"type": optimizer_type,
"params": optimizer_params
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage,
}
}
# Perform a few training steps to ensure the optimizer works correctly
model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayer)
if 'muon' in optimizer_type:
set_muon_flag(model.parameters())
initial_params = [p.clone().cpu() for p in model.parameters()]
engine, optimizer, _, _ = deepspeed.initialize(
config=config_dict,
model=model,
model_parameters=model.parameters(),
dist_init_required=False,
)
assert optimizer_type in optimizer.optimizer.__class__.__name__.lower(
), f"Expected optimizer type {optimizer_type}, got {optimizer.optimizer.__class__.__name__}"
steps = 5
for _ in range(steps):
# Random inputs: (batch_size, hidden_dim)
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
# Random class labels: (batch_size,)
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
# Forward + loss
loss = engine(x, y)
# Backward
engine.backward(loss)
engine.step()
# Verify that parameters have been updated
after_training = [p.clone().cpu() for p in model.parameters()]
for initial, final in zip(initial_params, after_training):
assert not torch.equal(initial.cpu(), final.cpu()), "Parameters should have been updated during training"