Make Muon optimizer easier to enable (#7555)

The original Muon optimizer PR
(https://github.com/deepspeedai/DeepSpeed/pull/7509) requires user to
explicitly set `use_muon` flags in `model.parameters()`, as shown in
test
https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/ops/muon/test_muon.py#L27
.

This PR integrate setting of `use_muon` into DeepSpeed before engine
initialization. This makes Muon optimizer easier to use. User only needs
to change optimizer in `config.json` from `AdamW` to `Muon`, no need to
change code. It will solve the following issue
https://github.com/deepspeedai/DeepSpeed/issues/7552

---------

Signed-off-by: Ma, Guokai <guokai.ma@intel.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
This commit is contained in:
Ma, Guokai
2025-09-17 21:52:11 +08:00
committed by GitHub
parent aa539c6dd5
commit 2585881ae9
2 changed files with 12 additions and 11 deletions

View File

@ -28,7 +28,7 @@ from . import module_inject
from .accelerator import get_accelerator
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, MUON_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
from .runtime.pipe.engine import PipelineEngine
from .inference.engine import InferenceEngine
@ -66,6 +66,15 @@ __git_branch__ = git_branch
dist = None
def set_optimizer_flags(config_class, model):
if config_class.optimizer_name == MUON_OPTIMIZER:
for p in model.parameters():
if p.ndim >= 2:
setattr(p, "use_muon", True)
else:
setattr(p, "use_muon", False)
def initialize(args=None,
model: torch.nn.Module = None,
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
@ -177,6 +186,7 @@ def initialize(args=None,
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"
if not isinstance(model, PipelineModule):
config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
set_optimizer_flags(config_class, model)
if config_class.hybrid_engine.enabled:
engine = DeepSpeedHybridEngine(args=args,
model=model,
@ -206,6 +216,7 @@ def initialize(args=None,
assert mpu is None, "mpu must be None with pipeline parallelism"
mpu = model.mpu()
config_class = DeepSpeedConfig(config, mpu)
set_optimizer_flags(config_class, model)
engine = PipelineEngine(args=args,
model=model,
optimizer=optimizer,

View File

@ -24,14 +24,6 @@ for optimizer_name in ['muon', 'adam']:
muon_configs.append([optimizer_name, stage, lr, model_dim, nlayer])
def set_muon_flag(params):
for p in params:
if p.ndim >= 2:
setattr(p, "use_muon", True)
else:
setattr(p, "use_muon", False)
@pytest.mark.parametrize('optimizer_type, zero_stage, lr, hidden_dim, nlayer', muon_configs)
class TestMuonConfigs(DistributedTest):
@ -55,8 +47,6 @@ class TestMuonConfigs(DistributedTest):
# Perform a few training steps to ensure the optimizer works correctly
model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayer)
if 'muon' in optimizer_type:
set_muon_flag(model.parameters())
initial_params = [p.clone().cpu() for p in model.parameters()]
engine, optimizer, _, _ = deepspeed.initialize(
config=config_dict,