Improve leaf module interface (enable via config, relax matching criteria, add document, etc.) (#7604)

This PR improves the usability of the leaf module feature.

Here are the changes:
- Allow enabling the leaf module via both the DeepSpeed config and APIs.
- Relax matching criteria to support class-based matching.
- Support multiple ways of specifying the target module: class, class
name (with or without package name), module name, or suffix.
- Add documentation to the training guide, including config snippets and
explanations of default behavior.
- Add default classes (e.g., Mixtral, Qwen2/Qwen3) that automatically
enable the leaf module feature. (Welcoming requests to add more classes)

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Masahiro Tanaka
2025-10-03 02:45:28 -07:00
committed by GitHub
parent 82a9db7eba
commit 7d9a2f2bf3
7 changed files with 492 additions and 22 deletions

View File

@ -76,6 +76,7 @@ from deepspeed.runtime.sparse_tensor import SparseTensor
from deepspeed.runtime import lr_schedules
from deepspeed.utils import groups
from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx
from deepspeed.utils.z3_leaf_module import apply_zero_leaf_module_config
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
STEP_MICRO_TIMER, \
@ -1293,6 +1294,7 @@ class DeepSpeedEngine(Module):
def _configure_distributed_model(self, model):
self._set_client_model(model)
apply_zero_leaf_module_config(self.module, getattr(self._config.zero_config, "leaf_module", None))
is_zero_init_model = self.zero_optimization_partition_weights() and any(
[hasattr(param, "ds_id") for param in self.module.parameters()])

View File

@ -11,6 +11,7 @@ from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedCo
from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig
from .leaf_module_config import DeepSpeedZeroLeafModuleConfig
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
@ -356,6 +357,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Enable internal sanity checks, which could be useful for debugging
"""
leaf_module: DeepSpeedZeroLeafModuleConfig = Field(default_factory=DeepSpeedZeroLeafModuleConfig)
"""
Configuration for modules that should be treated as ZeRO3 leaf modules.
"""
# Validators
@model_validator(mode="after")
def overlap_comm_valid(self):

View File

@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
from pydantic import Field, model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
DEFAULT_LEAF_MODULE_CLASSES: List[str] = [
"transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock",
"transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock",
"transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock",
]
DEFAULT_LEAF_MODULE_NAMES: List[str] = []
DEFAULT_LEAF_MODULE_NAME_SUFFIXES: List[str] = []
class DeepSpeedZeroLeafModuleConfig(DeepSpeedConfigModel):
"""Configuration for ZeRO leaf modules that should bypass hook installation."""
classes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_CLASSES))
names: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAMES))
name_suffixes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAME_SUFFIXES))
@model_validator(mode="before")
def _coerce_container_types(cls, values):
if values is None:
return {}
if isinstance(values, dict):
coerced = dict(values)
for key in ("classes", "names", "name_suffixes"):
if key in coerced and isinstance(coerced[key], str):
coerced[key] = [coerced[key]]
return coerced
raise TypeError("leaf_module configuration must be a mapping of fields to values")
@model_validator(mode="after")
def _validate_entries(self):
normalized_classes = [str(cls) for cls in self.classes]
normalized_names = [str(name) for name in self.names]
normalized_suffixes = [str(suffix) for suffix in self.name_suffixes]
deduped_classes = list(dict.fromkeys(normalized_classes))
deduped_names = list(dict.fromkeys(normalized_names))
deduped_suffixes = list(dict.fromkeys(normalized_suffixes))
object.__setattr__(self, "classes", deduped_classes)
object.__setattr__(self, "names", deduped_names)
object.__setattr__(self, "name_suffixes", deduped_suffixes)
return self

View File

@ -17,7 +17,7 @@ from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_s
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
from .tensor_fragment import safe_update_full_grad_vectorized
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module, set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd

View File

@ -4,7 +4,12 @@
# DeepSpeed Team
import torch
from typing import List, Type, Union
from typing import List, Tuple, Type, Union, Optional, TYPE_CHECKING
from .logging import logger
if TYPE_CHECKING:
from deepspeed.runtime.zero.leaf_module_config import DeepSpeedZeroLeafModuleConfig
def z3_leaf_module(model: torch.nn.Module) -> bool:
@ -44,50 +49,201 @@ def set_z3_leaf_module(model: torch.nn.Module, flag: bool):
model._z3_leaf = flag
def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]],
flag: bool) -> List[torch.nn.Module]:
assert all(isinstance(module_class, (type, str) ) for module_class in leaf_module_classes), \
def _fully_qualified_class_name(module: torch.nn.Module) -> str:
cls = module.__class__
return f"{cls.__module__}.{cls.__qualname__}"
def _do_set_z3_leaf_modules(model: torch.nn.Module,
leaf_module_classes: Union[List[Type], List[str]],
flag: bool,
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
assert all(isinstance(module_class, (type, str)) for module_class in leaf_module_classes), \
f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}'
leaf_modules = []
leaf_modules: List[torch.nn.Module] = []
def _set_z3_leaf_flag(model: torch.nn.Module):
def _set_z3_leaf_flag(module_instance: torch.nn.Module):
nonlocal leaf_modules
for module in leaf_module_classes:
if (isinstance(module, type) and model.__class__ == module) or \
(isinstance(module, str) and model.__class__.__name__ == module):
model._z3_leaf = flag
leaf_modules.append(model)
if isinstance(module, type) and isinstance(module_instance, module):
module_instance._z3_leaf = flag
leaf_modules.append(module_instance)
break
if isinstance(module, str):
if (module_instance.__class__.__name__ == module
or _fully_qualified_class_name(module_instance) == module):
module_instance._z3_leaf = flag
leaf_modules.append(module_instance)
break
model.apply(_set_z3_leaf_flag)
if len(leaf_modules) == 0:
if len(leaf_modules) == 0 and raise_if_not_found:
raise ValueError(f'No modules of type {leaf_module_classes} found in model {model}')
return leaf_modules
def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type],
List[str]]) -> List[torch.nn.Module]:
def set_z3_leaf_modules_by_name(model: torch.nn.Module,
module_names: List[str],
flag: bool = True,
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
"""Sets a leaf flag for modules referenced by their names in ``model.named_modules()``.
Args:
model (torch.nn.Module): The model containing the modules to update.
module_names (List[str]): Module names as returned by ``named_modules()``.
flag (bool): Desired flag state.
raise_if_not_found (bool): Whether to raise when no module matches a provided name.
Returns:
Tuple[List[torch.nn.Module], List[str]]: Matched modules and missing module names.
"""
modules_by_name = dict(model.named_modules())
leaf_modules: List[torch.nn.Module] = []
missing: List[str] = []
for name in module_names:
module = modules_by_name.get(name)
if module is None:
missing.append(name)
continue
module._z3_leaf = flag
leaf_modules.append(module)
if missing and raise_if_not_found:
raise ValueError(f'No modules named {missing} found in model {model}')
return leaf_modules, missing
def set_z3_leaf_modules_by_suffix(model: torch.nn.Module,
module_name_suffixes: List[str],
flag: bool = True,
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
"""Sets a leaf flag for modules referenced by suffixes of ``model.named_modules()`` names."""
modules_by_name = dict(model.named_modules())
leaf_modules: List[torch.nn.Module] = []
missing: List[str] = []
seen_ids = set()
for suffix in module_name_suffixes:
matched = False
for name, module in modules_by_name.items():
if name.endswith(suffix):
module._z3_leaf = flag
module_id = id(module)
if module_id not in seen_ids:
seen_ids.add(module_id)
leaf_modules.append(module)
matched = True
if not matched:
missing.append(suffix)
if missing and raise_if_not_found:
raise ValueError(f'No modules matching suffixes {missing} found in model {model}')
return leaf_modules, missing
def set_z3_leaf_modules(model: torch.nn.Module,
leaf_module_classes: Union[List[Type], List[str]],
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
match a module inside ``model``.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
return _do_set_z3_leaf_modules(model, leaf_module_classes, True)
return _do_set_z3_leaf_modules(model, leaf_module_classes, True, raise_if_not_found)
def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]:
def unset_z3_leaf_modules(model: torch.nn.Module,
leaf_module_classes: List[Type],
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
match a module inside ``model``.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
return _do_set_z3_leaf_modules(model, leaf_module_classes, False)
return _do_set_z3_leaf_modules(model, leaf_module_classes, False, raise_if_not_found)
def apply_zero_leaf_module_config(model: torch.nn.Module,
leaf_cfg: Optional["DeepSpeedZeroLeafModuleConfig"]) -> List[torch.nn.Module]:
"""Apply ZeRO leaf module configuration to ``model``.
Args:
model (torch.nn.Module): Root module to update.
leaf_cfg (DeepSpeedZeroLeafModuleConfig | None): Parsed configuration. If ``None``
no changes are applied.
Returns:
List[torch.nn.Module]: Modules flagged as leaves.
"""
if leaf_cfg is None:
return []
from deepspeed.runtime.zero.leaf_module_config import (
DEFAULT_LEAF_MODULE_CLASSES,
DEFAULT_LEAF_MODULE_NAMES,
DEFAULT_LEAF_MODULE_NAME_SUFFIXES,
)
matched_modules: List[torch.nn.Module] = []
matched_ids = set()
customized_classes = leaf_cfg.classes != DEFAULT_LEAF_MODULE_CLASSES
customized_names = leaf_cfg.names != DEFAULT_LEAF_MODULE_NAMES
customized_suffixes = leaf_cfg.name_suffixes != DEFAULT_LEAF_MODULE_NAME_SUFFIXES
if leaf_cfg.classes:
class_matched = set_z3_leaf_modules(model, leaf_cfg.classes, raise_if_not_found=False)
for module in class_matched:
module_id = id(module)
if module_id not in matched_ids:
matched_ids.add(module_id)
matched_modules.append(module)
if leaf_cfg.names:
name_matched, missing_names = set_z3_leaf_modules_by_name(model,
leaf_cfg.names,
flag=True,
raise_if_not_found=False)
for module in name_matched:
module_id = id(module)
if module_id not in matched_ids:
matched_ids.add(module_id)
matched_modules.append(module)
if missing_names and customized_names:
logger.warning(f"ZeRO leaf module configuration contains unknown module names: {missing_names}")
if leaf_cfg.name_suffixes:
suffix_matched, missing_suffixes = set_z3_leaf_modules_by_suffix(model,
leaf_cfg.name_suffixes,
flag=True,
raise_if_not_found=False)
for module in suffix_matched:
module_id = id(module)
if module_id not in matched_ids:
matched_ids.add(module_id)
matched_modules.append(module)
if missing_suffixes and customized_suffixes:
logger.warning(f"ZeRO leaf module configuration contains unmatched module suffixes: {missing_suffixes}")
if not matched_modules and (customized_classes or customized_names or customized_suffixes):
logger.warning("ZeRO leaf module configuration did not match any modules; hooks will be applied as usual")
return matched_modules

View File

@ -73,6 +73,84 @@ Each configuration works as follows:
.. autofunction:: deepspeed.runtime.torch_autocast.has_autocast_dtype
Configuring ZeRO Leaf Modules
-----------------------------
ZeRO-3 relies on module execution order to gather partitioned parameters.
When models select submodules dynamically (for example, MoE routers), different data-parallel ranks may gather different sets of parameters, which can cause the all-gather collective to deadlock.
To avoid this problem, you can designate the parent of dynamically activated submodules (e.g., MoE experts) as a "leaf" module.
When a module is marked as a leaf, ZeRO gathers all of its descendants immediately and stops inserting hooks beneath it.
Programmatic API
================
Use :func:`deepspeed.utils.set_z3_leaf_modules` to flag modules by class, class
name, or both. Optionally combine with
:func:`deepspeed.utils.set_z3_leaf_modules_by_name` to target specific entries
from ``model.named_modules()`` or
:func:`deepspeed.utils.set_z3_leaf_modules_by_suffix` to match suffixes of those
names.
.. code-block:: python
from deepspeed.utils import (
set_z3_leaf_modules,
set_z3_leaf_modules_by_name,
set_z3_leaf_modules_by_suffix,
)
# Match by class or subclass
set_z3_leaf_modules(model, [CustomMoEBlock])
# Match by fully qualified class name
set_z3_leaf_modules(model, ["my_package.layers.CustomMoEBlock"])
# Match by module name returned from model.named_modules()
set_z3_leaf_modules_by_name(model, ["transformer.layers.0.experts"])
# Match by suffix of names returned from model.named_modules()
set_z3_leaf_modules_by_suffix(model, ["experts"])
Configuration in DeepSpeed config
=================================
The same behavior can be controlled from the DeepSpeed config. Add a
``leaf_module`` block to ``zero_optimization`` specifying either classes,
module names, or name suffixes (or any combination). By default DeepSpeed marks
several Hugging Face MoE blocks—including Mixtral and Qwen MoE sparse blocks so
that they behave well with ZeRO3.
The default class list currently contains:
* ``transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock``
* ``transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock``
* ``transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock``
.. code-block:: json
{
"train_micro_batch_size_per_gpu": 1,
"zero_optimization": {
"stage": 3,
"leaf_module": {
"classes": ["my_package.layers.CustomMoEBlock"],
"names": ["transformer.layers.0.experts"],
"name_suffixes": ["experts"]
}
}
}
``names`` must match exactly what ``model.named_modules()`` produces. The
``name_suffixes`` field compares each suffix against the end of those same
module paths, making it convenient to apply a rule across repeated structures.
Entries in ``classes`` may be either bare class names (for example,
``MixtralSparseMoeBlock``) or fully qualified dotted paths; both forms are
accepted.
You can mix and match the API and configuration approaches; all referenced
modules are flagged before ZeRO installs its hooks.
Model Saving
------------
.. autofunction:: deepspeed.DeepSpeedEngine.save_16bit_model

View File

@ -11,7 +11,11 @@ from unit.common import DistributedTest, preferred_dtype
from unit.simple_model import random_dataloader
import deepspeed
from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module
from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, \
set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.leaf_module_config import (DEFAULT_LEAF_MODULE_CLASSES, DEFAULT_LEAF_MODULE_NAMES,
DEFAULT_LEAF_MODULE_NAME_SUFFIXES)
from deepspeed.accelerator import get_accelerator
from torch import nn
import time
@ -82,6 +86,142 @@ class FineGrainedBlock(nn.Module):
return x
class BaseLeafModule(nn.Module):
def __init__(self):
super(BaseLeafModule, self).__init__()
class SubLeafModule(BaseLeafModule):
def __init__(self, hidden_dim):
super(SubLeafModule, self).__init__()
self.proj = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
return self.proj(x)
class WrapperLeafModule(nn.Module):
def __init__(self, hidden_dim):
super(WrapperLeafModule, self).__init__()
self.child = SubLeafModule(hidden_dim)
def forward(self, x):
return self.child(x)
def test_set_leaf_modules_with_fully_qualified_name():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
fq_name = f"{SubLeafModule.__module__}.{SubLeafModule.__qualname__}"
matched = set_z3_leaf_modules(model, [fq_name])
assert len(matched) == 1
assert matched[0] is model.child
assert z3_leaf_module(model.child)
assert not z3_leaf_module(model)
def test_set_leaf_modules_no_raise_when_missing():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched = set_z3_leaf_modules(model, ["NonExistentClass"], raise_if_not_found=False)
assert matched == []
assert not z3_leaf_module(model.child)
def test_set_leaf_modules_by_name():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched, missing = set_z3_leaf_modules_by_name(model, ["child"])
assert matched == [model.child]
assert missing == []
assert z3_leaf_module(model.child)
def test_set_leaf_modules_by_name_missing():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched, missing = set_z3_leaf_modules_by_name(model, ["missing"], raise_if_not_found=False)
assert matched == []
assert missing == ["missing"]
def test_set_leaf_modules_by_suffix():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched, missing = set_z3_leaf_modules_by_suffix(model, ["child"])
assert missing == []
assert matched == [model.child]
assert z3_leaf_module(model.child)
def test_set_leaf_modules_by_suffix_missing():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched, missing = set_z3_leaf_modules_by_suffix(model, ["missing"], raise_if_not_found=False)
assert matched == []
assert missing == ["missing"]
def test_zero_leaf_module_default_config():
config = DeepSpeedZeroConfig()
assert config.leaf_module.classes == DEFAULT_LEAF_MODULE_CLASSES
assert config.leaf_module.names == DEFAULT_LEAF_MODULE_NAMES
assert config.leaf_module.name_suffixes == DEFAULT_LEAF_MODULE_NAME_SUFFIXES
def test_zero_leaf_module_custom_config():
payload = {
"leaf_module": {
"classes": ["custom.module.CustomClass"],
"names": ["transformer.layer"],
"name_suffixes": ["experts"]
}
}
config = DeepSpeedZeroConfig(**payload)
assert config.leaf_module.classes == ["custom.module.CustomClass"]
assert config.leaf_module.names == ["transformer.layer"]
assert config.leaf_module.name_suffixes == ["experts"]
def test_zero_leaf_module_string_coercion():
payload = {"leaf_module": {"classes": "my.Class", "names": "submodule", "name_suffixes": "tail"}}
config = DeepSpeedZeroConfig(**payload)
assert config.leaf_module.classes == ["my.Class"]
assert config.leaf_module.names == ["submodule"]
assert config.leaf_module.name_suffixes == ["tail"]
@pytest.mark.skip(reason="Requires Hugging Face transformers; run manually when validating defaults.")
def test_default_leaf_module_classes_exist():
import importlib
from deepspeed.runtime.zero.leaf_module_config import DEFAULT_LEAF_MODULE_CLASSES
for cls_path in DEFAULT_LEAF_MODULE_CLASSES:
module_name, _, class_name = cls_path.rpartition('.')
module = importlib.import_module(module_name)
assert hasattr(module, class_name), f"Expected {class_name} in {module_name}"
class modelWithFineGrainedBlock(nn.Module):
def __init__(self, hidden_dim, num_block):
@ -123,10 +263,7 @@ class TestSetZ3LeafModule(DistributedTest):
world_size = 2
reuse_dist_env = True
def _test_set_z3_leaf_modules(self, cls, requires_grad):
hidden_dim = 128
# `stage3_max_reuse_distance` is set to 0 to cause an error if the module is not set as a leaf module
def _create_zero_config(self, hidden_dim, leaf_module=None):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
@ -143,11 +280,20 @@ class TestSetZ3LeafModule(DistributedTest):
"stage3_max_reuse_distance": 0,
}
}
if leaf_module is not None:
config_dict["zero_optimization"]["leaf_module"] = leaf_module
if preferred_dtype() is torch.float16:
config_dict["fp16"] = {"enabled": True}
elif preferred_dtype() is torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
return config_dict
def _test_set_z3_leaf_modules(self, cls, requires_grad):
hidden_dim = 128
config_dict = self._create_zero_config(hidden_dim)
model = cls(hidden_dim)
assert not z3_leaf_module(model)
@ -181,6 +327,17 @@ class TestSetZ3LeafModule(DistributedTest):
"Expected only one module to be unset as a leaf module"
assert len(get_z3_leaf_modules(model)) == 0, "Expected there is no leaf module"
def test_set_leaf_modules_with_subclass(self):
hidden_dim = 32
model = WrapperLeafModule(hidden_dim)
leaf_modules = set_z3_leaf_modules(model, [BaseLeafModule])
assert len(leaf_modules) == 1, "Expected the subclass instance to be marked as leaf"
assert leaf_modules[0] is model.child, "Expected the subclass instance to be returned"
assert z3_leaf_module(model.child), "Expected subclass instance flagged as leaf"
assert not z3_leaf_module(model), "Expected wrapper module to remain non-leaf"
def test_set_no_match_class(self):
hidden_dim = 128
model = ChooseModuleByCounter(hidden_dim)
@ -190,6 +347,25 @@ class TestSetZ3LeafModule(DistributedTest):
except ValueError as e:
pass
def test_leaf_module_enabled_via_config(self):
hidden_dim = 128
leaf_class_fqn = f"{ChooseModuleByCounter.__module__}.{ChooseModuleByCounter.__qualname__}"
config_dict = self._create_zero_config(hidden_dim,
leaf_module={
"classes": [leaf_class_fqn],
"name_suffixes": ["linears"]
})
model = ChooseModuleByCounter(hidden_dim)
assert not z3_leaf_module(model)
run_model(model, config_dict, hidden_dim, preferred_dtype(), True)
assert z3_leaf_module(model)
modules_by_name = dict(model.named_modules())
assert "linears" in modules_by_name
assert z3_leaf_module(modules_by_name["linears"])
@pytest.mark.parametrize("module_granularity_threshold", [0, 100, 12100, 10000000])
class TestZ3LeafOptimization(DistributedTest):