Files
DeepSpeed/deepspeed/runtime/config.py
Zhipeng Wang 66ad278048 Enabling Muon Optimizer in DeepSpeed (#7509)
Authorship: @pengdurice and @PKUWZP 

Related Issue: #7438

# Introduction

[Muon](https://arxiv.org/abs/2502.16982), a new optimizer that has
attracted the community’s attention recently shows promising results in
training large language models. Adding the Muon Optimizer to DeepSpeed,
a popular OSS framework for large scale training and inference is
critically important for DeepSpeed users and developers. There has been
a [PR](https://github.com/deepspeedai/DeepSpeed/pull/7454) attempting
the adoption. (Huge Thanks to @qimcis), which is a good starting point.
It still requires more substantial effort to make it fully compatible
and work within DeepSpeed. We are publishing this PR to fully enable
Muon Optimizer capabilities for DeepSpeed.

# Issues and solutions
## Issues
1. With stage 1, 2 or 3, the optimizer states will be partitioned within
the same data parallel group. This means that each process is already
handling only parts of the model parameters and there is no need to use
the DP solution as in the
[code](https://github.com/KellerJordan/Muon/blob/master/muon.py#L195).
2. The parameters (and the gradients) will be flattened to 1D vector
before being used in the optimizer, thus nullifying the major hypothesis
of the muon optimizer: it works by orthogonalizing the updates for each
matrix (dim >=2)

## Solutions
To solve the issues, we propose this new PR in which: 
1. We simplify the Muon code by
[removing](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-c9052994e41caee9ca88363749c10af08655f8019f08dc971c018663d25a3712R22)
the partitioning and muon updates logics.

2. We
[move](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1867)
the muon update to the
[get_flat_partition](https://github.com/deepspeedai/DeepSpeed/compare/master...pengdurice:DeepSpeed:peng-add-muon-v1#diff-99dcf26ea2876ff5bbf05b5165c4133eaa0d0f36b170685643c2f7e2eb566addR1848)
function of stage 1 and 2 DeepSpeedZeroOptimizer in which per parameter
gradients are collected before being flattened and used by the optimizer
to update the model parameters. Since each parameter is still in its
original shape, we can easily apply the muon updates.
3. We also save the momentum buffer into the optimizer’ state so that we
have a smooth convergence after applying the saved checkpoints.
4. We added comprehensive unit tests to validate Muon Optimizer's
correctness and functionality.

# Future directions and roadmap
In the future, several follow up works are of interests:
- [ ] Create a CPU offload version.
- [ ] Apply Muon to Stage 3
- [ ] Use the highly optimized version of Adam for the Adam part of
MuonWithAuxAdam optimizer.
- [ ] More efficient implementations e.g. a) add specialized kernels for
Newton-Schulz iteration and muon updates; b) parallelize updates for the
parameters (currently, each parameter is updated separately and
sequentially)

---------

Co-authored-by: Peng Du <pedu@linkedin.com>
Co-authored-by: pengdurice <pengduhit@gmail.com>
Co-authored-by: Zhipeng Wang <zhipengbayern@gmail.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
2025-08-26 18:34:35 -07:00

1000 lines
40 KiB
Python
Executable File

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
from typing import Union
from enum import Enum
import torch
import json
import hjson
import copy
import base64
from .constants import *
from .config_utils import (
get_scalar_param,
dict_raise_error_on_duplicate_keys,
ScientificNotationEncoder,
)
from .zero.config import get_zero_config, ZeroStageEnum
from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..inference.config import WeightQuantConfig
from .precision_config import get_bfloat16_config, get_float16_config
from ..compile.config import CompileConfig
from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from ..git_version_info import version as __version__
from ..utils import logger
from ..elasticity import (
elasticity_enabled,
compute_elastic_config,
ensure_immutable_elastic_config,
)
from ..elasticity.config import ElasticityConfigError
from ..elasticity.constants import (
ELASTICITY,
IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT,
MODEL_PARALLEL_SIZE,
MODEL_PARALLEL_SIZE_DEFAULT,
NUM_GPUS_PER_NODE,
NUM_GPUS_PER_NODE_DEFAULT,
)
from ..profiling.config import DeepSpeedFlopsProfilerConfig
from ..autotuning.config import DeepSpeedAutotuningConfig
from ..nebula.config import DeepSpeedNebulaConfig
from ..compression.config import get_compression_config, get_quantize_enabled
from ..compression.constants import *
from .swap_tensor.aio_config import get_aio_config
from .model_checkpointing.config import get_checkpoint_config
from .tensor_parallel import get_tensor_parallel_config
from .data_pipeline.config import get_data_efficiency_enabled, get_data_efficiency_config, get_curriculum_enabled_legacy, get_curriculum_params_legacy
from .data_pipeline.constants import *
from ..utils.config import get_timers_config
TENSOR_CORE_ALIGN_SIZE = 8
ADAGRAD_OPTIMIZER = 'adagrad'
ADAM_OPTIMIZER = 'adam'
ADAMW_OPTIMIZER = 'adamw'
LAMB_OPTIMIZER = 'lamb'
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ZERO_ONE_ADAM_OPTIMIZER = 'zerooneadam'
ONEBIT_LAMB_OPTIMIZER = 'onebitlamb'
MUADAM_OPTIMIZER = 'muadam'
MUADAMW_OPTIMIZER = 'muadamw'
MUSGD_OPTIMIZER = 'musgd'
LION_OPTIMIZER = 'lion'
MUON_OPTIMIZER = 'muon'
DEEPSPEED_OPTIMIZERS = [
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER,
ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER, LION_OPTIMIZER, MUON_OPTIMIZER
]
# extra optimizer parameters for adam/adamw
TORCH_ADAM_PARAM = "torch_adam"
# default to adamw logic for adam/adamw optimizers unless user explicitly opts out
ADAM_W_MODE = "adam_w_mode"
ADAM_W_MODE_DEFAULT = True
class DeepSpeedConfigError(Exception):
pass
class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
int8 = torch.int8, "torch.int8", "int8"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj
def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], PLD_ENABLED, PLD_ENABLED_DEFAULT)
else:
return False
def get_pld_params(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
pld_params = copy.copy(param_dict[PROGRESSIVE_LAYER_DROP])
pld_params.pop(PLD_ENABLED)
return pld_params
else:
return False
def get_amp_enabled(param_dict):
if AMP in param_dict.keys():
return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT)
else:
return False
def get_amp_params(param_dict):
if AMP in param_dict.keys():
amp_params = copy.copy(param_dict[AMP])
amp_params.pop(AMP_ENABLED)
return amp_params
else:
return False
def get_torch_autocast_enabled(param_dict):
if TORCH_AUTOCAST in param_dict.keys():
return get_scalar_param(param_dict[TORCH_AUTOCAST], TORCH_AUTOCAST_ENABLED, TORCH_AUTOCAST_ENABLED_DEFAULT)
else:
return False
def get_torch_autocast_dtype(param_dict):
if TORCH_AUTOCAST in param_dict:
if TORCH_AUTOCAST_DTYPE in param_dict[TORCH_AUTOCAST]:
try:
return DtypeEnum(param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]).value
except KeyError:
raise ValueError(
f"Invalid dtype for torch autocast: {param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]}")
return None
def get_lower_precision_safe_modules(param_dict):
if TORCH_AUTOCAST in param_dict:
if TORCH_AUTOCAST_LOWER_PRECISION_SAFE_MODULES in param_dict[TORCH_AUTOCAST]:
module_names_with_package = param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_LOWER_PRECISION_SAFE_MODULES]
if not all(isinstance(module_name, str) for module_name in module_names_with_package):
raise ValueError(
f"Invalid module names for torch autocast: {module_names_with_package}. Expected list of strings.")
return module_names_with_package
return None
def get_gradient_accumulation_steps(param_dict):
return get_scalar_param(param_dict, GRADIENT_ACCUMULATION_STEPS, GRADIENT_ACCUMULATION_STEPS_DEFAULT)
def get_sparse_gradients_enabled(param_dict):
return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT)
def get_communication_data_type(param_dict,
comm_type=COMMUNICATION_DATA_TYPE,
comm_data_type_default=COMMUNICATION_DATA_TYPE_DEFAULT):
val = get_scalar_param(param_dict, comm_type, comm_data_type_default)
val = val.lower() if val is not None else val
if val is None:
return val # we must determine it by other parameters
elif val == "fp32":
return torch.float32
elif val == "fp16":
return torch.float16
elif val == "bf16":
return torch.bfloat16
raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bf16', 'fp32']. Got: {val}")
def get_prescale_gradients(param_dict):
return get_scalar_param(param_dict, PRESCALE_GRADIENTS, PRESCALE_GRADIENTS_DEFAULT)
def get_gradient_predivide_factor(param_dict):
return get_scalar_param(param_dict, GRADIENT_PREDIVIDE_FACTOR, GRADIENT_PREDIVIDE_FACTOR_DEFAULT)
def get_steps_per_print(param_dict):
return get_scalar_param(param_dict, STEPS_PER_PRINT, STEPS_PER_PRINT_DEFAULT)
def get_disable_allgather(param_dict):
return get_scalar_param(param_dict, DISABLE_ALLGATHER, DISABLE_ALLGATHER_DEFAULT)
def get_dump_state(param_dict):
return get_scalar_param(param_dict, DUMP_STATE, DUMP_STATE_DEFAULT)
def get_gradient_clipping(param_dict):
return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT)
def get_graph_harvesting(param_dict):
return get_scalar_param(param_dict, GRAPH_HARVESTING, GRAPH_HARVESTING_DEFAULT)
def get_sparse_attention(param_dict):
if SPARSE_ATTENTION in param_dict.keys():
sparsity = param_dict[SPARSE_ATTENTION]
mode = get_sparse_attention_mode(sparsity)
if mode == SPARSE_DENSE_MODE:
return get_sparse_dense_config(sparsity)
elif mode == SPARSE_FIXED_MODE:
return get_sparse_fixed_config(sparsity)
elif mode == SPARSE_VARIABLE_MODE:
return get_sparse_variable_config(sparsity)
elif mode == SPARSE_BIGBIRD_MODE:
return get_sparse_bigbird_config(sparsity)
elif mode == SPARSE_BSLONGFORMER_MODE:
return get_sparse_bslongformer_config(sparsity)
else:
raise NotImplementedError(f"Given sparsity mode, {mode}, has not been implemented yet!")
else:
return None
def get_sparse_dense_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
return {SPARSE_MODE: SPARSE_DENSE_MODE, SPARSE_BLOCK: block}
def get_sparse_fixed_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
)
num_local_blocks = get_scalar_param(sparsity, SPARSE_NUM_LOCAL_BLOCKS, SPARSE_NUM_LOCAL_BLOCKS_DEFAULT)
num_global_blocks = get_scalar_param(sparsity, SPARSE_NUM_GLOBAL_BLOCKS, SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
attention = get_scalar_param(sparsity, SPARSE_ATTENTION_TYPE, SPARSE_ATTENTION_TYPE_DEFAULT)
horizontal_global_attention = get_scalar_param(
sparsity,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT,
)
num_different_global_patterns = get_scalar_param(
sparsity,
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS,
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT,
)
return {
SPARSE_MODE: SPARSE_FIXED_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_LOCAL_BLOCKS: num_local_blocks,
SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks,
SPARSE_ATTENTION_TYPE: attention,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention,
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS: num_different_global_patterns,
}
def get_sparse_variable_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
)
num_random_blocks = get_scalar_param(sparsity, SPARSE_NUM_RANDOM_BLOCKS, SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
local_window_blocks = get_scalar_param(sparsity, SPARSE_LOCAL_WINDOW_BLOCKS, SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT)
global_block_indices = get_scalar_param(sparsity, SPARSE_GLOBAL_BLOCK_INDICES, SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
global_block_end_indices = get_scalar_param(
sparsity,
SPARSE_GLOBAL_BLOCK_END_INDICES,
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT,
)
attention = get_scalar_param(sparsity, SPARSE_ATTENTION_TYPE, SPARSE_ATTENTION_TYPE_DEFAULT)
horizontal_global_attention = get_scalar_param(
sparsity,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT,
)
return {
SPARSE_MODE: SPARSE_VARIABLE_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
SPARSE_LOCAL_WINDOW_BLOCKS: local_window_blocks,
SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices,
SPARSE_ATTENTION_TYPE: attention,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention,
}
def get_sparse_bigbird_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
)
num_random_blocks = get_scalar_param(sparsity, SPARSE_NUM_RANDOM_BLOCKS, SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
num_sliding_window_blocks = get_scalar_param(
sparsity,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT,
)
num_global_blocks = get_scalar_param(sparsity, SPARSE_NUM_GLOBAL_BLOCKS, SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
return {
SPARSE_MODE: SPARSE_BIGBIRD_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks,
}
def get_sparse_bslongformer_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
)
num_sliding_window_blocks = get_scalar_param(
sparsity,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT,
)
global_block_indices = get_scalar_param(sparsity, SPARSE_GLOBAL_BLOCK_INDICES, SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
global_block_end_indices = get_scalar_param(
sparsity,
SPARSE_GLOBAL_BLOCK_END_INDICES,
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT,
)
return {
SPARSE_MODE: SPARSE_BSLONGFORMER_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices,
}
def get_sparse_attention_mode(param_dict):
if SPARSE_MODE in param_dict.keys():
return param_dict[SPARSE_MODE]
else:
return SPARSE_MODE_DEFAULT
def get_sparse_attention_type(param_dict):
if SPARSE_ATTENTION_TYPE in param_dict.keys():
return param_dict[SPARSE_ATTENTION_TYPE]
else:
return SPARSE_ATTENTION_TYPE_DEFAULT
def get_pipeline_config(param_dict):
"""Parses pipeline engine configuration. """
default_pipeline = {
"stages": "auto",
"partition": "best",
"seed_layers": False,
"activation_checkpoint_interval": 0,
"pipe_partitioned": True,
"grad_partitioned": True,
}
config = default_pipeline
for key, val in param_dict.get("pipeline", {}).items():
config[key] = val
return config
def get_optimizer_name(param_dict):
if OPTIMIZER in param_dict.keys() and TYPE in param_dict[OPTIMIZER].keys():
return param_dict[OPTIMIZER][TYPE]
else:
return OPTIMIZER_TYPE_DEFAULT
def get_optimizer_params(param_dict):
if (get_optimizer_name(param_dict) is not None and OPTIMIZER_PARAMS in param_dict[OPTIMIZER].keys()):
return param_dict[OPTIMIZER][OPTIMIZER_PARAMS]
else:
return None
def get_optimizer_gradient_clipping(param_dict):
optimizer_params = get_optimizer_params(param_dict)
if optimizer_params is not None and MAX_GRAD_NORM in optimizer_params.keys():
return optimizer_params[MAX_GRAD_NORM]
else:
return None
def get_optimizer_legacy_fusion(param_dict):
if OPTIMIZER in param_dict.keys() and LEGACY_FUSION in param_dict[OPTIMIZER].keys():
return param_dict[OPTIMIZER][LEGACY_FUSION]
else:
return LEGACY_FUSION_DEFAULT
def get_zero_allow_untested_optimizer(param_dict):
return get_scalar_param(param_dict, ZERO_ALLOW_UNTESTED_OPTIMIZER, ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT)
def get_zero_force_ds_cpu_optimizer(param_dict):
return get_scalar_param(param_dict, ZERO_FORCE_DS_CPU_OPTIMIZER, ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT)
def get_scheduler_name(param_dict):
if SCHEDULER in param_dict.keys() and TYPE in param_dict[SCHEDULER].keys():
return param_dict[SCHEDULER][TYPE]
else:
return SCHEDULER_TYPE_DEFAULT
def get_scheduler_params(param_dict):
if (get_scheduler_name(param_dict) is not None and SCHEDULER_PARAMS in param_dict[SCHEDULER].keys()):
return param_dict[SCHEDULER][SCHEDULER_PARAMS]
else:
return None
def get_train_batch_size(param_dict):
return get_scalar_param(param_dict, TRAIN_BATCH_SIZE, TRAIN_BATCH_SIZE_DEFAULT)
def get_train_micro_batch_size_per_gpu(param_dict):
return get_scalar_param(
param_dict,
TRAIN_MICRO_BATCH_SIZE_PER_GPU,
TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT,
)
def get_wall_clock_breakdown(param_dict):
return get_scalar_param(param_dict, WALL_CLOCK_BREAKDOWN, WALL_CLOCK_BREAKDOWN_DEFAULT)
def get_memory_breakdown(param_dict):
return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
class HybridEngineConfig(DeepSpeedConfigModel):
enabled: bool = False
max_out_tokens: int = 512
inference_tp_size: int = 1
release_inference_cache: bool = False
pin_parameters: bool = True
tp_gather_partition_size: int = 8
def get_hybrid_engine_config(param_dict):
hybrid_engine_config_dict = param_dict.get("hybrid_engine", {})
hybrid_engine_config = HybridEngineConfig(**hybrid_engine_config_dict)
return hybrid_engine_config
def get_expert_data_topo_config(param_dict):
return get_scalar_param(param_dict, USE_DATA_BEFORE_EXPERT_PARALLEL, USE_DATA_BEFORE_EXPERT_PARALLEL_DEFAULT)
def get_eigenvalue_config(param_dict):
if get_quantize_enabled(param_dict):
param_dict = param_dict[QUANTIZE_TRAINING]
assert not get_eigenvalue_enabled(param_dict), "Eigenvalue based MoQ is temporarily disabled"
return (
get_eigenvalue_enabled(param_dict),
get_eigenvalue_verbose(param_dict),
get_eigenvalue_max_iter(param_dict),
get_eigenvalue_tol(param_dict),
get_eigenvalue_stability(param_dict),
get_eigenvalue_gas_boundary_resolution(param_dict),
get_eigenvalue_layer_name(param_dict),
get_eigenvalue_layer_num(param_dict),
)
else:
return (
EIGENVALUE_ENABLED_DEFAULT,
EIGENVALUE_VERBOSE_DEFAULT,
EIGENVALUE_MAX_ITER_DEFAULT,
EIGENVALUE_TOL_DEFAULT,
EIGENVALUE_STABILITY_DEFAULT,
EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT,
EIGENVALUE_LAYER_NAME_DEFAULT,
EIGENVALUE_LAYER_NUM_DEFAULT,
)
def get_eigenvalue_enabled(param_dict):
if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_ENABLED, EIGENVALUE_ENABLED_DEFAULT)
else:
return EIGENVALUE_ENABLED_DEFAULT
def get_eigenvalue_verbose(param_dict):
if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_VERBOSE, EIGENVALUE_VERBOSE_DEFAULT)
else:
return EIGENVALUE_VERBOSE_DEFAULT
def get_eigenvalue_max_iter(param_dict):
if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_MAX_ITER, EIGENVALUE_MAX_ITER_DEFAULT)
else:
return EIGENVALUE_MAX_ITER_DEFAULT
def get_eigenvalue_tol(param_dict):
if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_TOL, EIGENVALUE_TOL_DEFAULT)
else:
return EIGENVALUE_TOL_DEFAULT
def get_eigenvalue_stability(param_dict):
if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_STABILITY, EIGENVALUE_STABILITY_DEFAULT)
else:
return EIGENVALUE_STABILITY_DEFAULT
def get_eigenvalue_gas_boundary_resolution(param_dict):
if EIGENVALUE in param_dict.keys():
return get_scalar_param(
param_dict[EIGENVALUE],
EIGENVALUE_GAS_BOUNDARY_RESOLUTION,
EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT,
)
else:
return EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT
def get_eigenvalue_layer_name(param_dict):
if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_LAYER_NAME, EIGENVALUE_LAYER_NAME_DEFAULT)
else:
return EIGENVALUE_LAYER_NAME_DEFAULT
def get_eigenvalue_layer_num(param_dict):
if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_LAYER_NUM, EIGENVALUE_LAYER_NUM_DEFAULT)
else:
return EIGENVALUE_LAYER_NUM_DEFAULT
def get_checkpoint_params(param_dict):
return param_dict.get(CHECKPOINT, {})
def get_data_types_params(param_dict):
return param_dict.get(DATA_TYPES, {})
def get_checkpoint_tag_validation_mode(checkpoint_params):
tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION, CHECKPOINT_TAG_VALIDATION_DEFAULT)
tag_validation_mode = tag_validation_mode.upper()
if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES:
return tag_validation_mode
else:
raise DeepSpeedConfigError(
"Checkpoint config contains invalid tag_validation "
f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}")
def get_checkpoint_parallel_write_pipeline(checkpoint_params):
par_write_params = checkpoint_params.get(CHECKPOINT_PARALLEL_WRITE, {})
par_write_pipeline = par_write_params.get(CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE,
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT)
if par_write_pipeline in [True, False]:
return par_write_pipeline
else:
raise DeepSpeedConfigError("checkpoint::parallel_write::pipeline_stage "
f"value of '{par_write_pipeline}' is invalid, expecting: true or false")
def get_dataloader_drop_last(param_dict):
return get_scalar_param(param_dict, DATALOADER_DROP_LAST, DATALOADER_DROP_LAST_DEFAULT)
'''Write deepspeed config files by modifying basic templates.
Can be used for quickly changing parameters via command line parameters.'''
class DeepSpeedConfigWriter:
def __init__(self, data=None):
self.data = data if data is not None else {}
def add_config(self, key, value):
self.data[key] = value
def load_config(self, filename):
self.data = json.load(open(filename, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
def write_config(self, filename):
with open(filename, "w") as outfile:
json.dump(self.data, outfile)
class DeepSpeedConfig(object):
def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None):
super(DeepSpeedConfig, self).__init__()
if isinstance(config, dict):
self._param_dict = config
elif os.path.exists(config):
self._param_dict = hjson.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
else:
try:
config_decoded = base64.urlsafe_b64decode(config).decode('utf-8')
self._param_dict = hjson.loads(config_decoded)
except (UnicodeDecodeError, AttributeError):
raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}"
)
try:
self.global_rank = dist.get_rank()
if mpu is not None:
# Ulysses SP
if not hasattr(mpu, "get_data_parallel_world_size"):
self.world_size = dist.get_world_size() / mpu.get_sequence_parallel_world_size()
else:
self.world_size = mpu.get_data_parallel_world_size()
elif mesh_device is not None:
self.world_size = dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
else:
# HF zero.init case where there is no mpu
if "sequence_parallel_size" in config:
self.world_size = dist.get_world_size() / config["sequence_parallel_size"]
else:
self.world_size = dist.get_world_size()
except:
self.global_rank = 0
self.world_size = 1
logger.info(f"Config mesh_device {mesh_device} world_size = {self.world_size}")
# If elastic-mode enabled, update compute + update _param_dict
self.elasticity_enabled = elasticity_enabled(self._param_dict)
if self.elasticity_enabled:
logger.info("DeepSpeed elasticity support enabled")
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(
ds_config=self._param_dict,
target_deepspeed_version=__version__,
world_size=self.world_size,
)
elastic_dict = self._param_dict[ELASTICITY]
# Ensure the resource scheduler saw the same elastic config we are using at runtime
ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict)
self.elastic_model_parallel_size = elastic_dict.get(MODEL_PARALLEL_SIZE, MODEL_PARALLEL_SIZE_DEFAULT)
if self.elastic_model_parallel_size < 1:
raise ElasticityConfigError("Model-Parallel size cannot be less than 1, "
f"given model-parallel size: {self.elastic_model_parallel_size}")
self.num_gpus_per_node = elastic_dict.get(NUM_GPUS_PER_NODE, NUM_GPUS_PER_NODE_DEFAULT)
if self.num_gpus_per_node < 1:
raise ElasticityConfigError("NUmber of GPUs per node cannot be less than 1, "
f"given number of GPUs per node: {self.num_gpus_per_node}")
ignore_non_elastic_batch_info = elastic_dict.get(IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
if not ignore_non_elastic_batch_info:
batch_params = [
TRAIN_BATCH_SIZE,
TRAIN_MICRO_BATCH_SIZE_PER_GPU,
GRADIENT_ACCUMULATION_STEPS,
]
if any(map(lambda t: t in self._param_dict, batch_params)):
raise ElasticityConfigError("One or more batch related parameters were found in your " \
f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \
f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \
"elastic training is enabled, which takes control of these parameters. " \
"If you want to suppress this error (the parameters will be silently ignored) " \
f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.")
# micro_bsz * world_size * gas = total_batch_size
# gas = total_batch_size // (micro_bsz * world_size)
gradient_accu_steps = final_batch_size // (micro_batch_size * self.world_size)
if TRAIN_BATCH_SIZE in self._param_dict:
logger.warning("[Elasticity] overriding training_batch_size: "
f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}")
if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict:
logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: "
f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}")
if GRADIENT_ACCUMULATION_STEPS in self._param_dict:
logger.warning("[Elasticity] overriding gradient_accumulation_steps: "
f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}")
logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}")
self._param_dict[TRAIN_BATCH_SIZE] = final_batch_size
self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = micro_batch_size
self._param_dict[GRADIENT_ACCUMULATION_STEPS] = gradient_accu_steps
# Pass a copy so that user json is unmodified, e.g. for logging
self._initialize_params(copy.copy(self._param_dict))
self._configure_train_batch_size()
self._do_sanity_check()
def _initialize_params(self, param_dict):
self.train_batch_size = get_train_batch_size(param_dict)
self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu(param_dict)
self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict)
self.steps_per_print = get_steps_per_print(param_dict)
self.dump_state = get_dump_state(param_dict)
self.disable_allgather = get_disable_allgather(param_dict)
self.communication_data_type = get_communication_data_type(param_dict)
self.seq_parallel_communication_data_type = get_communication_data_type(
param_dict, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT)
self.prescale_gradients = get_prescale_gradients(param_dict)
self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict)
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
self.zero_config = get_zero_config(param_dict)
self.mics_shard_size = self.zero_config.mics_shard_size
self.mics_hierarchial_params_gather = self.zero_config.mics_hierarchical_params_gather
self.zero_optimization_stage = self.zero_config.stage
self.zero_enabled = self.zero_optimization_stage > 0
self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(param_dict)
self.comms_config = DeepSpeedCommsConfig(param_dict)
self.monitor_config = get_monitor_config(param_dict)
self.gradient_clipping = get_gradient_clipping(param_dict)
self.float16_config = get_float16_config(param_dict)
self.bfloat16_config = get_bfloat16_config(param_dict)
assert not (self.float16_config.enabled
and self.bfloat16_config.enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
self.amp_enabled = get_amp_enabled(param_dict)
self.amp_params = get_amp_params(param_dict)
self.torch_autocast_enabled = get_torch_autocast_enabled(param_dict)
self.torch_autocast_dtype = get_torch_autocast_dtype(param_dict)
self.torch_autocast_lower_precision_safe_modules = get_lower_precision_safe_modules(param_dict)
self.compression_config = get_compression_config(param_dict)
self.graph_harvesting = get_graph_harvesting(param_dict)
self.optimizer_name = get_optimizer_name(param_dict)
if (self.optimizer_name is not None and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS):
self.optimizer_name = self.optimizer_name.lower()
self.optimizer_params = get_optimizer_params(param_dict)
self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict)
self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer(param_dict)
self.zero_force_ds_cpu_optimizer = get_zero_force_ds_cpu_optimizer(param_dict)
self.scheduler_name = get_scheduler_name(param_dict)
self.scheduler_params = get_scheduler_params(param_dict)
self.flops_profiler_config = DeepSpeedFlopsProfilerConfig(param_dict)
self.wall_clock_breakdown = (get_wall_clock_breakdown(param_dict) | self.flops_profiler_config.enabled)
self.memory_breakdown = get_memory_breakdown(param_dict)
self.autotuning_config = DeepSpeedAutotuningConfig(param_dict)
(
self.eigenvalue_enabled,
self.eigenvalue_verbose,
self.eigenvalue_max_iter,
self.eigenvalue_tol,
self.eigenvalue_stability,
self.eigenvalue_gas_boundary_resolution,
self.eigenvalue_layer_name,
self.eigenvalue_layer_num,
) = get_eigenvalue_config(param_dict)
self.use_data_before_expert_parallel_ = get_expert_data_topo_config(param_dict)
self.hybrid_engine = get_hybrid_engine_config(param_dict)
self.sparse_attention = get_sparse_attention(param_dict)
self.pipeline = get_pipeline_config(param_dict)
self.pld_enabled = get_pld_enabled(param_dict)
self.pld_params = get_pld_params(param_dict)
self.curriculum_enabled_legacy = get_curriculum_enabled_legacy(param_dict)
self.curriculum_params_legacy = get_curriculum_params_legacy(param_dict)
self.data_efficiency_enabled = get_data_efficiency_enabled(param_dict)
self.data_efficiency_config = get_data_efficiency_config(param_dict)
checkpoint_params = get_checkpoint_params(param_dict)
validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params)
self.checkpoint_tag_validation_enabled = (validation_mode != ValidationMode.IGNORE)
self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL
self.load_universal_checkpoint = checkpoint_params.get(LOAD_UNIVERSAL_CHECKPOINT,
LOAD_UNIVERSAL_CHECKPOINT_DEFAULT)
self.use_node_local_storage = checkpoint_params.get(USE_NODE_LOCAL_STORAGE_CHECKPOINT,
USE_NODE_LOCAL_STORAGE_CHECKPOINT_DEFAULT)
data_types_params = get_data_types_params(param_dict)
self.grad_accum_dtype = data_types_params.get(GRAD_ACCUM_DTYPE, GRAD_ACCUM_DTYPE_DEFAULT)
par_write_pipe = get_checkpoint_parallel_write_pipeline(checkpoint_params)
self.checkpoint_parallel_write_pipeline = par_write_pipe
self.aio_config = get_aio_config(param_dict)
self.dataloader_drop_last = get_dataloader_drop_last(param_dict)
self.nebula_config = DeepSpeedNebulaConfig(param_dict)
self.checkpoint_config = get_checkpoint_config(param_dict)
self.weight_quantization_config = WeightQuantConfig(
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None
self.compile_config = CompileConfig(**param_dict.get('compile', {}))
self.timers_config = get_timers_config(param_dict)
self.tensor_parallel_config = get_tensor_parallel_config(param_dict)
def _batch_assertion(self):
train_batch = self.train_batch_size
micro_batch = self.train_micro_batch_size_per_gpu
grad_acc = self.gradient_accumulation_steps
assert (train_batch > 0), f"Train batch size: {train_batch} has to be greater than 0"
assert (micro_batch > 0), f"Micro batch size per gpu: {micro_batch} has to be greater than 0"
assert (grad_acc > 0), f"Gradient accumulation steps: {grad_acc} has to be greater than 0"
assert train_batch == micro_batch * grad_acc * self.world_size, (
f"Check batch related parameters. train_batch_size is not equal "
"to micro_batch_per_gpu * gradient_acc_step * world_size "
f"{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}")
def _set_batch_related_parameters(self):
train_batch = self.train_batch_size
micro_batch = self.train_micro_batch_size_per_gpu
grad_acc = self.gradient_accumulation_steps
#print(f"in: train_batch = {train_batch}, micro_batch={micro_batch}")
# all values are provided nothing needs to be set
if train_batch is not None and micro_batch is not None and grad_acc is not None:
return
# global_accumulation_steps needs to be set
elif train_batch is not None and micro_batch is not None:
grad_acc = train_batch // micro_batch
grad_acc //= self.world_size
self.gradient_accumulation_steps = grad_acc
# micro_batch_per_gpu needs to be set
elif train_batch is not None and grad_acc is not None:
micro_batch = train_batch // self.world_size
micro_batch //= grad_acc
self.train_micro_batch_size_per_gpu = micro_batch
# train_batch_size needs to be set
elif micro_batch is not None and grad_acc is not None:
train_batch_size = micro_batch * grad_acc
train_batch_size *= self.world_size
self.train_batch_size = train_batch_size
# gradient_accumulation_steps and micro_batch_per_gpus is set
elif train_batch is not None:
self.gradient_accumulation_steps = 1
self.train_micro_batch_size_per_gpu = train_batch // self.world_size
# train_batch_size and gradient_accumulation_step is set
elif micro_batch is not None:
self.train_batch_size = micro_batch * self.world_size
self.gradient_accumulation_steps = 1
# either none of the three parameters are provided or just gradient_accumulation_step is provided
else:
assert False, \
'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided'
#print(f"final: {self.train_batch_size=} {self.train_micro_batch_size_per_gpu=} {self.gradient_accumulation_steps=}")
def _configure_train_batch_size(self):
self._set_batch_related_parameters()
self._batch_assertion()
def _do_sanity_check(self):
self._do_error_check()
self._do_warning_check()
def print_user_config(self):
logger.info(" json = {}".format(
json.dumps(
self._param_dict,
sort_keys=True,
indent=4,
cls=ScientificNotationEncoder,
separators=(",", ":"),
)))
def print(self, name):
logger.info("{}:".format(name))
for arg in sorted(vars(self)):
if arg != "_param_dict":
dots = "." * (29 - len(arg))
logger.info(" {} {} {}".format(arg, dots, getattr(self, arg)))
self.print_user_config()
def _do_error_check(self):
assert (self.train_micro_batch_size_per_gpu
), "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
assert (
self.gradient_accumulation_steps), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS)
if self.zero_enabled:
assert (self.zero_optimization_stage
<= ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(
ZeroStageEnum.max_stage)
if self.float16_config.fp16_master_weights_and_grads:
assert self.zero_enabled and self.zero_optimization_stage == ZeroStageEnum.gradients, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now."
def _do_warning_check(self):
fp16_enabled = self.float16_config.enabled
vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT)
if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0:
logger.warning(
"DeepSpeedConfig: vocabulary size {} is not aligned to {}, may import tensor core utilization.".format(
vocabulary_size, TENSOR_CORE_ALIGN_SIZE))
if (self.optimizer_params is not None and MAX_GRAD_NORM in self.optimizer_params.keys()
and self.optimizer_params[MAX_GRAD_NORM] > 0):
if fp16_enabled:
if self.global_rank == 0:
logger.warning("DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper".format(
MAX_GRAD_NORM, self.optimizer_params[MAX_GRAD_NORM]))
else:
if self.global_rank == 0:
logger.warning(
"DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero"
.format(self.optimizer_params[MAX_GRAD_NORM]))
self.optimizer_params[MAX_GRAD_NORM] = 0.0