mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Enabling Muon Optimizer in DeepSpeed (#7509)
Authorship: @pengdurice and @PKUWZP Related Issue: #7438 # Introduction [Muon](https://arxiv.org/abs/2502.16982), a new optimizer that has attracted the community’s attention recently shows promising results in training large language models. Adding the Muon Optimizer to DeepSpeed, a popular OSS framework for large scale training and inference is critically important for DeepSpeed users and developers. There has been a [PR](https://github.com/deepspeedai/DeepSpeed/pull/7454) attempting the adoption. (Huge Thanks to @qimcis), which is a good starting point. It still requires more substantial effort to make it fully compatible and work within DeepSpeed. We are publishing this PR to fully enable Muon Optimizer capabilities for DeepSpeed. # Issues and solutions ## Issues 1. With stage 1, 2 or 3, the optimizer states will be partitioned within the same data parallel group. This means that each process is already handling only parts of the model parameters and there is no need to use the DP solution as in the [code](https://github.com/KellerJordan/Muon/blob/master/muon.py#L195). 2. The parameters (and the gradients) will be flattened to 1D vector before being used in the optimizer, thus nullifying the major hypothesis of the muon optimizer: it works by orthogonalizing the updates for each matrix (dim >=2) ## Solutions To solve the issues, we propose this new PR in which: 1. We simplify the Muon code by [removing](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-c9052994e41caee9ca88363749c10af08655f8019f08dc971c018663d25a3712R22) the partitioning and muon updates logics. 2. We [move](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1867) the muon update to the [get_flat_partition](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1848) function of stage 1 and 2 DeepSpeedZeroOptimizer in which per parameter gradients are collected before being flattened and used by the optimizer to update the model parameters. Since each parameter is still in its original shape, we can easily apply the muon updates. 3. We also save the momentum buffer into the optimizer’ state so that we have a smooth convergence after applying the saved checkpoints. 4. We added comprehensive unit tests to validate Muon Optimizer's correctness and functionality. # Future directions and roadmap In the future, several follow up works are of interests: - [ ] Create a CPU offload version. - [ ] Apply Muon to Stage 3 - [ ] Use the highly optimized version of Adam for the Adam part of MuonWithAuxAdam optimizer. - [ ] More efficient implementations e.g. a) add specialized kernels for Newton-Schulz iteration and muon updates; b) parallelize updates for the parameters (currently, each parameter is updated separately and sequentially) --------- Co-authored-by: Peng Du <pedu@linkedin.com> Co-authored-by: pengdurice <pengduhit@gmail.com> Co-authored-by: Zhipeng Wang <zhipengbayern@gmail.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -77,9 +77,11 @@ MUADAM_OPTIMIZER = 'muadam'
|
||||
MUADAMW_OPTIMIZER = 'muadamw'
|
||||
MUSGD_OPTIMIZER = 'musgd'
|
||||
LION_OPTIMIZER = 'lion'
|
||||
MUON_OPTIMIZER = 'muon'
|
||||
|
||||
DEEPSPEED_OPTIMIZERS = [
|
||||
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER,
|
||||
ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER, LION_OPTIMIZER
|
||||
ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER, LION_OPTIMIZER, MUON_OPTIMIZER
|
||||
]
|
||||
|
||||
# extra optimizer parameters for adam/adamw
|
||||
|
@ -44,12 +44,13 @@ from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_
|
||||
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
|
||||
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
|
||||
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
|
||||
MUSGD_OPTIMIZER, LION_OPTIMIZER
|
||||
MUSGD_OPTIMIZER, LION_OPTIMIZER, MUON_OPTIMIZER
|
||||
|
||||
from deepspeed.runtime.model_checkpointing.constants import ValidationMode, \
|
||||
CHECKPOINT_TAG_VALIDATION, CHECKPOINT_WRITER, CHECKPOINT_SERIALIZATION
|
||||
|
||||
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
|
||||
from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam
|
||||
from deepspeed.runtime.constants import \
|
||||
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
|
||||
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \
|
||||
@ -1574,6 +1575,29 @@ class DeepSpeedEngine(Module):
|
||||
except ImportError:
|
||||
logger.error("Install mup to use MuSGD optimizer")
|
||||
optimizer = MuSGD(model_parameters, **optimizer_parameters)
|
||||
elif self.optimizer_name() == MUON_OPTIMIZER:
|
||||
zero_stage = self.zero_optimization_stage()
|
||||
assert zero_stage <= ZeroStageEnum.gradients, "Muon optimizer is not yet compatible with ZeRO Stage 3"
|
||||
if not all([hasattr(p, 'use_muon') for p in model_parameters]):
|
||||
msg = "Muon optimizer is used, but the use_muon attribute is NOT configured for some of the model parameters, " \
|
||||
"please set by `param.use_muon = True / False` for all params"
|
||||
logger.error(msg)
|
||||
muon_params = [p for p in model_parameters if p.use_muon]
|
||||
non_muon_params = [p for p in model_parameters if not p.use_muon]
|
||||
param_groups = []
|
||||
if muon_params:
|
||||
accepted_parameters = dict()
|
||||
for key in ["lr", "momentum", "weight_decay"]:
|
||||
if key in optimizer_parameters:
|
||||
accepted_parameters[key] = optimizer_parameters[key]
|
||||
param_groups.append(dict(params=muon_params, use_muon=True, **accepted_parameters))
|
||||
if non_muon_params:
|
||||
accepted_parameters = dict()
|
||||
for key in ["lr", "betas", "eps", "weight_decay"]:
|
||||
if key in optimizer_parameters:
|
||||
accepted_parameters[key] = optimizer_parameters[key]
|
||||
param_groups.append(dict(params=non_muon_params, use_muon=False, **accepted_parameters))
|
||||
optimizer = MuonWithAuxAdam(param_groups)
|
||||
else:
|
||||
torch_optimizer = getattr(torch.optim, self.optimizer_name())
|
||||
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
|
||||
|
4
deepspeed/runtime/zero/muon/__init__.py
Normal file
4
deepspeed/runtime/zero/muon/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Peng Du and Zhipeng Wang
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
48
deepspeed/runtime/zero/muon/muon_optimizer.py
Normal file
48
deepspeed/runtime/zero/muon/muon_optimizer.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2025 Peng Du and Zhipeng Wang
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
try:
|
||||
from deepspeed.runtime.zero.muon.original_muon import MuonWithAuxAdam as BaseMuonWithAuxAdam
|
||||
from deepspeed.runtime.zero.muon.original_muon import adam_update
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class MuonWithAuxAdam(BaseMuonWithAuxAdam):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
for group in self.param_groups:
|
||||
if group["use_muon"]:
|
||||
# we move the muon update part to the deepspeed's optimizer since the parameter here is a flat version
|
||||
# thus not suitable for muon update
|
||||
for p in group["params"]:
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(p.grad.reshape(p.shape), alpha=-group["lr"])
|
||||
else:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"],
|
||||
group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
return loss
|
323
deepspeed/runtime/zero/muon/original_muon.py
Normal file
323
deepspeed/runtime/zero/muon/original_muon.py
Normal file
@ -0,0 +1,323 @@
|
||||
# Copyright (c) 2024 Keller Jordan
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
"""
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Keller Jordan
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check
|
||||
|
||||
|
||||
def zeropower_via_newtonschulz5(G, steps: int):
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
|
||||
|
||||
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
|
||||
momentum.lerp_(grad, 1 - beta)
|
||||
update = grad.lerp_(momentum, beta) if nesterov else momentum
|
||||
if update.ndim == 4: # for the case of conv filters
|
||||
update = update.view(len(update), -1)
|
||||
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
|
||||
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
|
||||
return update
|
||||
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
|
||||
advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Muon should only be used for hidden weight layers. The input embedding, final output layer,
|
||||
and any internal gains or biases should be optimized using a standard method such as AdamW.
|
||||
Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
|
||||
collapsing their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
lr: The learning rate, in units of spectral norm per update.
|
||||
weight_decay: The AdamW-style weight decay.
|
||||
momentum: The momentum. A value of 0.95 here is usually fine.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
|
||||
params = sorted(params, key=lambda x: x.size(), reverse=True)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])
|
||||
] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
|
||||
params_pad[base_i + dist.get_rank()])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SingleDeviceMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon variant for usage in non-distributed settings.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def adam_update(grad, buf1, buf2, step, betas, eps):
|
||||
buf1.lerp_(grad, 1 - betas[0])
|
||||
buf2.lerp_(grad.square(), 1 - betas[1])
|
||||
buf1c = buf1 / (1 - betas[0]**step)
|
||||
buf2c = buf2 / (1 - betas[1]**step)
|
||||
return buf1c / (buf2c.sqrt() + eps)
|
||||
|
||||
|
||||
class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
"""
|
||||
Distributed Muon variant that can be used for all parameters in the network, since it runs an
|
||||
internal AdamW for the parameters that are not compatible with Muon. The user must manually
|
||||
specify which parameters shall be optimized with Muon and which with Adam by passing in a
|
||||
list of param_groups with the `use_muon` flag set.
|
||||
|
||||
The point of this class is to allow the user to have a single optimizer in their code, rather
|
||||
than having both a Muon and an Adam which each need to be stepped.
|
||||
|
||||
You can see an example usage below:
|
||||
|
||||
https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
|
||||
```
|
||||
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
|
||||
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
||||
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
||||
head_params = [model.lm_head.weight]
|
||||
|
||||
from muon import MuonWithAuxAdam
|
||||
adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
|
||||
adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
|
||||
muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
|
||||
param_groups = [*adam_groups, muon_group]
|
||||
optimizer = MuonWithAuxAdam(param_groups)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
if group["use_muon"]:
|
||||
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 0.02)
|
||||
group["momentum"] = group.get("momentum", 0.95)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
|
||||
else:
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 3e-4)
|
||||
group["betas"] = group.get("betas", (0.9, 0.95))
|
||||
group["eps"] = group.get("eps", 1e-10)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
|
||||
super().__init__(param_groups, dict())
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
if group["use_muon"]:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])
|
||||
] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
|
||||
params_pad[base_i + dist.get_rank()])
|
||||
else:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"],
|
||||
group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
||||
"""
|
||||
Non-distributed variant of MuonWithAuxAdam.
|
||||
"""
|
||||
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
if group["use_muon"]:
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 0.02)
|
||||
group["momentum"] = group.get("momentum", 0.95)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
|
||||
else:
|
||||
# defaults
|
||||
group["lr"] = group.get("lr", 3e-4)
|
||||
group["betas"] = group.get("betas", (0.9, 0.95))
|
||||
group["eps"] = group.get("eps", 1e-10)
|
||||
group["weight_decay"] = group.get("weight_decay", 0)
|
||||
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
|
||||
super().__init__(param_groups, dict())
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
if group["use_muon"]:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
else:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
# continue
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"],
|
||||
group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
return loss
|
@ -32,7 +32,7 @@ from deepspeed.git_version_info import version
|
||||
|
||||
from deepspeed.runtime.constants import PIPE_REPLICATED
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from deepspeed.runtime.zero.muon.original_muon import muon_update
|
||||
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
|
||||
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
|
||||
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
|
||||
@ -561,7 +561,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
|
||||
# will store the averaged gradients required by this partition
|
||||
self.averaged_gradients = {}
|
||||
|
||||
self.all_grad_tensors = {}
|
||||
# For cpu_offload, will store the averaged gradients required by this partition
|
||||
self.offload_gradient_dict = {}
|
||||
|
||||
@ -850,25 +850,24 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
|
||||
if self.cpu_offload is False:
|
||||
for i, _ in enumerate(self.bit16_groups):
|
||||
|
||||
if i not in self.averaged_gradients or self.averaged_gradients[i] is None:
|
||||
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
|
||||
self.all_grad_tensors[i] = self.get_all_grad_tensors(self.params_in_partition[i],
|
||||
dtype=self.gradient_accumulation_dtype)
|
||||
else:
|
||||
avg_new = self.get_all_grad_tensors(self.params_in_partition[i],
|
||||
dtype=self.gradient_accumulation_dtype)
|
||||
for accumulated_grad, new_avg_grad in zip(self.all_grad_tensors[i], avg_new):
|
||||
accumulated_grad.add_(new_avg_grad)
|
||||
if self.is_gradient_accumulation_boundary:
|
||||
self.averaged_gradients[i] = self.get_flat_partition(
|
||||
self.params_in_partition[i],
|
||||
self.first_offset[i],
|
||||
self.partition_size[i],
|
||||
dtype=self.gradient_accumulation_dtype,
|
||||
device=get_accelerator().current_device_name(),
|
||||
param_group_idx=i,
|
||||
return_tensor_list=True)
|
||||
else:
|
||||
avg_new = self.get_flat_partition(self.params_in_partition[i],
|
||||
self.first_offset[i],
|
||||
self.partition_size[i],
|
||||
dtype=self.gradient_accumulation_dtype,
|
||||
device=get_accelerator().current_device_name(),
|
||||
return_tensor_list=True)
|
||||
|
||||
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
|
||||
accumulated_grad.add_(new_avg_grad)
|
||||
self.all_grad_tensors[i] = None
|
||||
|
||||
self._release_ipg_buffers()
|
||||
|
||||
@ -1847,20 +1846,55 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
|
||||
return total_norm
|
||||
|
||||
# creates a flat fused tensor from the tensor list starting at the first_offset
|
||||
# in the first tensor of the list. If there are not enough elements in the tensor
|
||||
# list then the flat tensor will be padded with zeros
|
||||
def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, device, return_tensor_list=False):
|
||||
flat_tensor_list = []
|
||||
current_size = 0
|
||||
|
||||
def get_all_grad_tensors(self, tensor_list, dtype):
|
||||
all_grad_tensors = []
|
||||
for i, tensor in enumerate(tensor_list):
|
||||
grad_accum = self.get_param_gradient_attribute(tensor)
|
||||
if grad_accum is None:
|
||||
grad_accum = torch.zeros_like(tensor, dtype=dtype)
|
||||
all_grad_tensors.append(grad_accum)
|
||||
return all_grad_tensors
|
||||
|
||||
# creates a flat fused tensor from the tensor list starting at the first_offset
|
||||
# in the first tensor of the list. If there are not enough elements in the tensor
|
||||
# list then the flat tensor will be padded with zeros
|
||||
def get_flat_partition(self,
|
||||
tensor_list,
|
||||
first_offset,
|
||||
partition_size,
|
||||
dtype,
|
||||
device,
|
||||
param_group_idx,
|
||||
return_tensor_list=False):
|
||||
flat_tensor_list = []
|
||||
current_size = 0
|
||||
# find the flatten copy in the optimizer's state
|
||||
flatten_copy = self.optimizer.param_groups[param_group_idx]['params'][0]
|
||||
if (not self.optimizer.state[flatten_copy]) and getattr(
|
||||
tensor_list[0], 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
|
||||
self.optimizer.state[flatten_copy] = {}
|
||||
if "momentum_buffer" not in self.optimizer.state[flatten_copy] and getattr(
|
||||
tensor_list[0], 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
|
||||
# need to check the total # of elements in the parameters in this group and this partition
|
||||
total_size = sum([t.numel() for t in tensor_list])
|
||||
flatten_bf_list = [torch.zeros([total_size], dtype=dtype)] # put on cpu to save space
|
||||
self.optimizer.state[flatten_copy]["momentum_buffer"] = self.flatten(flatten_bf_list)
|
||||
|
||||
buffer_idx = 0
|
||||
for i, tensor in enumerate(tensor_list):
|
||||
grad_accum = self.all_grad_tensors[param_group_idx][i]
|
||||
if getattr(tensor, 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
|
||||
assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}"
|
||||
# create a gpu copy
|
||||
buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
|
||||
tensor.numel()).view(tensor.size()).to(device).to(dtype)
|
||||
grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum'])
|
||||
# write back to the cpu copy
|
||||
torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
|
||||
tensor.numel()).data.copy_(buffer.view(-1).data)
|
||||
tensor = grad_accum
|
||||
num_elements = tensor.numel()
|
||||
buffer_idx += num_elements
|
||||
tensor_offset = 0
|
||||
|
||||
# we need to offset to get to the right element
|
||||
@ -1979,6 +2013,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
else:
|
||||
for k in self.averaged_gradients.keys():
|
||||
self.averaged_gradients[k] = None
|
||||
self.all_grad_tensors[k] = None
|
||||
|
||||
see_memory_usage('After overflow after clearing gradients')
|
||||
|
||||
@ -2040,7 +2075,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
self.free_grad_in_param_list(self.params_in_partition[i])
|
||||
|
||||
self.averaged_gradients[i] = None
|
||||
|
||||
self.all_grad_tensors[i] = None
|
||||
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
|
||||
|
||||
self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()
|
||||
|
@ -45,6 +45,13 @@ ZERO_SUPPORTED_OPTIMIZERS = [
|
||||
DeepSpeedCPULion, FusedLion
|
||||
]
|
||||
|
||||
# Add MuonWithAuxAdam to supported list if muon is installed
|
||||
try:
|
||||
from deepspeed.runtime.muon_optimizer import MuonWithAuxAdam
|
||||
ZERO_SUPPORTED_OPTIMIZERS.append(MuonWithAuxAdam)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Add apex FusedAdam to supported list if apex is installed
|
||||
try:
|
||||
import apex
|
||||
|
84
tests/unit/ops/muon/test_muon.py
Normal file
84
tests/unit/ops/muon/test_muon.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright (c) 2025 Peng Du and Zhipeng Wang
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import deepspeed
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from unit.common import DistributedTest
|
||||
from unit.simple_model import SimpleModel
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
if torch.half not in get_accelerator().supported_dtypes():
|
||||
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
|
||||
|
||||
# 'optimizer_type, zero_stage, lr, hidden_dim, nlayer'
|
||||
|
||||
muon_configs = []
|
||||
for optimizer_name in ['muon', 'adam']:
|
||||
for stage in [1, 2]:
|
||||
for lr in [0.01, 0.05]:
|
||||
for model_dim in [32, 128]:
|
||||
for nlayer in [5, 10]:
|
||||
muon_configs.append([optimizer_name, stage, lr, model_dim, nlayer])
|
||||
|
||||
|
||||
def set_muon_flag(params):
|
||||
for p in params:
|
||||
if p.ndim >= 2:
|
||||
setattr(p, "use_muon", True)
|
||||
else:
|
||||
setattr(p, "use_muon", False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer_type, zero_stage, lr, hidden_dim, nlayer', muon_configs)
|
||||
class TestMuonConfigs(DistributedTest):
|
||||
|
||||
def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer):
|
||||
optimizer_params = {"lr": lr}
|
||||
batch_size = 8
|
||||
config_dict = {
|
||||
"train_batch_size": batch_size,
|
||||
"optimizer": {
|
||||
"type": optimizer_type,
|
||||
"params": optimizer_params
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
}
|
||||
}
|
||||
# Perform a few training steps to ensure the optimizer works correctly
|
||||
|
||||
model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayer)
|
||||
if 'muon' in optimizer_type:
|
||||
set_muon_flag(model.parameters())
|
||||
initial_params = [p.clone().cpu() for p in model.parameters()]
|
||||
engine, optimizer, _, _ = deepspeed.initialize(
|
||||
config=config_dict,
|
||||
model=model,
|
||||
model_parameters=model.parameters(),
|
||||
dist_init_required=False,
|
||||
)
|
||||
assert optimizer_type in optimizer.optimizer.__class__.__name__.lower(
|
||||
), f"Expected optimizer type {optimizer_type}, got {optimizer.optimizer.__class__.__name__}"
|
||||
steps = 5
|
||||
for _ in range(steps):
|
||||
# Random inputs: (batch_size, hidden_dim)
|
||||
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
|
||||
# Random class labels: (batch_size,)
|
||||
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
|
||||
# Forward + loss
|
||||
loss = engine(x, y)
|
||||
# Backward
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
# Verify that parameters have been updated
|
||||
after_training = [p.clone().cpu() for p in model.parameters()]
|
||||
for initial, final in zip(initial_params, after_training):
|
||||
assert not torch.equal(initial.cpu(), final.cpu()), "Parameters should have been updated during training"
|
Reference in New Issue
Block a user