Files
DeepSpeed/deepspeed/utils/z3_leaf_module.py
Masahiro Tanaka 7d9a2f2bf3 Improve leaf module interface (enable via config, relax matching criteria, add document, etc.) (#7604)
This PR improves the usability of the leaf module feature.

Here are the changes:
- Allow enabling the leaf module via both the DeepSpeed config and APIs.
- Relax matching criteria to support class-based matching.
- Support multiple ways of specifying the target module: class, class
name (with or without package name), module name, or suffix.
- Add documentation to the training guide, including config snippets and
explanations of default behavior.
- Add default classes (e.g., Mixtral, Qwen2/Qwen3) that automatically
enable the leaf module feature. (Welcoming requests to add more classes)

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
2025-10-03 09:45:28 +00:00

250 lines
11 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from typing import List, Tuple, Type, Union, Optional, TYPE_CHECKING
from .logging import logger
if TYPE_CHECKING:
from deepspeed.runtime.zero.leaf_module_config import DeepSpeedZeroLeafModuleConfig
def z3_leaf_module(model: torch.nn.Module) -> bool:
"""Returns whether a module in `model` has been flagged as a 'leaf' module.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
Returns:
bool: Whether the module has been flagged as a 'leaf' module.
"""
return hasattr(model, '_z3_leaf') and model._z3_leaf
def z3_leaf_parameter(model: torch.nn.Parameter) -> bool:
"""Returns whether a parameter belongs to a leaf module.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Parameter): The parameter to which the leaf module flag will be applied.
Returns:
bool: Whether the parameter belongs to a leaf module.
"""
return hasattr(model, 'ds_z3_leaf_module')
def get_z3_leaf_modules(model: torch.nn.Module) -> List[torch.nn.Module]:
"""Returns a list of modules in `model` that have been flagged as 'leaf' modules.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
Returns:
List[torch.nn.Module]: A list of modules that have been flagged as 'leaf' modules.
"""
return [module for module in model.modules() if z3_leaf_module(module)]
def set_z3_leaf_module(model: torch.nn.Module, flag: bool):
model._z3_leaf = flag
def _fully_qualified_class_name(module: torch.nn.Module) -> str:
cls = module.__class__
return f"{cls.__module__}.{cls.__qualname__}"
def _do_set_z3_leaf_modules(model: torch.nn.Module,
leaf_module_classes: Union[List[Type], List[str]],
flag: bool,
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
assert all(isinstance(module_class, (type, str)) for module_class in leaf_module_classes), \
f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}'
leaf_modules: List[torch.nn.Module] = []
def _set_z3_leaf_flag(module_instance: torch.nn.Module):
nonlocal leaf_modules
for module in leaf_module_classes:
if isinstance(module, type) and isinstance(module_instance, module):
module_instance._z3_leaf = flag
leaf_modules.append(module_instance)
break
if isinstance(module, str):
if (module_instance.__class__.__name__ == module
or _fully_qualified_class_name(module_instance) == module):
module_instance._z3_leaf = flag
leaf_modules.append(module_instance)
break
model.apply(_set_z3_leaf_flag)
if len(leaf_modules) == 0 and raise_if_not_found:
raise ValueError(f'No modules of type {leaf_module_classes} found in model {model}')
return leaf_modules
def set_z3_leaf_modules_by_name(model: torch.nn.Module,
module_names: List[str],
flag: bool = True,
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
"""Sets a leaf flag for modules referenced by their names in ``model.named_modules()``.
Args:
model (torch.nn.Module): The model containing the modules to update.
module_names (List[str]): Module names as returned by ``named_modules()``.
flag (bool): Desired flag state.
raise_if_not_found (bool): Whether to raise when no module matches a provided name.
Returns:
Tuple[List[torch.nn.Module], List[str]]: Matched modules and missing module names.
"""
modules_by_name = dict(model.named_modules())
leaf_modules: List[torch.nn.Module] = []
missing: List[str] = []
for name in module_names:
module = modules_by_name.get(name)
if module is None:
missing.append(name)
continue
module._z3_leaf = flag
leaf_modules.append(module)
if missing and raise_if_not_found:
raise ValueError(f'No modules named {missing} found in model {model}')
return leaf_modules, missing
def set_z3_leaf_modules_by_suffix(model: torch.nn.Module,
module_name_suffixes: List[str],
flag: bool = True,
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
"""Sets a leaf flag for modules referenced by suffixes of ``model.named_modules()`` names."""
modules_by_name = dict(model.named_modules())
leaf_modules: List[torch.nn.Module] = []
missing: List[str] = []
seen_ids = set()
for suffix in module_name_suffixes:
matched = False
for name, module in modules_by_name.items():
if name.endswith(suffix):
module._z3_leaf = flag
module_id = id(module)
if module_id not in seen_ids:
seen_ids.add(module_id)
leaf_modules.append(module)
matched = True
if not matched:
missing.append(suffix)
if missing and raise_if_not_found:
raise ValueError(f'No modules matching suffixes {missing} found in model {model}')
return leaf_modules, missing
def set_z3_leaf_modules(model: torch.nn.Module,
leaf_module_classes: Union[List[Type], List[str]],
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
match a module inside ``model``.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
return _do_set_z3_leaf_modules(model, leaf_module_classes, True, raise_if_not_found)
def unset_z3_leaf_modules(model: torch.nn.Module,
leaf_module_classes: List[Type],
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
match a module inside ``model``.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
return _do_set_z3_leaf_modules(model, leaf_module_classes, False, raise_if_not_found)
def apply_zero_leaf_module_config(model: torch.nn.Module,
leaf_cfg: Optional["DeepSpeedZeroLeafModuleConfig"]) -> List[torch.nn.Module]:
"""Apply ZeRO leaf module configuration to ``model``.
Args:
model (torch.nn.Module): Root module to update.
leaf_cfg (DeepSpeedZeroLeafModuleConfig | None): Parsed configuration. If ``None``
no changes are applied.
Returns:
List[torch.nn.Module]: Modules flagged as leaves.
"""
if leaf_cfg is None:
return []
from deepspeed.runtime.zero.leaf_module_config import (
DEFAULT_LEAF_MODULE_CLASSES,
DEFAULT_LEAF_MODULE_NAMES,
DEFAULT_LEAF_MODULE_NAME_SUFFIXES,
)
matched_modules: List[torch.nn.Module] = []
matched_ids = set()
customized_classes = leaf_cfg.classes != DEFAULT_LEAF_MODULE_CLASSES
customized_names = leaf_cfg.names != DEFAULT_LEAF_MODULE_NAMES
customized_suffixes = leaf_cfg.name_suffixes != DEFAULT_LEAF_MODULE_NAME_SUFFIXES
if leaf_cfg.classes:
class_matched = set_z3_leaf_modules(model, leaf_cfg.classes, raise_if_not_found=False)
for module in class_matched:
module_id = id(module)
if module_id not in matched_ids:
matched_ids.add(module_id)
matched_modules.append(module)
if leaf_cfg.names:
name_matched, missing_names = set_z3_leaf_modules_by_name(model,
leaf_cfg.names,
flag=True,
raise_if_not_found=False)
for module in name_matched:
module_id = id(module)
if module_id not in matched_ids:
matched_ids.add(module_id)
matched_modules.append(module)
if missing_names and customized_names:
logger.warning(f"ZeRO leaf module configuration contains unknown module names: {missing_names}")
if leaf_cfg.name_suffixes:
suffix_matched, missing_suffixes = set_z3_leaf_modules_by_suffix(model,
leaf_cfg.name_suffixes,
flag=True,
raise_if_not_found=False)
for module in suffix_matched:
module_id = id(module)
if module_id not in matched_ids:
matched_ids.add(module_id)
matched_modules.append(module)
if missing_suffixes and customized_suffixes:
logger.warning(f"ZeRO leaf module configuration contains unmatched module suffixes: {missing_suffixes}")
if not matched_modules and (customized_classes or customized_names or customized_suffixes):
logger.warning("ZeRO leaf module configuration did not match any modules; hooks will be applied as usual")
return matched_modules