Merge branch 'master' into gma/zenflow_binding_study

This commit is contained in:
Olatunji Ruwase
2025-10-04 08:59:23 -04:00
committed by GitHub
16 changed files with 825 additions and 174 deletions

View File

@ -0,0 +1,188 @@
# SuperOffload: 释放超级芯片上大规模LLM训练的潜力
**在单个英伟达GH200超级芯片上高效完成GPT-OSS-20B和Qwen3-14B模型的全参数微调并在四块英伟达GH200超级芯片上实现Llama3-70B模型的训练同时提供高达600TFLOPS的训练吞吐量。**
**作者**
[Xinyu Lian](https://xinyulian.tech/)<sup>1</sup>, [Masahiro Tanaka](https://tohtana.github.io/)<sup>2</sup>, [Olatunji Ruwase](https://www.snowflake.com/en/blog/authors/olatunji--tunji--ruwase/)<sup>3</sup>, [Minjia Zhang](https://minjiazhang.github.io/)<sup>1</sup>
<sup>1</sup>SSAIL Lab, University of Illinois Urbana-Champaign · <sup>2</sup>Anyscale · <sup>3</sup>Snowflake
---
## 目录 <!-- omit in toc -->
- [SuperOffload释放超级芯片上大规模LLM训练的潜力](#superoffload释放超级芯片上大规模llm训练的潜力)
- [SuperOffload的亮点](#superoffload的亮点)
- [介绍](#介绍)
- [SuperOffload的工作原理](#superoffload的工作原理)
- [1. 推测验证机制STV](#1-推测验证机制stv)
- [2. 异构优化器计算](#2-异构优化器计算)
- [3. 超级芯片感知的类型转换](#3-超级芯片感知的类型转换)
- [4. GraceAdam提升优化器效率](#4-graceadam提升优化器效率)
- [经验与洞察](#经验与洞察)
- [快速使用指南](#快速使用指南)
- [致谢](#致谢)
---
## SuperOffload的亮点
- 在**一块GH200**上能够对GPT-OSS-20B和Qwen3-14B进行全参数微调达到600TFLOPS的运算速度Seqlen=4KBS=4
- **多卡训练**在两块英伟达GH200上训练Qwen3-30B-A3B和Seed-OSS-36B在四块英伟达GH200上训练Llama-70B。
- **训练速度**在合理的设置下比ZeRO-Offload快四倍的训练吞吐量。
- **提高显卡利用率**将显卡利用率从约50%提高到大于80%。
- **灵活组合性**支持ZeRO-3和Ulysses一些操作技巧如NUMA绑定和MPAM等已在教程中详细说明。
---
## 介绍
紧密耦合的异构GPU/CPU架构又称超级芯片的出现例如NVIDIA GH200、GB200和AMD MI300A为大规模AI提供了新的优化机遇。然而如何充分利用这些新硬件进行大规模LLM训练仍处于探索不足的状态。现有的offloading解决方案是为传统松散耦合架构设计的在超级芯片上表现欠佳存在高开销和低GPU利用率的问题。为弥补这一空白并充分利用超级芯片实现高效LLM训练我们开发并开源了**SuperOffload**。
SuperOffload引入了一系列创新技术可同时充分利用Hopper GPU、Grace CPU和NVLink-C2C进行LLM训练。与先前假设GPU-CPU互连速度较慢如PCIe-Gen4的64GB/秒的offloading解决方案不同SuperOffload利用更高速的互连技术如NVLink-C2C的900GB/秒来提升GPU和CPU利用率及训练吞吐量。借助SuperOffload诸如**GPT-OSS-20B**、**Qwen3-14B**和**Phi-4**等模型可在单台GH200上完成全参数微调在常规设置下序列长度4k批次大小4实现高达**600 TFLOPS**的训练吞吐量。与ZeRO-Offload等先前工作相比此举可实现高达**4倍**的吞吐量提升。SuperOffload还能支持扩展至更大模型包括在两台GH200上运行Qwen3-30B-A3B和Seed-OSS-36B以及在四台GH200上运行Llama-70B。
SuperOffload构建于DeepSpeed ZeRO Stage 3之上并在DeepSpeed [0.18.0]((https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.18.0)及以上版本中提供。为便于集成到LLM微调流程中SuperOffload与Hugging Face Transformers兼容且无需对模型代码进行任何修改。
<div align="center">
<img src="./images/superoffload_comparison.jpg" alt="SuperOffload system overview" width="90%">
<p align="center"><em>图1在不同序列长度和批次大小的大型模型微调中SuperOffload相比ZeRO-Offload可实现高达4倍的吞吐量提升最高达到600 TFLOPS的吞吐量。</em></p>
</div>
---
## SuperOffload的工作原理
SuperOffload包含四项可组合的offloading优化技术(1) 推测验证机制,(2) GPU/CPU优化器计算(3) 超级芯片感知的类型转换,以及(4) GraceAdam优化器。以下我们将简要介绍这些技术。
### 1. 推测验证机制STV
在大多数offloading解决方案中优化器步骤需要CPU和GPU之间的同步以确保数值鲁棒性。例如梯度norm裁剪需要计算全局梯度norm混合精度训练需要全局检查NaN和INF值。这些操作要求CPU等待直到收到所有梯度后才能执行优化器步骤和权重更新。STV通过打破这种依赖性来避免此瓶颈同时通过将CPU上的推测性优化器计算与GPU上的反向传播重叠来保持训练语义。当梯度后处理最终完成时推测性优化器计算会根据情况被提交、丢弃或正确重放。STV对训练稳定性的后验证使其能够相比先前的前验证方法安全地缩短关键路径。下图展示了SuperOffload如何以不同于传统方法如ZeRO-Offload的方式调度反向传播和优化器计算。
<div align="center">
<img src="./images/superoffload_schedule.jpg" alt="Schedule comparison" width="80%">
<p align="center"><em>图2以往的offloading方法受限于全局梯度范数计算及全局NaN/INF值检查导致优化器步骤暴露在关键路径中且无法实现计算重叠。SuperOffload通过引入推测验证调度机制来解决这一问题。</em></p>
</div>
我们通过测量BLOOM-176B模型预训练过程中推测性优化器计算被撤销的频率来评估STV的有效性。如下图所示这类回滚例如由于梯度裁剪等原因引起在预热阶段后很少发生使得相关开销在整个训练过程中可忽略不计。这使得STV在加速大规模训练方面具有实用性。
<div align="center">
<img src="./images/superoffload_rollback.jpg" alt="Gradient clipping data" width="80%">
<p align="center"><em>图3红色数据点表示BLOOM预训练过程中触发梯度裁剪的时刻——在预热阶段后极少出现这表明SuperOffload的STV机制有效消除了由梯度裁剪和NaN/INF检查引起的同步停顿。
</em></p>
</div>
---
### 2. 异构优化器计算
SuperOffload通过将优化器计算分区到GPU和CPU上来提升STV之外的优化器效率。GPU用于处理反向传播后期阶段产生的梯度对应的优化器计算而CPU则负责其余部分。这种分区方案具有多重优势首先GPU无需闲置等待CPU完成优化器计算其次通过同时利用GPU和CPU的计算资源减少了优化器计算时间第三避免了与GPU优化器计算对应的参数和梯度在GPU-CPU间的传输。
---
### 3. 超级芯片感知的类型转换
在采用offloading的混合精度训练中GPU与CPU之间的张量传输需要在GPU低精度格式如BF16、FP16等与CPU高精度格式即FP32间进行类型转换。为应对PCIe互连的带宽限制先前的offloading解决方案采用低精度传输张量并在GPU和CPU上适时进行类型转换。然而这在超级芯片架构中并非最优策略因为GPU计算吞吐量约为CPU的100倍而高带宽互连如NVLink-C2C使得传输成本可忽略不计。如图4所示GH200上的最优策略是在GPU上进行张量类型转换并采用高精度格式传输。
<div align="center">
<img src="./images/superoffload_cast_transfer.jpg" alt="Tensor casting optimization" width="80%">
<p align="center"><em>图4GH200在超级芯片上通过GPU进行张量高低精度转换并以高精度格式传输更为高效。</em></p>
</div>
---
### 4. GraceAdam提升优化器效率
现有用于LLM训练的offloading解决方案需要流行Adam优化器如PyTorch Adam和DeepSpeed CPU-Adam的CPU实现版本。然而这些实现并不适用于超级芯片因为它们未针对Grace CPU架构进行优化。为解决此问题我们创建了GraceAdam——专为Grace CPU设计的高效Adam优化器实现。GraceAdam通过利用底层ARM架构特性如可扩展向量扩展SVE、显式内存层次管理和指令级并行实现高性能。图5显示在GH200超级芯片上GraceAdam比PyTorch Adam快3倍比CPU-Adam快1.3倍。据我们所知GraceAdam是首个面向Grace CPU开源的Adam优化器实现。
<div align="center">
<img src="./images/superoffload_grace_adam.png" alt="GraceAdam" width="80%">
<p align="center"><em>图5使用GraceAdam在GH200上实现高效Adam优化器计算。</em></p>
</div>
## 经验与洞察
- **NUMA绑定**
将每个GPU与其直接关联的CPU进行配对以最大化带宽。在DeepSpeed中
```bash
--bind_cores_to_rank
```
- **MPAM内存系统资源分区与监控**
减少CPU与GPU任务间的相互干扰。
**如何在NVIDIA超级芯片上启用MPAM**
1. 安装[NVIDIA NV-Kernels](https://github.com/NVIDIA/NV-Kernels/tree/24.04_linux-nvidia-adv-6.11)提供的内核。
2. 检查MPAM支持情况
```bash
grep MPAM /boot/config-$(uname -r)
```
预期输出:
```
CONFIG_ARM64_MPAM=y
CONFIG_ACPI_MPAM=y
CONFIG_ARM64_MPAM_DRIVER=y
CONFIG_ARM64_MPAM_RESCTRL_FS=y
```
检查resctrl文件系统
```bash
ls -ld /sys/fs/resctrl
```
3. 挂载resctrl
```bash
mount -t resctrl resctrl /sys/fs/resctrl
```
4. 建立分区:
```bash
mkdir /sys/fs/resctrl/p1 /sys/fs/resctrl/p2
```
5. 设定CPU内核与内存配置
```bash
/sys/fs/resctrl/p1/cpus_list:
0-6
/sys/fs/resctrl/p2/cpus_list:
7-71
/sys/fs/resctrl/p1/schemata:
MB:1=100
L3:1=ff0
/sys/fs/resctrl/p2/schemata:
MB:1=20
L3:1=f
```
---
## 快速使用指南
我们已在教程/说明文档[DeepSpeedExamples: SuperOffload](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/DeepSpeed-SuperOffload#readme)中提供了SuperOffload的端到端微调示例。请在DeepSpeed配置中添加以下开关完整上下文请参阅教程
<div align="center">
<img src="./images/superoffload_enable.jpg" alt="Enable SuperOffload" width="60%">
<p align="center"><em>图6通过在DeepSpeed配置中添加单行代码即可启用SuperOffload。</em></p>
</div>
提示在超级芯片平台如GH200/GB200/MI300A结合"经验与洞察"章节中的NUMA绑定与MPAM设置可稳定带宽并提升端到端性能。
---
## 致谢
本成果由[University of Illinois Urbana-Champaign (UIUC)](https://supercomputing-system-ai-lab.github.io/), [Anyscale](https://www.anyscale.com/)与[Snowflake](https://www.snowflake.com/en/blog/authors/snowflake-ai-research/)紧密协作完成。
我们同时衷心感谢美国国家超级计算应用中心的William Gropp、Brett Bode和Gregory H. Bauer以及NVIDIA的Dan Ernst、Ian Karlin、Giridhar Chukkapalli、Kurt Rago等专家就Grace CPU的MPAM支持提供的宝贵讨论与指导。
欢迎社区反馈与贡献。具体启用方法与示例请参阅前文「快速开始」章节。
---
## BibTeX <!-- omit in toc -->
```bibtex
@inproceedings{superoffload,
author = {Xinyu Lian and Masahiro Tanaka and Olatunji Ruwase and Minjia Zhang},
title = "{SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips}",
year = {2026},
booktitle = {Proceedings of the 31st ACM International Conference on Architectural Support for Programming Languages and Operating System (ASPLOS'26)}
}
```

View File

@ -16,6 +16,7 @@ try:
import torch._dynamo
from functorch.compile import make_boxed_func
from torch._functorch.aot_autograd import aot_module_simplified
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch._subclasses.fake_tensor import is_fake
except ImportError:
@ -367,17 +368,16 @@ def make_backend(backend, compile_config, compile_kwargs={}):
return compiler_fn
partition_fn = get_wrapped_partitioner(z3_partition, param_indices, min_cut_rematerialization_partition)
aot_mod = aot_module_simplified(gm,
real_inputs,
fw_compiler=make_compiler_fn(make_fw_graph),
bw_compiler=make_compiler_fn(make_bw_graph),
partition_fn=get_wrapped_partitioner(param_indices))
partition_fn=partition_fn)
return torch._dynamo.optimize(**compile_kwargs)(aot_mod)
elif backend == "inductor":
patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs,
param_indices, param_manager)
from .partitioner import get_wrapped_choose_saved_values_set
torch._functorch.partitioners.choose_saved_values_set = get_wrapped_choose_saved_values_set(param_indices)
return torch._inductor.compile(gm, real_inputs)

View File

@ -20,6 +20,7 @@ except ImportError:
from deepspeed.utils.torch import required_torch_version
from .util import get_input_nodes
from .graph_param import DSGraphParamManager
from .partitioner import get_wrapped_partitioner
def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):
@ -66,7 +67,8 @@ def wrap_partition_fn(partition_fn, real_inputs, param_indices):
def wrapped_partition_fn(*args, **kwargs):
fw_module, bw_module = partition_fn(*args, **kwargs)
fn = get_wrapped_partitioner(True, param_indices, partition_fn=partition_fn)
fw_module, bw_module = fn(*args, **kwargs)
# get parameter names
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)

View File

@ -3,156 +3,94 @@
# DeepSpeed Team
# This file was copied from PyTorch and modified for DeepSpeed.
from typing import Tuple, List
import operator
import torch
from torch.fx import GraphModule, Graph, Node
try:
from torch._functorch.partitioners import is_sym_node, _is_primal, _is_fwd_seed_offset, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, _extract_fwd_bwd_modules, has_recomputable_ops, min_cut_rematerialization_partition, choose_saved_values_set
from torch.utils.checkpoint import CheckpointPolicy
from torch._functorch.partitioners import _is_primal
except ImportError:
pass
from .util import get_no_copy_ops
_recompute_ops = {torch.ops.aten.t.default}
from .util import get_no_copy_ops, is_cast_op
def _find_recompute_nodes(graph: Graph, ds_param_node: Node) -> List[Node]:
"""
Given a graph and a node that represents a parameter that was allgathered,
find all nodes that use the parameter and require recomputation.
def _recompute_param_aliases(joint_graph: Graph, param_indices: List[Tuple[int, int, torch.Size]]):
"""Recompute nodes aliasing or downcasting any parameter
In ZeRO3, sharded parameters are gathered before use and the gathered
parameters should be freed once they are no longer needed to save GPU
memory.
When DeepCompile is active for ZeRO3, parameter gathering is done by custom
passes after the joint graph captured by Dynamo and AOT Autograd is
partitioned into fwd and bwd parts. Since the partitioner has no clue about
parameter sharding now, the partitioned graphs will save for backward all
intermediate activations including those aliasing the gathered parameters.
That essentially nullifies the memory reduction that ZeRO3 is designed to
bring.
The solution is to recompute the parameter-aliasing activations in the
backward. It is done by marking such nodes as MUST_RECOMPUTE and reusing the
min-cut partitioner originally designed for checkpointing. If autocast is
enabled, parameter downcasts are also recomputed.
This cannot be converted to a standalone pass because it must be applied
before partitioning the joint graph, but passes run after the partitioning.
TODO(eternalNight) `min_cut_rematerialization_partition` may recompute more
nodes than required for ZeRO3. Need investigate its performance
implications.
"""
no_copy_ops = get_no_copy_ops()
recompute_nodes = set()
for node in graph.nodes:
if node.target in no_copy_ops:
if ds_param_node in node.args:
recompute_nodes.add(node)
if any(a in recompute_nodes for a in node.args):
recompute_nodes.add(node)
return recompute_nodes
def need_recompute(n: Node) -> bool:
if n.op == "call_function":
is_cast, _ = is_cast_op(n)
return n.target in no_copy_ops or is_cast
return False
def _get_values_from_ds_params(joint_graph, param_indices):
primal_inputs = list(filter(_is_primal, joint_graph.nodes))
ds_param_inputs = [primal_inputs[arg_idx] for arg_idx, _, _ in param_indices]
no_copy_ops = get_no_copy_ops()
ds_param_inputs = set(ds_param_inputs)
ds_param_users = {}
ds_param_inputs = set([primal_inputs[arg_idx] for arg_idx, _, _ in param_indices])
recomputed_nodes = set()
for node in joint_graph.nodes:
if node.target in no_copy_ops and any((a in ds_param_inputs or a in ds_param_users) for a in node.args):
for a in node.args:
if a in ds_param_inputs:
ds_param_users[node] = a
elif a in ds_param_users:
ds_param_users[node] = ds_param_users[a]
# The `ac_graph_id` tag tracks the checkpoint module that a node belongs
# to, and is for enforcing the saving of activations at the boundary of
# consecutive checkpointed blocks. It starts from 1 and increments by 1
# each time a graph module is checkpointed.
#
# `min_cut_rematerialization_partition` requires every node to have
# `ac_graph_id`. If this graph is not checkpointed (and thus
# `ac_graph_id` is missing), we tag all nodes to 1 to prevent the
# partition function from modifying the recompute tag.
node.meta.setdefault("ac_graph_id", 1)
return ds_param_users
# Arguments can be non-tensor types some of which are not hashable. So
# we must inspect the type of an argument before checking if it is in
# any set.
if need_recompute(node) and \
any([(isinstance(a, Node) and (a in ds_param_inputs or a in recomputed_nodes)) for a in node.args]):
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
recomputed_nodes.add(node)
else:
# If checkpointing is not enabled for this graph, assume all
# activations required by the backward pass should be saved.
node.meta.setdefault("recompute", CheckpointPolicy.MUST_SAVE)
def get_wrapped_choose_saved_values_set(param_indices: List[Tuple[int, int, torch.Size]]):
def ds_choose_saved_values_set(joint_graph: torch.fx.Graph, node_info, memory_budget=1) -> List[Node]:
saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget)
ds_param_users = _get_values_from_ds_params(joint_graph, param_indices)
new_saved_values = []
for v in saved_values:
if v in ds_param_users:
ds_val = ds_param_users[v]
if ds_val not in new_saved_values:
new_saved_values.append(ds_val)
else:
new_saved_values.append(v)
return new_saved_values
return ds_choose_saved_values_set
def get_wrapped_partitioner(param_indices: List[Tuple[int, int, torch.Size]]):
def get_wrapped_partitioner(
z3_partition: bool,
param_indices: List[Tuple[int, int, torch.Size]],
partition_fn,
):
def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *,
num_fwd_outputs) -> Tuple[GraphModule, GraphModule]:
"""
This is basically the same as the default_partition function, but
it doesn't save the gathered params and values computed from them.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs, "forward")
forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != "output"}
saved_values = []
saved_sym_nodes = []
fwd_inputs = list(filter(_is_primal, forward_only_graph.nodes))
ds_param_inputs = [fwd_inputs[arg_idx] for arg_idx, _, _ in param_indices]
ds_param_input_names = {node.name for node in ds_param_inputs}
ds_param_recompute_nodes = set()
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif "tensor_meta" not in node.meta and node.op == "call_function":
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target == operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
if node.name in ds_param_input_names:
saved_values.append(node)
recompute_nodes = _find_recompute_nodes(joint_module.graph, node)
recompute_nodes = [n for n in recompute_nodes if n.name in forward_node_names]
for recompute_node in recompute_nodes:
ds_param_recompute_nodes.add(recompute_node)
if len(recompute_nodes) > 0:
saved_values.append(node)
else:
if node not in ds_param_recompute_nodes:
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
f_gm, b_gm = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
num_fwd_outputs=num_fwd_outputs,
)
return f_gm, b_gm
if z3_partition:
_recompute_param_aliases(joint_module.graph, param_indices)
return partition_fn(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
return partition_recompute_ds_params

View File

@ -76,6 +76,7 @@ from deepspeed.runtime.sparse_tensor import SparseTensor
from deepspeed.runtime import lr_schedules
from deepspeed.utils import groups
from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx
from deepspeed.utils.z3_leaf_module import apply_zero_leaf_module_config
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
STEP_MICRO_TIMER, \
@ -1293,6 +1294,7 @@ class DeepSpeedEngine(Module):
def _configure_distributed_model(self, model):
self._set_client_model(model)
apply_zero_leaf_module_config(self.module, getattr(self._config.zero_config, "leaf_module", None))
is_zero_init_model = self.zero_optimization_partition_weights() and any(
[hasattr(param, "ds_id") for param in self.module.parameters()])

View File

@ -11,6 +11,7 @@ from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedCo
from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig
from .leaf_module_config import DeepSpeedZeroLeafModuleConfig
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
@ -356,6 +357,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Enable internal sanity checks, which could be useful for debugging
"""
leaf_module: DeepSpeedZeroLeafModuleConfig = Field(default_factory=DeepSpeedZeroLeafModuleConfig)
"""
Configuration for modules that should be treated as ZeRO3 leaf modules.
"""
# Validators
@model_validator(mode="after")
def overlap_comm_valid(self):

View File

@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
from pydantic import Field, model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
DEFAULT_LEAF_MODULE_CLASSES: List[str] = [
"transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock",
"transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock",
"transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock",
]
DEFAULT_LEAF_MODULE_NAMES: List[str] = []
DEFAULT_LEAF_MODULE_NAME_SUFFIXES: List[str] = []
class DeepSpeedZeroLeafModuleConfig(DeepSpeedConfigModel):
"""Configuration for ZeRO leaf modules that should bypass hook installation."""
classes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_CLASSES))
names: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAMES))
name_suffixes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAME_SUFFIXES))
@model_validator(mode="before")
def _coerce_container_types(cls, values):
if values is None:
return {}
if isinstance(values, dict):
coerced = dict(values)
for key in ("classes", "names", "name_suffixes"):
if key in coerced and isinstance(coerced[key], str):
coerced[key] = [coerced[key]]
return coerced
raise TypeError("leaf_module configuration must be a mapping of fields to values")
@model_validator(mode="after")
def _validate_entries(self):
normalized_classes = [str(cls) for cls in self.classes]
normalized_names = [str(name) for name in self.names]
normalized_suffixes = [str(suffix) for suffix in self.name_suffixes]
deduped_classes = list(dict.fromkeys(normalized_classes))
deduped_names = list(dict.fromkeys(normalized_names))
deduped_suffixes = list(dict.fromkeys(normalized_suffixes))
object.__setattr__(self, "classes", deduped_classes)
object.__setattr__(self, "names", deduped_names)
object.__setattr__(self, "name_suffixes", deduped_suffixes)
return self

View File

@ -17,7 +17,7 @@ from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_s
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
from .tensor_fragment import safe_update_full_grad_vectorized
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module, set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd

View File

@ -83,19 +83,20 @@ def print_configuration(args, name):
logger.info(" {} {} {}".format(arg, dots, getattr(args, arg)))
def log_dist(message, ranks=None, level=logging.INFO, use_logger=True):
def get_dist_msg(message, ranks=None):
from deepspeed import comm as dist
"""Log message when one of following condition meets
"""Return a message with rank prefix when one of following conditions is met:
+ not dist.is_initialized()
+ dist.get_rank() in ranks if ranks is not None or ranks = [-1]
+ not dist.is_initialized()
+ dist.get_rank() in ranks if ranks is not None or ranks = [-1]
If neither is met, `None` is returned.
Example: "hello" => "[Rank 0] hello"
Args:
message (str)
ranks (list)
level (int)
use_logger (bool): if `False` ignores the log-levels and always prints
"""
should_log = not dist.is_initialized()
ranks = ranks or []
@ -104,11 +105,36 @@ def log_dist(message, ranks=None, level=logging.INFO, use_logger=True):
should_log = ranks[0] == -1
should_log = should_log or (my_rank in set(ranks))
if should_log:
final_message = "[Rank {}] {}".format(my_rank, message)
if use_logger:
logger.log(level, final_message)
else:
print(final_message)
return "[Rank {}] {}".format(my_rank, message)
else:
return None
def log_dist(message, ranks=None, level=logging.INFO):
"""Log message when get_dist_msg() deems it should be logged, see its docstring for details.
Args:
message (str)
ranks (list)
level (int)
"""
final_message = get_dist_msg(message, ranks)
if final_message is not None:
logger.log(level, final_message)
def print_dist(message, ranks=None):
"""print message when get_dist_msg() deems it should be logged, see its docstring for details.
Use this function instead of `log_dist` when the log level shouldn't impact whether the message should be printed or not.
Args:
message (str)
ranks (list)
"""
final_message = get_dist_msg(message, ranks)
if final_message is not None:
print(final_message)
@functools.lru_cache(None)

View File

@ -5,7 +5,7 @@
import time
from numpy import mean
from deepspeed.utils.logging import log_dist
from deepspeed.utils.logging import print_dist
from deepspeed.accelerator import get_accelerator
FORWARD_MICRO_TIMER = 'fwd_microstep'
@ -149,7 +149,7 @@ class SynchronizedWallClockTimer:
string += " | {}: {:.2f}".format(name, elapsed_time)
# timers logging should be independent of the global log level it's already conditional on wall_clock_breakdown being True, so using use_logger=False will always print the stats
log_dist(string, ranks=ranks or [0], use_logger=False)
print_dist(string, ranks=ranks or [0])
def get_mean(self, names, normalizer=1.0, reset=True):
"""Get the mean of a group of timers."""

View File

@ -4,7 +4,12 @@
# DeepSpeed Team
import torch
from typing import List, Type, Union
from typing import List, Tuple, Type, Union, Optional, TYPE_CHECKING
from .logging import logger
if TYPE_CHECKING:
from deepspeed.runtime.zero.leaf_module_config import DeepSpeedZeroLeafModuleConfig
def z3_leaf_module(model: torch.nn.Module) -> bool:
@ -44,50 +49,201 @@ def set_z3_leaf_module(model: torch.nn.Module, flag: bool):
model._z3_leaf = flag
def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]],
flag: bool) -> List[torch.nn.Module]:
assert all(isinstance(module_class, (type, str) ) for module_class in leaf_module_classes), \
def _fully_qualified_class_name(module: torch.nn.Module) -> str:
cls = module.__class__
return f"{cls.__module__}.{cls.__qualname__}"
def _do_set_z3_leaf_modules(model: torch.nn.Module,
leaf_module_classes: Union[List[Type], List[str]],
flag: bool,
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
assert all(isinstance(module_class, (type, str)) for module_class in leaf_module_classes), \
f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}'
leaf_modules = []
leaf_modules: List[torch.nn.Module] = []
def _set_z3_leaf_flag(model: torch.nn.Module):
def _set_z3_leaf_flag(module_instance: torch.nn.Module):
nonlocal leaf_modules
for module in leaf_module_classes:
if (isinstance(module, type) and model.__class__ == module) or \
(isinstance(module, str) and model.__class__.__name__ == module):
model._z3_leaf = flag
leaf_modules.append(model)
if isinstance(module, type) and isinstance(module_instance, module):
module_instance._z3_leaf = flag
leaf_modules.append(module_instance)
break
if isinstance(module, str):
if (module_instance.__class__.__name__ == module
or _fully_qualified_class_name(module_instance) == module):
module_instance._z3_leaf = flag
leaf_modules.append(module_instance)
break
model.apply(_set_z3_leaf_flag)
if len(leaf_modules) == 0:
if len(leaf_modules) == 0 and raise_if_not_found:
raise ValueError(f'No modules of type {leaf_module_classes} found in model {model}')
return leaf_modules
def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type],
List[str]]) -> List[torch.nn.Module]:
def set_z3_leaf_modules_by_name(model: torch.nn.Module,
module_names: List[str],
flag: bool = True,
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
"""Sets a leaf flag for modules referenced by their names in ``model.named_modules()``.
Args:
model (torch.nn.Module): The model containing the modules to update.
module_names (List[str]): Module names as returned by ``named_modules()``.
flag (bool): Desired flag state.
raise_if_not_found (bool): Whether to raise when no module matches a provided name.
Returns:
Tuple[List[torch.nn.Module], List[str]]: Matched modules and missing module names.
"""
modules_by_name = dict(model.named_modules())
leaf_modules: List[torch.nn.Module] = []
missing: List[str] = []
for name in module_names:
module = modules_by_name.get(name)
if module is None:
missing.append(name)
continue
module._z3_leaf = flag
leaf_modules.append(module)
if missing and raise_if_not_found:
raise ValueError(f'No modules named {missing} found in model {model}')
return leaf_modules, missing
def set_z3_leaf_modules_by_suffix(model: torch.nn.Module,
module_name_suffixes: List[str],
flag: bool = True,
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
"""Sets a leaf flag for modules referenced by suffixes of ``model.named_modules()`` names."""
modules_by_name = dict(model.named_modules())
leaf_modules: List[torch.nn.Module] = []
missing: List[str] = []
seen_ids = set()
for suffix in module_name_suffixes:
matched = False
for name, module in modules_by_name.items():
if name.endswith(suffix):
module._z3_leaf = flag
module_id = id(module)
if module_id not in seen_ids:
seen_ids.add(module_id)
leaf_modules.append(module)
matched = True
if not matched:
missing.append(suffix)
if missing and raise_if_not_found:
raise ValueError(f'No modules matching suffixes {missing} found in model {model}')
return leaf_modules, missing
def set_z3_leaf_modules(model: torch.nn.Module,
leaf_module_classes: Union[List[Type], List[str]],
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
match a module inside ``model``.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
return _do_set_z3_leaf_modules(model, leaf_module_classes, True)
return _do_set_z3_leaf_modules(model, leaf_module_classes, True, raise_if_not_found)
def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]:
def unset_z3_leaf_modules(model: torch.nn.Module,
leaf_module_classes: List[Type],
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
match a module inside ``model``.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
return _do_set_z3_leaf_modules(model, leaf_module_classes, False)
return _do_set_z3_leaf_modules(model, leaf_module_classes, False, raise_if_not_found)
def apply_zero_leaf_module_config(model: torch.nn.Module,
leaf_cfg: Optional["DeepSpeedZeroLeafModuleConfig"]) -> List[torch.nn.Module]:
"""Apply ZeRO leaf module configuration to ``model``.
Args:
model (torch.nn.Module): Root module to update.
leaf_cfg (DeepSpeedZeroLeafModuleConfig | None): Parsed configuration. If ``None``
no changes are applied.
Returns:
List[torch.nn.Module]: Modules flagged as leaves.
"""
if leaf_cfg is None:
return []
from deepspeed.runtime.zero.leaf_module_config import (
DEFAULT_LEAF_MODULE_CLASSES,
DEFAULT_LEAF_MODULE_NAMES,
DEFAULT_LEAF_MODULE_NAME_SUFFIXES,
)
matched_modules: List[torch.nn.Module] = []
matched_ids = set()
customized_classes = leaf_cfg.classes != DEFAULT_LEAF_MODULE_CLASSES
customized_names = leaf_cfg.names != DEFAULT_LEAF_MODULE_NAMES
customized_suffixes = leaf_cfg.name_suffixes != DEFAULT_LEAF_MODULE_NAME_SUFFIXES
if leaf_cfg.classes:
class_matched = set_z3_leaf_modules(model, leaf_cfg.classes, raise_if_not_found=False)
for module in class_matched:
module_id = id(module)
if module_id not in matched_ids:
matched_ids.add(module_id)
matched_modules.append(module)
if leaf_cfg.names:
name_matched, missing_names = set_z3_leaf_modules_by_name(model,
leaf_cfg.names,
flag=True,
raise_if_not_found=False)
for module in name_matched:
module_id = id(module)
if module_id not in matched_ids:
matched_ids.add(module_id)
matched_modules.append(module)
if missing_names and customized_names:
logger.warning(f"ZeRO leaf module configuration contains unknown module names: {missing_names}")
if leaf_cfg.name_suffixes:
suffix_matched, missing_suffixes = set_z3_leaf_modules_by_suffix(model,
leaf_cfg.name_suffixes,
flag=True,
raise_if_not_found=False)
for module in suffix_matched:
module_id = id(module)
if module_id not in matched_ids:
matched_ids.add(module_id)
matched_modules.append(module)
if missing_suffixes and customized_suffixes:
logger.warning(f"ZeRO leaf module configuration contains unmatched module suffixes: {missing_suffixes}")
if not matched_modules and (customized_classes or customized_names or customized_suffixes):
logger.warning("ZeRO leaf module configuration did not match any modules; hooks will be applied as usual")
return matched_modules

View File

@ -73,6 +73,84 @@ Each configuration works as follows:
.. autofunction:: deepspeed.runtime.torch_autocast.has_autocast_dtype
Configuring ZeRO Leaf Modules
-----------------------------
ZeRO-3 relies on module execution order to gather partitioned parameters.
When models select submodules dynamically (for example, MoE routers), different data-parallel ranks may gather different sets of parameters, which can cause the all-gather collective to deadlock.
To avoid this problem, you can designate the parent of dynamically activated submodules (e.g., MoE experts) as a "leaf" module.
When a module is marked as a leaf, ZeRO gathers all of its descendants immediately and stops inserting hooks beneath it.
Programmatic API
================
Use :func:`deepspeed.utils.set_z3_leaf_modules` to flag modules by class, class
name, or both. Optionally combine with
:func:`deepspeed.utils.set_z3_leaf_modules_by_name` to target specific entries
from ``model.named_modules()`` or
:func:`deepspeed.utils.set_z3_leaf_modules_by_suffix` to match suffixes of those
names.
.. code-block:: python
from deepspeed.utils import (
set_z3_leaf_modules,
set_z3_leaf_modules_by_name,
set_z3_leaf_modules_by_suffix,
)
# Match by class or subclass
set_z3_leaf_modules(model, [CustomMoEBlock])
# Match by fully qualified class name
set_z3_leaf_modules(model, ["my_package.layers.CustomMoEBlock"])
# Match by module name returned from model.named_modules()
set_z3_leaf_modules_by_name(model, ["transformer.layers.0.experts"])
# Match by suffix of names returned from model.named_modules()
set_z3_leaf_modules_by_suffix(model, ["experts"])
Configuration in DeepSpeed config
=================================
The same behavior can be controlled from the DeepSpeed config. Add a
``leaf_module`` block to ``zero_optimization`` specifying either classes,
module names, or name suffixes (or any combination). By default DeepSpeed marks
several Hugging Face MoE blocks—including Mixtral and Qwen MoE sparse blocks so
that they behave well with ZeRO3.
The default class list currently contains:
* ``transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock``
* ``transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock``
* ``transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock``
.. code-block:: json
{
"train_micro_batch_size_per_gpu": 1,
"zero_optimization": {
"stage": 3,
"leaf_module": {
"classes": ["my_package.layers.CustomMoEBlock"],
"names": ["transformer.layers.0.experts"],
"name_suffixes": ["experts"]
}
}
}
``names`` must match exactly what ``model.named_modules()`` produces. The
``name_suffixes`` field compares each suffix against the end of those same
module paths, making it convenient to apply a rule across repeated structures.
Entries in ``classes`` may be either bare class names (for example,
``MixtralSparseMoeBlock``) or fully qualified dotted paths; both forms are
accepted.
You can mix and match the API and configuration approaches; all referenced
modules are flagged before ZeRO installs its hooks.
Model Saving
------------
.. autofunction:: deepspeed.DeepSpeedEngine.save_16bit_model

View File

@ -21,6 +21,8 @@ import deepspeed
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist
from .util import torch_assert_close
import pytest
from _pytest.outcomes import Skipped
from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
@ -562,6 +564,8 @@ def enable_determinism(seed: int):
def reduce_boolean_flags(flag: bool, op=all) -> bool:
if not dist.is_initialized():
return flag
device = get_accelerator().current_device()
tensor_flag = torch.tensor(1 if flag else 0, dtype=torch.int, device=device)
world_size = dist.get_world_size()
@ -569,3 +573,24 @@ def reduce_boolean_flags(flag: bool, op=all) -> bool:
dist.all_gather_into_tensor(tensor_flag_buf, tensor_flag)
list_flags = [bool(f) for f in tensor_flag_buf.tolist()]
return op(list_flags)
def allclose_on_all_ranks(actual, expected, assert_message=None, **kwargs) -> None:
"""
Compare two tensors across all ranks.
We want to make sure that all ranks succeed or fail together.
"""
allclose_local = False
allclose_global = False
mismatch_msg = ""
try:
torch_assert_close(actual, expected, **kwargs)
allclose_local = True
allclose_global = reduce_boolean_flags(allclose_local, all)
except AssertionError:
allclose_global = reduce_boolean_flags(allclose_local, all)
mismatch_msg = f"Tensors are not close: {actual=}, {expected=} {kwargs=}"
if not allclose_global:
message = "Tensors are not close on all ranks." if assert_message is None else assert_message
raise AssertionError(f"{message} {mismatch_msg}")

View File

@ -11,7 +11,11 @@ from unit.common import DistributedTest, preferred_dtype
from unit.simple_model import random_dataloader
import deepspeed
from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module
from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, \
set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.leaf_module_config import (DEFAULT_LEAF_MODULE_CLASSES, DEFAULT_LEAF_MODULE_NAMES,
DEFAULT_LEAF_MODULE_NAME_SUFFIXES)
from deepspeed.accelerator import get_accelerator
from torch import nn
import time
@ -82,6 +86,142 @@ class FineGrainedBlock(nn.Module):
return x
class BaseLeafModule(nn.Module):
def __init__(self):
super(BaseLeafModule, self).__init__()
class SubLeafModule(BaseLeafModule):
def __init__(self, hidden_dim):
super(SubLeafModule, self).__init__()
self.proj = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
return self.proj(x)
class WrapperLeafModule(nn.Module):
def __init__(self, hidden_dim):
super(WrapperLeafModule, self).__init__()
self.child = SubLeafModule(hidden_dim)
def forward(self, x):
return self.child(x)
def test_set_leaf_modules_with_fully_qualified_name():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
fq_name = f"{SubLeafModule.__module__}.{SubLeafModule.__qualname__}"
matched = set_z3_leaf_modules(model, [fq_name])
assert len(matched) == 1
assert matched[0] is model.child
assert z3_leaf_module(model.child)
assert not z3_leaf_module(model)
def test_set_leaf_modules_no_raise_when_missing():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched = set_z3_leaf_modules(model, ["NonExistentClass"], raise_if_not_found=False)
assert matched == []
assert not z3_leaf_module(model.child)
def test_set_leaf_modules_by_name():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched, missing = set_z3_leaf_modules_by_name(model, ["child"])
assert matched == [model.child]
assert missing == []
assert z3_leaf_module(model.child)
def test_set_leaf_modules_by_name_missing():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched, missing = set_z3_leaf_modules_by_name(model, ["missing"], raise_if_not_found=False)
assert matched == []
assert missing == ["missing"]
def test_set_leaf_modules_by_suffix():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched, missing = set_z3_leaf_modules_by_suffix(model, ["child"])
assert missing == []
assert matched == [model.child]
assert z3_leaf_module(model.child)
def test_set_leaf_modules_by_suffix_missing():
hidden_dim = 16
model = WrapperLeafModule(hidden_dim)
matched, missing = set_z3_leaf_modules_by_suffix(model, ["missing"], raise_if_not_found=False)
assert matched == []
assert missing == ["missing"]
def test_zero_leaf_module_default_config():
config = DeepSpeedZeroConfig()
assert config.leaf_module.classes == DEFAULT_LEAF_MODULE_CLASSES
assert config.leaf_module.names == DEFAULT_LEAF_MODULE_NAMES
assert config.leaf_module.name_suffixes == DEFAULT_LEAF_MODULE_NAME_SUFFIXES
def test_zero_leaf_module_custom_config():
payload = {
"leaf_module": {
"classes": ["custom.module.CustomClass"],
"names": ["transformer.layer"],
"name_suffixes": ["experts"]
}
}
config = DeepSpeedZeroConfig(**payload)
assert config.leaf_module.classes == ["custom.module.CustomClass"]
assert config.leaf_module.names == ["transformer.layer"]
assert config.leaf_module.name_suffixes == ["experts"]
def test_zero_leaf_module_string_coercion():
payload = {"leaf_module": {"classes": "my.Class", "names": "submodule", "name_suffixes": "tail"}}
config = DeepSpeedZeroConfig(**payload)
assert config.leaf_module.classes == ["my.Class"]
assert config.leaf_module.names == ["submodule"]
assert config.leaf_module.name_suffixes == ["tail"]
@pytest.mark.skip(reason="Requires Hugging Face transformers; run manually when validating defaults.")
def test_default_leaf_module_classes_exist():
import importlib
from deepspeed.runtime.zero.leaf_module_config import DEFAULT_LEAF_MODULE_CLASSES
for cls_path in DEFAULT_LEAF_MODULE_CLASSES:
module_name, _, class_name = cls_path.rpartition('.')
module = importlib.import_module(module_name)
assert hasattr(module, class_name), f"Expected {class_name} in {module_name}"
class modelWithFineGrainedBlock(nn.Module):
def __init__(self, hidden_dim, num_block):
@ -123,10 +263,7 @@ class TestSetZ3LeafModule(DistributedTest):
world_size = 2
reuse_dist_env = True
def _test_set_z3_leaf_modules(self, cls, requires_grad):
hidden_dim = 128
# `stage3_max_reuse_distance` is set to 0 to cause an error if the module is not set as a leaf module
def _create_zero_config(self, hidden_dim, leaf_module=None):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
@ -143,11 +280,20 @@ class TestSetZ3LeafModule(DistributedTest):
"stage3_max_reuse_distance": 0,
}
}
if leaf_module is not None:
config_dict["zero_optimization"]["leaf_module"] = leaf_module
if preferred_dtype() is torch.float16:
config_dict["fp16"] = {"enabled": True}
elif preferred_dtype() is torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
return config_dict
def _test_set_z3_leaf_modules(self, cls, requires_grad):
hidden_dim = 128
config_dict = self._create_zero_config(hidden_dim)
model = cls(hidden_dim)
assert not z3_leaf_module(model)
@ -181,6 +327,17 @@ class TestSetZ3LeafModule(DistributedTest):
"Expected only one module to be unset as a leaf module"
assert len(get_z3_leaf_modules(model)) == 0, "Expected there is no leaf module"
def test_set_leaf_modules_with_subclass(self):
hidden_dim = 32
model = WrapperLeafModule(hidden_dim)
leaf_modules = set_z3_leaf_modules(model, [BaseLeafModule])
assert len(leaf_modules) == 1, "Expected the subclass instance to be marked as leaf"
assert leaf_modules[0] is model.child, "Expected the subclass instance to be returned"
assert z3_leaf_module(model.child), "Expected subclass instance flagged as leaf"
assert not z3_leaf_module(model), "Expected wrapper module to remain non-leaf"
def test_set_no_match_class(self):
hidden_dim = 128
model = ChooseModuleByCounter(hidden_dim)
@ -190,6 +347,25 @@ class TestSetZ3LeafModule(DistributedTest):
except ValueError as e:
pass
def test_leaf_module_enabled_via_config(self):
hidden_dim = 128
leaf_class_fqn = f"{ChooseModuleByCounter.__module__}.{ChooseModuleByCounter.__qualname__}"
config_dict = self._create_zero_config(hidden_dim,
leaf_module={
"classes": [leaf_class_fqn],
"name_suffixes": ["linears"]
})
model = ChooseModuleByCounter(hidden_dim)
assert not z3_leaf_module(model)
run_model(model, config_dict, hidden_dim, preferred_dtype(), True)
assert z3_leaf_module(model)
modules_by_name = dict(model.named_modules())
assert "linears" in modules_by_name
assert z3_leaf_module(modules_by_name["linears"])
@pytest.mark.parametrize("module_granularity_threshold", [0, 100, 12100, 10000000])
class TestZ3LeafOptimization(DistributedTest):

View File

@ -94,15 +94,15 @@ class no_child_process_in_deepspeed_io:
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method
def torch_assert_equal(actual, expected, **kwargs):
def torch_assert_equal(actual, expected, **kwargs) -> None:
"""
Compare two tensors or non-tensor numbers for their equality.
Add msg=blah to add an additional comment to when assert fails.
"""
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
def torch_assert_close(actual, expected, **kwargs):
def torch_assert_close(actual, expected, **kwargs) -> None:
"""
Compare two tensors or non-tensor numbers for their closeness.
@ -113,7 +113,7 @@ def torch_assert_close(actual, expected, **kwargs):
The check doesn't assert when `|a - b| <= (atol + rtol * |b|)`
"""
return torch.testing.assert_close(actual, expected, **kwargs)
torch.testing.assert_close(actual, expected, **kwargs)
def torch_assert_dicts_of_tensors_equal(actual, expected, **kwargs):

View File

@ -12,12 +12,14 @@ from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero import GatheredParameters
from unit.simple_model import SimpleModel
from unit.common import enable_determinism
from unit.common import enable_determinism, allclose_on_all_ranks
@enable_determinism(123)
def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
hidden_dim = hidden_dim_override if hidden_dim_override is not None else 10
# the default tolerances of torch.testing.assert_close are too small
RTOL = 5e-1
ATOL = 1e-2
@ -56,7 +58,7 @@ def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
baseline_loss = baseline_engine(x, y)
target_loss = target_engine(x, y)
assert torch.allclose(baseline_loss, target_loss, rtol=RTOL, atol=ATOL)
allclose_on_all_ranks(baseline_loss, target_loss, "Loss values are not close.", rtol=RTOL, atol=ATOL)
baseline_engine.backward(baseline_loss)
target_engine.backward(target_loss)
@ -66,7 +68,7 @@ def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
with GatheredParameters(target_engine.parameters()):
for p1, p2 in zip(baseline_engine.parameters(), target_engine.parameters()):
assert torch.allclose(p1.to(dtype), p2, rtol=RTOL, atol=ATOL)
allclose_on_all_ranks(p1, p2, "Parameters are not equal.", rtol=RTOL, atol=ATOL)
baseline_engine.destroy()
target_engine.destroy()