mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Improve leaf module interface (enable via config, relax matching criteria, add document, etc.) (#7604)
This PR improves the usability of the leaf module feature. Here are the changes: - Allow enabling the leaf module via both the DeepSpeed config and APIs. - Relax matching criteria to support class-based matching. - Support multiple ways of specifying the target module: class, class name (with or without package name), module name, or suffix. - Add documentation to the training guide, including config snippets and explanations of default behavior. - Add default classes (e.g., Mixtral, Qwen2/Qwen3) that automatically enable the leaf module feature. (Welcoming requests to add more classes) --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -76,6 +76,7 @@ from deepspeed.runtime.sparse_tensor import SparseTensor
|
||||
from deepspeed.runtime import lr_schedules
|
||||
from deepspeed.utils import groups
|
||||
from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx
|
||||
from deepspeed.utils.z3_leaf_module import apply_zero_leaf_module_config
|
||||
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
|
||||
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
|
||||
STEP_MICRO_TIMER, \
|
||||
@ -1293,6 +1294,7 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
def _configure_distributed_model(self, model):
|
||||
self._set_client_model(model)
|
||||
apply_zero_leaf_module_config(self.module, getattr(self._config.zero_config, "leaf_module", None))
|
||||
is_zero_init_model = self.zero_optimization_partition_weights() and any(
|
||||
[hasattr(param, "ds_id") for param in self.module.parameters()])
|
||||
|
||||
|
@ -11,6 +11,7 @@ from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedCo
|
||||
from deepspeed.utils import logger
|
||||
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
|
||||
from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig
|
||||
from .leaf_module_config import DeepSpeedZeroLeafModuleConfig
|
||||
|
||||
# ZeRO optimization. By default, this optimization is not enabled.
|
||||
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
|
||||
@ -356,6 +357,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
||||
Enable internal sanity checks, which could be useful for debugging
|
||||
"""
|
||||
|
||||
leaf_module: DeepSpeedZeroLeafModuleConfig = Field(default_factory=DeepSpeedZeroLeafModuleConfig)
|
||||
"""
|
||||
Configuration for modules that should be treated as ZeRO3 leaf modules.
|
||||
"""
|
||||
|
||||
# Validators
|
||||
@model_validator(mode="after")
|
||||
def overlap_comm_valid(self):
|
||||
|
52
deepspeed/runtime/zero/leaf_module_config.py
Normal file
52
deepspeed/runtime/zero/leaf_module_config.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
|
||||
DEFAULT_LEAF_MODULE_CLASSES: List[str] = [
|
||||
"transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock",
|
||||
"transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock",
|
||||
"transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock",
|
||||
]
|
||||
DEFAULT_LEAF_MODULE_NAMES: List[str] = []
|
||||
DEFAULT_LEAF_MODULE_NAME_SUFFIXES: List[str] = []
|
||||
|
||||
|
||||
class DeepSpeedZeroLeafModuleConfig(DeepSpeedConfigModel):
|
||||
"""Configuration for ZeRO leaf modules that should bypass hook installation."""
|
||||
|
||||
classes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_CLASSES))
|
||||
names: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAMES))
|
||||
name_suffixes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAME_SUFFIXES))
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _coerce_container_types(cls, values):
|
||||
if values is None:
|
||||
return {}
|
||||
if isinstance(values, dict):
|
||||
coerced = dict(values)
|
||||
for key in ("classes", "names", "name_suffixes"):
|
||||
if key in coerced and isinstance(coerced[key], str):
|
||||
coerced[key] = [coerced[key]]
|
||||
return coerced
|
||||
raise TypeError("leaf_module configuration must be a mapping of fields to values")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_entries(self):
|
||||
normalized_classes = [str(cls) for cls in self.classes]
|
||||
normalized_names = [str(name) for name in self.names]
|
||||
normalized_suffixes = [str(suffix) for suffix in self.name_suffixes]
|
||||
|
||||
deduped_classes = list(dict.fromkeys(normalized_classes))
|
||||
deduped_names = list(dict.fromkeys(normalized_names))
|
||||
deduped_suffixes = list(dict.fromkeys(normalized_suffixes))
|
||||
|
||||
object.__setattr__(self, "classes", deduped_classes)
|
||||
object.__setattr__(self, "names", deduped_names)
|
||||
object.__setattr__(self, "name_suffixes", deduped_suffixes)
|
||||
return self
|
@ -17,7 +17,7 @@ from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_s
|
||||
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
|
||||
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
|
||||
from .tensor_fragment import safe_update_full_grad_vectorized
|
||||
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module
|
||||
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module, set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix
|
||||
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
|
||||
from deepspeed.runtime.dataloader import RepeatingLoader
|
||||
from .numa import get_numactl_cmd
|
||||
|
@ -4,7 +4,12 @@
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
from typing import List, Type, Union
|
||||
from typing import List, Tuple, Type, Union, Optional, TYPE_CHECKING
|
||||
|
||||
from .logging import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepspeed.runtime.zero.leaf_module_config import DeepSpeedZeroLeafModuleConfig
|
||||
|
||||
|
||||
def z3_leaf_module(model: torch.nn.Module) -> bool:
|
||||
@ -44,50 +49,201 @@ def set_z3_leaf_module(model: torch.nn.Module, flag: bool):
|
||||
model._z3_leaf = flag
|
||||
|
||||
|
||||
def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]],
|
||||
flag: bool) -> List[torch.nn.Module]:
|
||||
assert all(isinstance(module_class, (type, str) ) for module_class in leaf_module_classes), \
|
||||
def _fully_qualified_class_name(module: torch.nn.Module) -> str:
|
||||
cls = module.__class__
|
||||
return f"{cls.__module__}.{cls.__qualname__}"
|
||||
|
||||
|
||||
def _do_set_z3_leaf_modules(model: torch.nn.Module,
|
||||
leaf_module_classes: Union[List[Type], List[str]],
|
||||
flag: bool,
|
||||
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
|
||||
assert all(isinstance(module_class, (type, str)) for module_class in leaf_module_classes), \
|
||||
f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}'
|
||||
|
||||
leaf_modules = []
|
||||
leaf_modules: List[torch.nn.Module] = []
|
||||
|
||||
def _set_z3_leaf_flag(model: torch.nn.Module):
|
||||
def _set_z3_leaf_flag(module_instance: torch.nn.Module):
|
||||
nonlocal leaf_modules
|
||||
for module in leaf_module_classes:
|
||||
if (isinstance(module, type) and model.__class__ == module) or \
|
||||
(isinstance(module, str) and model.__class__.__name__ == module):
|
||||
model._z3_leaf = flag
|
||||
leaf_modules.append(model)
|
||||
if isinstance(module, type) and isinstance(module_instance, module):
|
||||
module_instance._z3_leaf = flag
|
||||
leaf_modules.append(module_instance)
|
||||
break
|
||||
|
||||
if isinstance(module, str):
|
||||
if (module_instance.__class__.__name__ == module
|
||||
or _fully_qualified_class_name(module_instance) == module):
|
||||
module_instance._z3_leaf = flag
|
||||
leaf_modules.append(module_instance)
|
||||
break
|
||||
|
||||
model.apply(_set_z3_leaf_flag)
|
||||
|
||||
if len(leaf_modules) == 0:
|
||||
if len(leaf_modules) == 0 and raise_if_not_found:
|
||||
raise ValueError(f'No modules of type {leaf_module_classes} found in model {model}')
|
||||
|
||||
return leaf_modules
|
||||
|
||||
|
||||
def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type],
|
||||
List[str]]) -> List[torch.nn.Module]:
|
||||
def set_z3_leaf_modules_by_name(model: torch.nn.Module,
|
||||
module_names: List[str],
|
||||
flag: bool = True,
|
||||
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
|
||||
"""Sets a leaf flag for modules referenced by their names in ``model.named_modules()``.
|
||||
Args:
|
||||
model (torch.nn.Module): The model containing the modules to update.
|
||||
module_names (List[str]): Module names as returned by ``named_modules()``.
|
||||
flag (bool): Desired flag state.
|
||||
raise_if_not_found (bool): Whether to raise when no module matches a provided name.
|
||||
Returns:
|
||||
Tuple[List[torch.nn.Module], List[str]]: Matched modules and missing module names.
|
||||
"""
|
||||
modules_by_name = dict(model.named_modules())
|
||||
leaf_modules: List[torch.nn.Module] = []
|
||||
missing: List[str] = []
|
||||
|
||||
for name in module_names:
|
||||
module = modules_by_name.get(name)
|
||||
if module is None:
|
||||
missing.append(name)
|
||||
continue
|
||||
module._z3_leaf = flag
|
||||
leaf_modules.append(module)
|
||||
|
||||
if missing and raise_if_not_found:
|
||||
raise ValueError(f'No modules named {missing} found in model {model}')
|
||||
|
||||
return leaf_modules, missing
|
||||
|
||||
|
||||
def set_z3_leaf_modules_by_suffix(model: torch.nn.Module,
|
||||
module_name_suffixes: List[str],
|
||||
flag: bool = True,
|
||||
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
|
||||
"""Sets a leaf flag for modules referenced by suffixes of ``model.named_modules()`` names."""
|
||||
modules_by_name = dict(model.named_modules())
|
||||
leaf_modules: List[torch.nn.Module] = []
|
||||
missing: List[str] = []
|
||||
seen_ids = set()
|
||||
|
||||
for suffix in module_name_suffixes:
|
||||
matched = False
|
||||
for name, module in modules_by_name.items():
|
||||
if name.endswith(suffix):
|
||||
module._z3_leaf = flag
|
||||
module_id = id(module)
|
||||
if module_id not in seen_ids:
|
||||
seen_ids.add(module_id)
|
||||
leaf_modules.append(module)
|
||||
matched = True
|
||||
if not matched:
|
||||
missing.append(suffix)
|
||||
|
||||
if missing and raise_if_not_found:
|
||||
raise ValueError(f'No modules matching suffixes {missing} found in model {model}')
|
||||
|
||||
return leaf_modules, missing
|
||||
|
||||
|
||||
def set_z3_leaf_modules(model: torch.nn.Module,
|
||||
leaf_module_classes: Union[List[Type], List[str]],
|
||||
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
|
||||
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
|
||||
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
|
||||
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
|
||||
Args:
|
||||
model (torch.nn.Module): The model to which the leaf module flag will be applied.
|
||||
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
|
||||
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
|
||||
match a module inside ``model``.
|
||||
Returns:
|
||||
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
|
||||
"""
|
||||
return _do_set_z3_leaf_modules(model, leaf_module_classes, True)
|
||||
return _do_set_z3_leaf_modules(model, leaf_module_classes, True, raise_if_not_found)
|
||||
|
||||
|
||||
def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]:
|
||||
def unset_z3_leaf_modules(model: torch.nn.Module,
|
||||
leaf_module_classes: List[Type],
|
||||
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
|
||||
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
|
||||
See `set_z3_leaf_modules` for more details.
|
||||
Args:
|
||||
model (torch.nn.Module): The model to which the leaf module flag will be applied.
|
||||
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
|
||||
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
|
||||
match a module inside ``model``.
|
||||
Returns:
|
||||
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
|
||||
"""
|
||||
return _do_set_z3_leaf_modules(model, leaf_module_classes, False)
|
||||
return _do_set_z3_leaf_modules(model, leaf_module_classes, False, raise_if_not_found)
|
||||
|
||||
|
||||
def apply_zero_leaf_module_config(model: torch.nn.Module,
|
||||
leaf_cfg: Optional["DeepSpeedZeroLeafModuleConfig"]) -> List[torch.nn.Module]:
|
||||
"""Apply ZeRO leaf module configuration to ``model``.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Root module to update.
|
||||
leaf_cfg (DeepSpeedZeroLeafModuleConfig | None): Parsed configuration. If ``None``
|
||||
no changes are applied.
|
||||
|
||||
Returns:
|
||||
List[torch.nn.Module]: Modules flagged as leaves.
|
||||
"""
|
||||
if leaf_cfg is None:
|
||||
return []
|
||||
|
||||
from deepspeed.runtime.zero.leaf_module_config import (
|
||||
DEFAULT_LEAF_MODULE_CLASSES,
|
||||
DEFAULT_LEAF_MODULE_NAMES,
|
||||
DEFAULT_LEAF_MODULE_NAME_SUFFIXES,
|
||||
)
|
||||
|
||||
matched_modules: List[torch.nn.Module] = []
|
||||
matched_ids = set()
|
||||
|
||||
customized_classes = leaf_cfg.classes != DEFAULT_LEAF_MODULE_CLASSES
|
||||
customized_names = leaf_cfg.names != DEFAULT_LEAF_MODULE_NAMES
|
||||
customized_suffixes = leaf_cfg.name_suffixes != DEFAULT_LEAF_MODULE_NAME_SUFFIXES
|
||||
|
||||
if leaf_cfg.classes:
|
||||
class_matched = set_z3_leaf_modules(model, leaf_cfg.classes, raise_if_not_found=False)
|
||||
for module in class_matched:
|
||||
module_id = id(module)
|
||||
if module_id not in matched_ids:
|
||||
matched_ids.add(module_id)
|
||||
matched_modules.append(module)
|
||||
|
||||
if leaf_cfg.names:
|
||||
name_matched, missing_names = set_z3_leaf_modules_by_name(model,
|
||||
leaf_cfg.names,
|
||||
flag=True,
|
||||
raise_if_not_found=False)
|
||||
for module in name_matched:
|
||||
module_id = id(module)
|
||||
if module_id not in matched_ids:
|
||||
matched_ids.add(module_id)
|
||||
matched_modules.append(module)
|
||||
|
||||
if missing_names and customized_names:
|
||||
logger.warning(f"ZeRO leaf module configuration contains unknown module names: {missing_names}")
|
||||
|
||||
if leaf_cfg.name_suffixes:
|
||||
suffix_matched, missing_suffixes = set_z3_leaf_modules_by_suffix(model,
|
||||
leaf_cfg.name_suffixes,
|
||||
flag=True,
|
||||
raise_if_not_found=False)
|
||||
for module in suffix_matched:
|
||||
module_id = id(module)
|
||||
if module_id not in matched_ids:
|
||||
matched_ids.add(module_id)
|
||||
matched_modules.append(module)
|
||||
|
||||
if missing_suffixes and customized_suffixes:
|
||||
logger.warning(f"ZeRO leaf module configuration contains unmatched module suffixes: {missing_suffixes}")
|
||||
|
||||
if not matched_modules and (customized_classes or customized_names or customized_suffixes):
|
||||
logger.warning("ZeRO leaf module configuration did not match any modules; hooks will be applied as usual")
|
||||
|
||||
return matched_modules
|
||||
|
@ -73,6 +73,84 @@ Each configuration works as follows:
|
||||
.. autofunction:: deepspeed.runtime.torch_autocast.has_autocast_dtype
|
||||
|
||||
|
||||
Configuring ZeRO Leaf Modules
|
||||
-----------------------------
|
||||
|
||||
ZeRO-3 relies on module execution order to gather partitioned parameters.
|
||||
When models select submodules dynamically (for example, MoE routers), different data-parallel ranks may gather different sets of parameters, which can cause the all-gather collective to deadlock.
|
||||
To avoid this problem, you can designate the parent of dynamically activated submodules (e.g., MoE experts) as a "leaf" module.
|
||||
When a module is marked as a leaf, ZeRO gathers all of its descendants immediately and stops inserting hooks beneath it.
|
||||
|
||||
Programmatic API
|
||||
================
|
||||
|
||||
Use :func:`deepspeed.utils.set_z3_leaf_modules` to flag modules by class, class
|
||||
name, or both. Optionally combine with
|
||||
:func:`deepspeed.utils.set_z3_leaf_modules_by_name` to target specific entries
|
||||
from ``model.named_modules()`` or
|
||||
:func:`deepspeed.utils.set_z3_leaf_modules_by_suffix` to match suffixes of those
|
||||
names.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from deepspeed.utils import (
|
||||
set_z3_leaf_modules,
|
||||
set_z3_leaf_modules_by_name,
|
||||
set_z3_leaf_modules_by_suffix,
|
||||
)
|
||||
|
||||
# Match by class or subclass
|
||||
set_z3_leaf_modules(model, [CustomMoEBlock])
|
||||
|
||||
# Match by fully qualified class name
|
||||
set_z3_leaf_modules(model, ["my_package.layers.CustomMoEBlock"])
|
||||
|
||||
# Match by module name returned from model.named_modules()
|
||||
set_z3_leaf_modules_by_name(model, ["transformer.layers.0.experts"])
|
||||
|
||||
# Match by suffix of names returned from model.named_modules()
|
||||
set_z3_leaf_modules_by_suffix(model, ["experts"])
|
||||
|
||||
Configuration in DeepSpeed config
|
||||
=================================
|
||||
|
||||
The same behavior can be controlled from the DeepSpeed config. Add a
|
||||
``leaf_module`` block to ``zero_optimization`` specifying either classes,
|
||||
module names, or name suffixes (or any combination). By default DeepSpeed marks
|
||||
several Hugging Face MoE blocks—including Mixtral and Qwen MoE sparse blocks so
|
||||
that they behave well with ZeRO3.
|
||||
|
||||
The default class list currently contains:
|
||||
|
||||
* ``transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock``
|
||||
* ``transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock``
|
||||
* ``transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock``
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"leaf_module": {
|
||||
"classes": ["my_package.layers.CustomMoEBlock"],
|
||||
"names": ["transformer.layers.0.experts"],
|
||||
"name_suffixes": ["experts"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
``names`` must match exactly what ``model.named_modules()`` produces. The
|
||||
``name_suffixes`` field compares each suffix against the end of those same
|
||||
module paths, making it convenient to apply a rule across repeated structures.
|
||||
Entries in ``classes`` may be either bare class names (for example,
|
||||
``MixtralSparseMoeBlock``) or fully qualified dotted paths; both forms are
|
||||
accepted.
|
||||
|
||||
You can mix and match the API and configuration approaches; all referenced
|
||||
modules are flagged before ZeRO installs its hooks.
|
||||
|
||||
|
||||
Model Saving
|
||||
------------
|
||||
.. autofunction:: deepspeed.DeepSpeedEngine.save_16bit_model
|
||||
|
@ -11,7 +11,11 @@ from unit.common import DistributedTest, preferred_dtype
|
||||
from unit.simple_model import random_dataloader
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module
|
||||
from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, \
|
||||
set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix
|
||||
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
|
||||
from deepspeed.runtime.zero.leaf_module_config import (DEFAULT_LEAF_MODULE_CLASSES, DEFAULT_LEAF_MODULE_NAMES,
|
||||
DEFAULT_LEAF_MODULE_NAME_SUFFIXES)
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from torch import nn
|
||||
import time
|
||||
@ -82,6 +86,142 @@ class FineGrainedBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class BaseLeafModule(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BaseLeafModule, self).__init__()
|
||||
|
||||
|
||||
class SubLeafModule(BaseLeafModule):
|
||||
|
||||
def __init__(self, hidden_dim):
|
||||
super(SubLeafModule, self).__init__()
|
||||
self.proj = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class WrapperLeafModule(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim):
|
||||
super(WrapperLeafModule, self).__init__()
|
||||
self.child = SubLeafModule(hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.child(x)
|
||||
|
||||
|
||||
def test_set_leaf_modules_with_fully_qualified_name():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
fq_name = f"{SubLeafModule.__module__}.{SubLeafModule.__qualname__}"
|
||||
|
||||
matched = set_z3_leaf_modules(model, [fq_name])
|
||||
|
||||
assert len(matched) == 1
|
||||
assert matched[0] is model.child
|
||||
assert z3_leaf_module(model.child)
|
||||
assert not z3_leaf_module(model)
|
||||
|
||||
|
||||
def test_set_leaf_modules_no_raise_when_missing():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched = set_z3_leaf_modules(model, ["NonExistentClass"], raise_if_not_found=False)
|
||||
|
||||
assert matched == []
|
||||
assert not z3_leaf_module(model.child)
|
||||
|
||||
|
||||
def test_set_leaf_modules_by_name():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched, missing = set_z3_leaf_modules_by_name(model, ["child"])
|
||||
|
||||
assert matched == [model.child]
|
||||
assert missing == []
|
||||
assert z3_leaf_module(model.child)
|
||||
|
||||
|
||||
def test_set_leaf_modules_by_name_missing():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched, missing = set_z3_leaf_modules_by_name(model, ["missing"], raise_if_not_found=False)
|
||||
|
||||
assert matched == []
|
||||
assert missing == ["missing"]
|
||||
|
||||
|
||||
def test_set_leaf_modules_by_suffix():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched, missing = set_z3_leaf_modules_by_suffix(model, ["child"])
|
||||
|
||||
assert missing == []
|
||||
assert matched == [model.child]
|
||||
assert z3_leaf_module(model.child)
|
||||
|
||||
|
||||
def test_set_leaf_modules_by_suffix_missing():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched, missing = set_z3_leaf_modules_by_suffix(model, ["missing"], raise_if_not_found=False)
|
||||
|
||||
assert matched == []
|
||||
assert missing == ["missing"]
|
||||
|
||||
|
||||
def test_zero_leaf_module_default_config():
|
||||
config = DeepSpeedZeroConfig()
|
||||
assert config.leaf_module.classes == DEFAULT_LEAF_MODULE_CLASSES
|
||||
assert config.leaf_module.names == DEFAULT_LEAF_MODULE_NAMES
|
||||
assert config.leaf_module.name_suffixes == DEFAULT_LEAF_MODULE_NAME_SUFFIXES
|
||||
|
||||
|
||||
def test_zero_leaf_module_custom_config():
|
||||
payload = {
|
||||
"leaf_module": {
|
||||
"classes": ["custom.module.CustomClass"],
|
||||
"names": ["transformer.layer"],
|
||||
"name_suffixes": ["experts"]
|
||||
}
|
||||
}
|
||||
|
||||
config = DeepSpeedZeroConfig(**payload)
|
||||
|
||||
assert config.leaf_module.classes == ["custom.module.CustomClass"]
|
||||
assert config.leaf_module.names == ["transformer.layer"]
|
||||
assert config.leaf_module.name_suffixes == ["experts"]
|
||||
|
||||
|
||||
def test_zero_leaf_module_string_coercion():
|
||||
payload = {"leaf_module": {"classes": "my.Class", "names": "submodule", "name_suffixes": "tail"}}
|
||||
|
||||
config = DeepSpeedZeroConfig(**payload)
|
||||
|
||||
assert config.leaf_module.classes == ["my.Class"]
|
||||
assert config.leaf_module.names == ["submodule"]
|
||||
assert config.leaf_module.name_suffixes == ["tail"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires Hugging Face transformers; run manually when validating defaults.")
|
||||
def test_default_leaf_module_classes_exist():
|
||||
import importlib
|
||||
|
||||
from deepspeed.runtime.zero.leaf_module_config import DEFAULT_LEAF_MODULE_CLASSES
|
||||
|
||||
for cls_path in DEFAULT_LEAF_MODULE_CLASSES:
|
||||
module_name, _, class_name = cls_path.rpartition('.')
|
||||
module = importlib.import_module(module_name)
|
||||
assert hasattr(module, class_name), f"Expected {class_name} in {module_name}"
|
||||
|
||||
|
||||
class modelWithFineGrainedBlock(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim, num_block):
|
||||
@ -123,10 +263,7 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
world_size = 2
|
||||
reuse_dist_env = True
|
||||
|
||||
def _test_set_z3_leaf_modules(self, cls, requires_grad):
|
||||
hidden_dim = 128
|
||||
|
||||
# `stage3_max_reuse_distance` is set to 0 to cause an error if the module is not set as a leaf module
|
||||
def _create_zero_config(self, hidden_dim, leaf_module=None):
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -143,11 +280,20 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
"stage3_max_reuse_distance": 0,
|
||||
}
|
||||
}
|
||||
if leaf_module is not None:
|
||||
config_dict["zero_optimization"]["leaf_module"] = leaf_module
|
||||
|
||||
if preferred_dtype() is torch.float16:
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif preferred_dtype() is torch.bfloat16:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
return config_dict
|
||||
|
||||
def _test_set_z3_leaf_modules(self, cls, requires_grad):
|
||||
hidden_dim = 128
|
||||
config_dict = self._create_zero_config(hidden_dim)
|
||||
|
||||
model = cls(hidden_dim)
|
||||
|
||||
assert not z3_leaf_module(model)
|
||||
@ -181,6 +327,17 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
"Expected only one module to be unset as a leaf module"
|
||||
assert len(get_z3_leaf_modules(model)) == 0, "Expected there is no leaf module"
|
||||
|
||||
def test_set_leaf_modules_with_subclass(self):
|
||||
hidden_dim = 32
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
leaf_modules = set_z3_leaf_modules(model, [BaseLeafModule])
|
||||
|
||||
assert len(leaf_modules) == 1, "Expected the subclass instance to be marked as leaf"
|
||||
assert leaf_modules[0] is model.child, "Expected the subclass instance to be returned"
|
||||
assert z3_leaf_module(model.child), "Expected subclass instance flagged as leaf"
|
||||
assert not z3_leaf_module(model), "Expected wrapper module to remain non-leaf"
|
||||
|
||||
def test_set_no_match_class(self):
|
||||
hidden_dim = 128
|
||||
model = ChooseModuleByCounter(hidden_dim)
|
||||
@ -190,6 +347,25 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
except ValueError as e:
|
||||
pass
|
||||
|
||||
def test_leaf_module_enabled_via_config(self):
|
||||
hidden_dim = 128
|
||||
leaf_class_fqn = f"{ChooseModuleByCounter.__module__}.{ChooseModuleByCounter.__qualname__}"
|
||||
config_dict = self._create_zero_config(hidden_dim,
|
||||
leaf_module={
|
||||
"classes": [leaf_class_fqn],
|
||||
"name_suffixes": ["linears"]
|
||||
})
|
||||
|
||||
model = ChooseModuleByCounter(hidden_dim)
|
||||
assert not z3_leaf_module(model)
|
||||
|
||||
run_model(model, config_dict, hidden_dim, preferred_dtype(), True)
|
||||
|
||||
assert z3_leaf_module(model)
|
||||
modules_by_name = dict(model.named_modules())
|
||||
assert "linears" in modules_by_name
|
||||
assert z3_leaf_module(modules_by_name["linears"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("module_granularity_threshold", [0, 100, 12100, 10000000])
|
||||
class TestZ3LeafOptimization(DistributedTest):
|
||||
|
Reference in New Issue
Block a user