mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Merge branch 'master' into gma/zenflow_binding_study
This commit is contained in:
188
blogs/deepspeed-superoffload/README_cn.md
Normal file
188
blogs/deepspeed-superoffload/README_cn.md
Normal file
@ -0,0 +1,188 @@
|
||||
# SuperOffload: 释放超级芯片上大规模LLM训练的潜力
|
||||
|
||||
**在单个英伟达GH200超级芯片上高效完成GPT-OSS-20B和Qwen3-14B模型的全参数微调,并在四块英伟达GH200超级芯片上实现Llama3-70B模型的训练,同时提供高达600TFLOPS的训练吞吐量。**
|
||||
|
||||
**作者**
|
||||
[Xinyu Lian](https://xinyulian.tech/)<sup>1</sup>, [Masahiro Tanaka](https://tohtana.github.io/)<sup>2</sup>, [Olatunji Ruwase](https://www.snowflake.com/en/blog/authors/olatunji--tunji--ruwase/)<sup>3</sup>, [Minjia Zhang](https://minjiazhang.github.io/)<sup>1</sup>
|
||||
|
||||
<sup>1</sup>SSAIL Lab, University of Illinois Urbana-Champaign · <sup>2</sup>Anyscale · <sup>3</sup>Snowflake
|
||||
|
||||
---
|
||||
|
||||
## 目录 <!-- omit in toc -->
|
||||
|
||||
- [SuperOffload:释放超级芯片上大规模LLM训练的潜力](#superoffload释放超级芯片上大规模llm训练的潜力)
|
||||
- [SuperOffload的亮点](#superoffload的亮点)
|
||||
- [介绍](#介绍)
|
||||
- [SuperOffload的工作原理](#superoffload的工作原理)
|
||||
- [1. 推测验证机制(STV)](#1-推测验证机制stv)
|
||||
- [2. 异构优化器计算](#2-异构优化器计算)
|
||||
- [3. 超级芯片感知的类型转换](#3-超级芯片感知的类型转换)
|
||||
- [4. GraceAdam:提升优化器效率](#4-graceadam提升优化器效率)
|
||||
- [经验与洞察](#经验与洞察)
|
||||
- [快速使用指南](#快速使用指南)
|
||||
- [致谢](#致谢)
|
||||
|
||||
---
|
||||
|
||||
## SuperOffload的亮点
|
||||
|
||||
- 在**一块GH200**上能够对GPT-OSS-20B和Qwen3-14B进行全参数微调,达到600TFLOPS的运算速度(Seqlen=4K,BS=4)。
|
||||
- **多卡训练**:在两块英伟达GH200上训练Qwen3-30B-A3B和Seed-OSS-36B,在四块英伟达GH200上训练Llama-70B。
|
||||
- **训练速度**:在合理的设置下,比ZeRO-Offload快四倍的训练吞吐量。
|
||||
- **提高显卡利用率**:将显卡利用率从约50%提高到大于80%。
|
||||
- **灵活组合性**:支持ZeRO-3和Ulysses;一些操作技巧如NUMA绑定和MPAM等已在教程中详细说明。
|
||||
|
||||
---
|
||||
|
||||
## 介绍
|
||||
|
||||
紧密耦合的异构GPU/CPU架构(又称超级芯片)的出现,例如NVIDIA GH200、GB200和AMD MI300A,为大规模AI提供了新的优化机遇。然而,如何充分利用这些新硬件进行大规模LLM训练仍处于探索不足的状态。现有的offloading解决方案是为传统松散耦合架构设计的,在超级芯片上表现欠佳,存在高开销和低GPU利用率的问题。为弥补这一空白并充分利用超级芯片实现高效LLM训练,我们开发并开源了**SuperOffload**。
|
||||
|
||||
SuperOffload引入了一系列创新技术,可同时充分利用Hopper GPU、Grace CPU和NVLink-C2C进行LLM训练。与先前假设GPU-CPU互连速度较慢(如PCIe-Gen4的64GB/秒)的offloading解决方案不同,SuperOffload利用更高速的互连技术(如NVLink-C2C的900GB/秒)来提升GPU和CPU利用率及训练吞吐量。借助SuperOffload,诸如**GPT-OSS-20B**、**Qwen3-14B**和**Phi-4**等模型可在单台GH200上完成全参数微调,在常规设置下(序列长度4k,批次大小4)实现高达**600 TFLOPS**的训练吞吐量。与ZeRO-Offload等先前工作相比,此举可实现高达**4倍**的吞吐量提升。SuperOffload还能支持扩展至更大模型,包括在两台GH200上运行Qwen3-30B-A3B和Seed-OSS-36B,以及在四台GH200上运行Llama-70B。
|
||||
|
||||
SuperOffload构建于DeepSpeed ZeRO Stage 3之上,并在DeepSpeed [0.18.0]((https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.18.0)及以上版本中提供。为便于集成到LLM微调流程中,SuperOffload与Hugging Face Transformers兼容,且无需对模型代码进行任何修改。
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/superoffload_comparison.jpg" alt="SuperOffload system overview" width="90%">
|
||||
<p align="center"><em>图1:在不同序列长度和批次大小的大型模型微调中,SuperOffload相比ZeRO-Offload可实现高达4倍的吞吐量提升,最高达到600 TFLOPS的吞吐量。</em></p>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## SuperOffload的工作原理
|
||||
|
||||
SuperOffload包含四项可组合的offloading优化技术:(1) 推测验证机制,(2) GPU/CPU优化器计算,(3) 超级芯片感知的类型转换,以及(4) GraceAdam优化器。以下我们将简要介绍这些技术。
|
||||
|
||||
|
||||
### 1. 推测验证机制(STV)
|
||||
|
||||
在大多数offloading解决方案中,优化器步骤需要CPU和GPU之间的同步以确保数值鲁棒性。例如,梯度norm裁剪需要计算全局梯度norm,混合精度训练需要全局检查NaN和INF值。这些操作要求CPU等待直到收到所有梯度后才能执行优化器步骤和权重更新。STV通过打破这种依赖性来避免此瓶颈,同时通过将CPU上的推测性优化器计算与GPU上的反向传播重叠来保持训练语义。当梯度后处理最终完成时,推测性优化器计算会根据情况被提交、丢弃或正确重放。STV对训练稳定性的后验证使其能够相比先前的前验证方法安全地缩短关键路径。下图展示了SuperOffload如何以不同于传统方法(如ZeRO-Offload)的方式调度反向传播和优化器计算。
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/superoffload_schedule.jpg" alt="Schedule comparison" width="80%">
|
||||
<p align="center"><em>图2:以往的offloading方法受限于全局梯度范数计算及全局NaN/INF值检查,导致优化器步骤暴露在关键路径中且无法实现计算重叠。SuperOffload通过引入推测验证调度机制来解决这一问题。</em></p>
|
||||
</div>
|
||||
|
||||
我们通过测量BLOOM-176B模型预训练过程中推测性优化器计算被撤销的频率来评估STV的有效性。如下图所示,这类回滚(例如由于梯度裁剪等原因引起)在预热阶段后很少发生,使得相关开销在整个训练过程中可忽略不计。这使得STV在加速大规模训练方面具有实用性。
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/superoffload_rollback.jpg" alt="Gradient clipping data" width="80%">
|
||||
<p align="center"><em>图3:红色数据点表示BLOOM预训练过程中触发梯度裁剪的时刻——在预热阶段后极少出现,这表明SuperOffload的STV机制有效消除了由梯度裁剪和NaN/INF检查引起的同步停顿。
|
||||
</em></p>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
### 2. 异构优化器计算
|
||||
|
||||
SuperOffload通过将优化器计算分区到GPU和CPU上来提升STV之外的优化器效率。GPU用于处理反向传播后期阶段产生的梯度对应的优化器计算,而CPU则负责其余部分。这种分区方案具有多重优势:首先,GPU无需闲置等待CPU完成优化器计算;其次,通过同时利用GPU和CPU的计算资源减少了优化器计算时间;第三,避免了与GPU优化器计算对应的参数和梯度在GPU-CPU间的传输。
|
||||
|
||||
---
|
||||
|
||||
### 3. 超级芯片感知的类型转换
|
||||
|
||||
在采用offloading的混合精度训练中,GPU与CPU之间的张量传输需要在GPU低精度格式(如BF16、FP16等)与CPU高精度格式(即FP32)间进行类型转换。为应对PCIe互连的带宽限制,先前的offloading解决方案采用低精度传输张量,并在GPU和CPU上适时进行类型转换。然而这在超级芯片架构中并非最优策略,因为GPU计算吞吐量约为CPU的100倍,而高带宽互连(如NVLink-C2C)使得传输成本可忽略不计。如图4所示,GH200上的最优策略是在GPU上进行张量类型转换并采用高精度格式传输。
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/superoffload_cast_transfer.jpg" alt="Tensor casting optimization" width="80%">
|
||||
<p align="center"><em>图4:GH200:在超级芯片上,通过GPU进行张量高低精度转换并以高精度格式传输更为高效。</em></p>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
### 4. GraceAdam:提升优化器效率
|
||||
|
||||
现有用于LLM训练的offloading解决方案需要流行Adam优化器(如PyTorch Adam和DeepSpeed CPU-Adam)的CPU实现版本。然而这些实现并不适用于超级芯片,因为它们未针对Grace CPU架构进行优化。为解决此问题,我们创建了GraceAdam——专为Grace CPU设计的高效Adam优化器实现。GraceAdam通过利用底层ARM架构特性(如可扩展向量扩展SVE、显式内存层次管理和指令级并行)实现高性能。图5显示在GH200超级芯片上,GraceAdam比PyTorch Adam快3倍,比CPU-Adam快1.3倍。据我们所知,GraceAdam是首个面向Grace CPU开源的Adam优化器实现。
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/superoffload_grace_adam.png" alt="GraceAdam" width="80%">
|
||||
<p align="center"><em>图5:使用GraceAdam在GH200上实现高效Adam优化器计算。</em></p>
|
||||
</div>
|
||||
|
||||
|
||||
## 经验与洞察
|
||||
|
||||
- **NUMA绑定:**
|
||||
将每个GPU与其直接关联的CPU进行配对以最大化带宽。在DeepSpeed中:
|
||||
```bash
|
||||
--bind_cores_to_rank
|
||||
```
|
||||
|
||||
- **MPAM(内存系统资源分区与监控):**
|
||||
减少CPU与GPU任务间的相互干扰。
|
||||
|
||||
**如何在NVIDIA超级芯片上启用MPAM**
|
||||
1. 安装[NVIDIA NV-Kernels](https://github.com/NVIDIA/NV-Kernels/tree/24.04_linux-nvidia-adv-6.11)提供的内核。
|
||||
2. 检查MPAM支持情况:
|
||||
```bash
|
||||
grep MPAM /boot/config-$(uname -r)
|
||||
```
|
||||
预期输出:
|
||||
```
|
||||
CONFIG_ARM64_MPAM=y
|
||||
CONFIG_ACPI_MPAM=y
|
||||
CONFIG_ARM64_MPAM_DRIVER=y
|
||||
CONFIG_ARM64_MPAM_RESCTRL_FS=y
|
||||
```
|
||||
检查resctrl文件系统:
|
||||
```bash
|
||||
ls -ld /sys/fs/resctrl
|
||||
```
|
||||
3. 挂载resctrl:
|
||||
```bash
|
||||
mount -t resctrl resctrl /sys/fs/resctrl
|
||||
```
|
||||
4. 建立分区:
|
||||
```bash
|
||||
mkdir /sys/fs/resctrl/p1 /sys/fs/resctrl/p2
|
||||
```
|
||||
5. 设定CPU内核与内存配置:
|
||||
```bash
|
||||
/sys/fs/resctrl/p1/cpus_list:
|
||||
0-6
|
||||
/sys/fs/resctrl/p2/cpus_list:
|
||||
7-71
|
||||
/sys/fs/resctrl/p1/schemata:
|
||||
MB:1=100
|
||||
L3:1=ff0
|
||||
/sys/fs/resctrl/p2/schemata:
|
||||
MB:1=20
|
||||
L3:1=f
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 快速使用指南
|
||||
|
||||
我们已在教程/说明文档[DeepSpeedExamples: SuperOffload](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/DeepSpeed-SuperOffload#readme)中提供了SuperOffload的端到端微调示例。请在DeepSpeed配置中添加以下开关(完整上下文请参阅教程):
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/superoffload_enable.jpg" alt="Enable SuperOffload" width="60%">
|
||||
<p align="center"><em>图6:通过在DeepSpeed配置中添加单行代码即可启用SuperOffload。</em></p>
|
||||
</div>
|
||||
|
||||
提示:在超级芯片平台(如GH200/GB200/MI300A)上,结合"经验与洞察"章节中的NUMA绑定与MPAM设置,可稳定带宽并提升端到端性能。
|
||||
|
||||
---
|
||||
|
||||
## 致谢
|
||||
|
||||
本成果由[University of Illinois Urbana-Champaign (UIUC)](https://supercomputing-system-ai-lab.github.io/), [Anyscale](https://www.anyscale.com/)与[Snowflake](https://www.snowflake.com/en/blog/authors/snowflake-ai-research/)紧密协作完成。
|
||||
|
||||
我们同时衷心感谢美国国家超级计算应用中心的William Gropp、Brett Bode和Gregory H. Bauer,以及NVIDIA的Dan Ernst、Ian Karlin、Giridhar Chukkapalli、Kurt Rago等专家就Grace CPU的MPAM支持提供的宝贵讨论与指导。
|
||||
|
||||
欢迎社区反馈与贡献。具体启用方法与示例请参阅前文「快速开始」章节。
|
||||
|
||||
---
|
||||
|
||||
## BibTeX <!-- omit in toc -->
|
||||
|
||||
```bibtex
|
||||
@inproceedings{superoffload,
|
||||
author = {Xinyu Lian and Masahiro Tanaka and Olatunji Ruwase and Minjia Zhang},
|
||||
title = "{SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips}",
|
||||
year = {2026},
|
||||
booktitle = {Proceedings of the 31st ACM International Conference on Architectural Support for Programming Languages and Operating System (ASPLOS'26)}
|
||||
}
|
||||
```
|
@ -16,6 +16,7 @@ try:
|
||||
import torch._dynamo
|
||||
from functorch.compile import make_boxed_func
|
||||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
except ImportError:
|
||||
@ -367,17 +368,16 @@ def make_backend(backend, compile_config, compile_kwargs={}):
|
||||
|
||||
return compiler_fn
|
||||
|
||||
partition_fn = get_wrapped_partitioner(z3_partition, param_indices, min_cut_rematerialization_partition)
|
||||
aot_mod = aot_module_simplified(gm,
|
||||
real_inputs,
|
||||
fw_compiler=make_compiler_fn(make_fw_graph),
|
||||
bw_compiler=make_compiler_fn(make_bw_graph),
|
||||
partition_fn=get_wrapped_partitioner(param_indices))
|
||||
partition_fn=partition_fn)
|
||||
return torch._dynamo.optimize(**compile_kwargs)(aot_mod)
|
||||
elif backend == "inductor":
|
||||
patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs,
|
||||
param_indices, param_manager)
|
||||
from .partitioner import get_wrapped_choose_saved_values_set
|
||||
torch._functorch.partitioners.choose_saved_values_set = get_wrapped_choose_saved_values_set(param_indices)
|
||||
|
||||
return torch._inductor.compile(gm, real_inputs)
|
||||
|
||||
|
@ -20,6 +20,7 @@ except ImportError:
|
||||
from deepspeed.utils.torch import required_torch_version
|
||||
from .util import get_input_nodes
|
||||
from .graph_param import DSGraphParamManager
|
||||
from .partitioner import get_wrapped_partitioner
|
||||
|
||||
|
||||
def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):
|
||||
@ -66,7 +67,8 @@ def wrap_partition_fn(partition_fn, real_inputs, param_indices):
|
||||
|
||||
def wrapped_partition_fn(*args, **kwargs):
|
||||
|
||||
fw_module, bw_module = partition_fn(*args, **kwargs)
|
||||
fn = get_wrapped_partitioner(True, param_indices, partition_fn=partition_fn)
|
||||
fw_module, bw_module = fn(*args, **kwargs)
|
||||
|
||||
# get parameter names
|
||||
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)
|
||||
|
@ -3,156 +3,94 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
# This file was copied from PyTorch and modified for DeepSpeed.
|
||||
|
||||
from typing import Tuple, List
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Graph, Node
|
||||
|
||||
try:
|
||||
from torch._functorch.partitioners import is_sym_node, _is_primal, _is_fwd_seed_offset, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, _extract_fwd_bwd_modules, has_recomputable_ops, min_cut_rematerialization_partition, choose_saved_values_set
|
||||
from torch.utils.checkpoint import CheckpointPolicy
|
||||
from torch._functorch.partitioners import _is_primal
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from .util import get_no_copy_ops
|
||||
|
||||
_recompute_ops = {torch.ops.aten.t.default}
|
||||
from .util import get_no_copy_ops, is_cast_op
|
||||
|
||||
|
||||
def _find_recompute_nodes(graph: Graph, ds_param_node: Node) -> List[Node]:
|
||||
"""
|
||||
Given a graph and a node that represents a parameter that was allgathered,
|
||||
find all nodes that use the parameter and require recomputation.
|
||||
def _recompute_param_aliases(joint_graph: Graph, param_indices: List[Tuple[int, int, torch.Size]]):
|
||||
"""Recompute nodes aliasing or downcasting any parameter
|
||||
|
||||
In ZeRO3, sharded parameters are gathered before use and the gathered
|
||||
parameters should be freed once they are no longer needed to save GPU
|
||||
memory.
|
||||
|
||||
When DeepCompile is active for ZeRO3, parameter gathering is done by custom
|
||||
passes after the joint graph captured by Dynamo and AOT Autograd is
|
||||
partitioned into fwd and bwd parts. Since the partitioner has no clue about
|
||||
parameter sharding now, the partitioned graphs will save for backward all
|
||||
intermediate activations including those aliasing the gathered parameters.
|
||||
That essentially nullifies the memory reduction that ZeRO3 is designed to
|
||||
bring.
|
||||
|
||||
The solution is to recompute the parameter-aliasing activations in the
|
||||
backward. It is done by marking such nodes as MUST_RECOMPUTE and reusing the
|
||||
min-cut partitioner originally designed for checkpointing. If autocast is
|
||||
enabled, parameter downcasts are also recomputed.
|
||||
|
||||
This cannot be converted to a standalone pass because it must be applied
|
||||
before partitioning the joint graph, but passes run after the partitioning.
|
||||
|
||||
TODO(eternalNight) `min_cut_rematerialization_partition` may recompute more
|
||||
nodes than required for ZeRO3. Need investigate its performance
|
||||
implications.
|
||||
"""
|
||||
no_copy_ops = get_no_copy_ops()
|
||||
recompute_nodes = set()
|
||||
for node in graph.nodes:
|
||||
if node.target in no_copy_ops:
|
||||
if ds_param_node in node.args:
|
||||
recompute_nodes.add(node)
|
||||
if any(a in recompute_nodes for a in node.args):
|
||||
recompute_nodes.add(node)
|
||||
|
||||
return recompute_nodes
|
||||
def need_recompute(n: Node) -> bool:
|
||||
if n.op == "call_function":
|
||||
is_cast, _ = is_cast_op(n)
|
||||
return n.target in no_copy_ops or is_cast
|
||||
return False
|
||||
|
||||
|
||||
def _get_values_from_ds_params(joint_graph, param_indices):
|
||||
primal_inputs = list(filter(_is_primal, joint_graph.nodes))
|
||||
ds_param_inputs = [primal_inputs[arg_idx] for arg_idx, _, _ in param_indices]
|
||||
|
||||
no_copy_ops = get_no_copy_ops()
|
||||
|
||||
ds_param_inputs = set(ds_param_inputs)
|
||||
ds_param_users = {}
|
||||
ds_param_inputs = set([primal_inputs[arg_idx] for arg_idx, _, _ in param_indices])
|
||||
recomputed_nodes = set()
|
||||
|
||||
for node in joint_graph.nodes:
|
||||
if node.target in no_copy_ops and any((a in ds_param_inputs or a in ds_param_users) for a in node.args):
|
||||
for a in node.args:
|
||||
if a in ds_param_inputs:
|
||||
ds_param_users[node] = a
|
||||
elif a in ds_param_users:
|
||||
ds_param_users[node] = ds_param_users[a]
|
||||
# The `ac_graph_id` tag tracks the checkpoint module that a node belongs
|
||||
# to, and is for enforcing the saving of activations at the boundary of
|
||||
# consecutive checkpointed blocks. It starts from 1 and increments by 1
|
||||
# each time a graph module is checkpointed.
|
||||
#
|
||||
# `min_cut_rematerialization_partition` requires every node to have
|
||||
# `ac_graph_id`. If this graph is not checkpointed (and thus
|
||||
# `ac_graph_id` is missing), we tag all nodes to 1 to prevent the
|
||||
# partition function from modifying the recompute tag.
|
||||
node.meta.setdefault("ac_graph_id", 1)
|
||||
|
||||
return ds_param_users
|
||||
# Arguments can be non-tensor types some of which are not hashable. So
|
||||
# we must inspect the type of an argument before checking if it is in
|
||||
# any set.
|
||||
if need_recompute(node) and \
|
||||
any([(isinstance(a, Node) and (a in ds_param_inputs or a in recomputed_nodes)) for a in node.args]):
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
|
||||
recomputed_nodes.add(node)
|
||||
else:
|
||||
# If checkpointing is not enabled for this graph, assume all
|
||||
# activations required by the backward pass should be saved.
|
||||
node.meta.setdefault("recompute", CheckpointPolicy.MUST_SAVE)
|
||||
|
||||
|
||||
def get_wrapped_choose_saved_values_set(param_indices: List[Tuple[int, int, torch.Size]]):
|
||||
|
||||
def ds_choose_saved_values_set(joint_graph: torch.fx.Graph, node_info, memory_budget=1) -> List[Node]:
|
||||
saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget)
|
||||
ds_param_users = _get_values_from_ds_params(joint_graph, param_indices)
|
||||
|
||||
new_saved_values = []
|
||||
for v in saved_values:
|
||||
if v in ds_param_users:
|
||||
ds_val = ds_param_users[v]
|
||||
if ds_val not in new_saved_values:
|
||||
new_saved_values.append(ds_val)
|
||||
else:
|
||||
new_saved_values.append(v)
|
||||
|
||||
return new_saved_values
|
||||
|
||||
return ds_choose_saved_values_set
|
||||
|
||||
|
||||
def get_wrapped_partitioner(param_indices: List[Tuple[int, int, torch.Size]]):
|
||||
def get_wrapped_partitioner(
|
||||
z3_partition: bool,
|
||||
param_indices: List[Tuple[int, int, torch.Size]],
|
||||
partition_fn,
|
||||
):
|
||||
|
||||
def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *,
|
||||
num_fwd_outputs) -> Tuple[GraphModule, GraphModule]:
|
||||
"""
|
||||
This is basically the same as the default_partition function, but
|
||||
it doesn't save the gathered params and values computed from them.
|
||||
"""
|
||||
if has_recomputable_ops(joint_module):
|
||||
return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs, "forward")
|
||||
forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != "output"}
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
fwd_inputs = list(filter(_is_primal, forward_only_graph.nodes))
|
||||
ds_param_inputs = [fwd_inputs[arg_idx] for arg_idx, _, _ in param_indices]
|
||||
ds_param_input_names = {node.name for node in ds_param_inputs}
|
||||
|
||||
ds_param_recompute_nodes = set()
|
||||
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.name not in forward_node_names:
|
||||
continue
|
||||
|
||||
if is_sym_node(node):
|
||||
# Symints must be kept separate from tensors so that PythonFunction only calls
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
saved_sym_nodes.append(node)
|
||||
elif "tensor_meta" not in node.meta and node.op == "call_function":
|
||||
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
|
||||
users = node.users
|
||||
assert all(user.target == operator.getitem for user in users)
|
||||
saved_values.extend(users)
|
||||
else:
|
||||
backward_usages = [n for n in node.users if n.name not in forward_node_names]
|
||||
|
||||
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
|
||||
if node.name in ds_param_input_names:
|
||||
saved_values.append(node)
|
||||
recompute_nodes = _find_recompute_nodes(joint_module.graph, node)
|
||||
recompute_nodes = [n for n in recompute_nodes if n.name in forward_node_names]
|
||||
for recompute_node in recompute_nodes:
|
||||
ds_param_recompute_nodes.add(recompute_node)
|
||||
|
||||
if len(recompute_nodes) > 0:
|
||||
saved_values.append(node)
|
||||
else:
|
||||
if node not in ds_param_recompute_nodes:
|
||||
saved_values.append(node)
|
||||
saved_values = list(dict.fromkeys(saved_values).keys())
|
||||
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
||||
|
||||
f_gm, b_gm = _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
)
|
||||
|
||||
return f_gm, b_gm
|
||||
if z3_partition:
|
||||
_recompute_param_aliases(joint_module.graph, param_indices)
|
||||
return partition_fn(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
|
||||
|
||||
return partition_recompute_ds_params
|
||||
|
@ -76,6 +76,7 @@ from deepspeed.runtime.sparse_tensor import SparseTensor
|
||||
from deepspeed.runtime import lr_schedules
|
||||
from deepspeed.utils import groups
|
||||
from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx
|
||||
from deepspeed.utils.z3_leaf_module import apply_zero_leaf_module_config
|
||||
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
|
||||
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
|
||||
STEP_MICRO_TIMER, \
|
||||
@ -1293,6 +1294,7 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
def _configure_distributed_model(self, model):
|
||||
self._set_client_model(model)
|
||||
apply_zero_leaf_module_config(self.module, getattr(self._config.zero_config, "leaf_module", None))
|
||||
is_zero_init_model = self.zero_optimization_partition_weights() and any(
|
||||
[hasattr(param, "ds_id") for param in self.module.parameters()])
|
||||
|
||||
|
@ -11,6 +11,7 @@ from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedCo
|
||||
from deepspeed.utils import logger
|
||||
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
|
||||
from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig
|
||||
from .leaf_module_config import DeepSpeedZeroLeafModuleConfig
|
||||
|
||||
# ZeRO optimization. By default, this optimization is not enabled.
|
||||
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
|
||||
@ -356,6 +357,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
||||
Enable internal sanity checks, which could be useful for debugging
|
||||
"""
|
||||
|
||||
leaf_module: DeepSpeedZeroLeafModuleConfig = Field(default_factory=DeepSpeedZeroLeafModuleConfig)
|
||||
"""
|
||||
Configuration for modules that should be treated as ZeRO3 leaf modules.
|
||||
"""
|
||||
|
||||
# Validators
|
||||
@model_validator(mode="after")
|
||||
def overlap_comm_valid(self):
|
||||
|
52
deepspeed/runtime/zero/leaf_module_config.py
Normal file
52
deepspeed/runtime/zero/leaf_module_config.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
|
||||
DEFAULT_LEAF_MODULE_CLASSES: List[str] = [
|
||||
"transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock",
|
||||
"transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock",
|
||||
"transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock",
|
||||
]
|
||||
DEFAULT_LEAF_MODULE_NAMES: List[str] = []
|
||||
DEFAULT_LEAF_MODULE_NAME_SUFFIXES: List[str] = []
|
||||
|
||||
|
||||
class DeepSpeedZeroLeafModuleConfig(DeepSpeedConfigModel):
|
||||
"""Configuration for ZeRO leaf modules that should bypass hook installation."""
|
||||
|
||||
classes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_CLASSES))
|
||||
names: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAMES))
|
||||
name_suffixes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAME_SUFFIXES))
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _coerce_container_types(cls, values):
|
||||
if values is None:
|
||||
return {}
|
||||
if isinstance(values, dict):
|
||||
coerced = dict(values)
|
||||
for key in ("classes", "names", "name_suffixes"):
|
||||
if key in coerced and isinstance(coerced[key], str):
|
||||
coerced[key] = [coerced[key]]
|
||||
return coerced
|
||||
raise TypeError("leaf_module configuration must be a mapping of fields to values")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_entries(self):
|
||||
normalized_classes = [str(cls) for cls in self.classes]
|
||||
normalized_names = [str(name) for name in self.names]
|
||||
normalized_suffixes = [str(suffix) for suffix in self.name_suffixes]
|
||||
|
||||
deduped_classes = list(dict.fromkeys(normalized_classes))
|
||||
deduped_names = list(dict.fromkeys(normalized_names))
|
||||
deduped_suffixes = list(dict.fromkeys(normalized_suffixes))
|
||||
|
||||
object.__setattr__(self, "classes", deduped_classes)
|
||||
object.__setattr__(self, "names", deduped_names)
|
||||
object.__setattr__(self, "name_suffixes", deduped_suffixes)
|
||||
return self
|
@ -17,7 +17,7 @@ from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_s
|
||||
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
|
||||
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
|
||||
from .tensor_fragment import safe_update_full_grad_vectorized
|
||||
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module
|
||||
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module, set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix
|
||||
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
|
||||
from deepspeed.runtime.dataloader import RepeatingLoader
|
||||
from .numa import get_numactl_cmd
|
||||
|
@ -83,19 +83,20 @@ def print_configuration(args, name):
|
||||
logger.info(" {} {} {}".format(arg, dots, getattr(args, arg)))
|
||||
|
||||
|
||||
def log_dist(message, ranks=None, level=logging.INFO, use_logger=True):
|
||||
def get_dist_msg(message, ranks=None):
|
||||
from deepspeed import comm as dist
|
||||
"""Log message when one of following condition meets
|
||||
"""Return a message with rank prefix when one of following conditions is met:
|
||||
|
||||
+ not dist.is_initialized()
|
||||
+ dist.get_rank() in ranks if ranks is not None or ranks = [-1]
|
||||
+ not dist.is_initialized()
|
||||
+ dist.get_rank() in ranks if ranks is not None or ranks = [-1]
|
||||
|
||||
If neither is met, `None` is returned.
|
||||
|
||||
Example: "hello" => "[Rank 0] hello"
|
||||
|
||||
Args:
|
||||
message (str)
|
||||
ranks (list)
|
||||
level (int)
|
||||
use_logger (bool): if `False` ignores the log-levels and always prints
|
||||
|
||||
"""
|
||||
should_log = not dist.is_initialized()
|
||||
ranks = ranks or []
|
||||
@ -104,11 +105,36 @@ def log_dist(message, ranks=None, level=logging.INFO, use_logger=True):
|
||||
should_log = ranks[0] == -1
|
||||
should_log = should_log or (my_rank in set(ranks))
|
||||
if should_log:
|
||||
final_message = "[Rank {}] {}".format(my_rank, message)
|
||||
if use_logger:
|
||||
logger.log(level, final_message)
|
||||
else:
|
||||
print(final_message)
|
||||
return "[Rank {}] {}".format(my_rank, message)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def log_dist(message, ranks=None, level=logging.INFO):
|
||||
"""Log message when get_dist_msg() deems it should be logged, see its docstring for details.
|
||||
|
||||
Args:
|
||||
message (str)
|
||||
ranks (list)
|
||||
level (int)
|
||||
"""
|
||||
final_message = get_dist_msg(message, ranks)
|
||||
if final_message is not None:
|
||||
logger.log(level, final_message)
|
||||
|
||||
|
||||
def print_dist(message, ranks=None):
|
||||
"""print message when get_dist_msg() deems it should be logged, see its docstring for details.
|
||||
|
||||
Use this function instead of `log_dist` when the log level shouldn't impact whether the message should be printed or not.
|
||||
|
||||
Args:
|
||||
message (str)
|
||||
ranks (list)
|
||||
"""
|
||||
final_message = get_dist_msg(message, ranks)
|
||||
if final_message is not None:
|
||||
print(final_message)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
import time
|
||||
from numpy import mean
|
||||
from deepspeed.utils.logging import log_dist
|
||||
from deepspeed.utils.logging import print_dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
FORWARD_MICRO_TIMER = 'fwd_microstep'
|
||||
@ -149,7 +149,7 @@ class SynchronizedWallClockTimer:
|
||||
string += " | {}: {:.2f}".format(name, elapsed_time)
|
||||
|
||||
# timers logging should be independent of the global log level it's already conditional on wall_clock_breakdown being True, so using use_logger=False will always print the stats
|
||||
log_dist(string, ranks=ranks or [0], use_logger=False)
|
||||
print_dist(string, ranks=ranks or [0])
|
||||
|
||||
def get_mean(self, names, normalizer=1.0, reset=True):
|
||||
"""Get the mean of a group of timers."""
|
||||
|
@ -4,7 +4,12 @@
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
from typing import List, Type, Union
|
||||
from typing import List, Tuple, Type, Union, Optional, TYPE_CHECKING
|
||||
|
||||
from .logging import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepspeed.runtime.zero.leaf_module_config import DeepSpeedZeroLeafModuleConfig
|
||||
|
||||
|
||||
def z3_leaf_module(model: torch.nn.Module) -> bool:
|
||||
@ -44,50 +49,201 @@ def set_z3_leaf_module(model: torch.nn.Module, flag: bool):
|
||||
model._z3_leaf = flag
|
||||
|
||||
|
||||
def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]],
|
||||
flag: bool) -> List[torch.nn.Module]:
|
||||
assert all(isinstance(module_class, (type, str) ) for module_class in leaf_module_classes), \
|
||||
def _fully_qualified_class_name(module: torch.nn.Module) -> str:
|
||||
cls = module.__class__
|
||||
return f"{cls.__module__}.{cls.__qualname__}"
|
||||
|
||||
|
||||
def _do_set_z3_leaf_modules(model: torch.nn.Module,
|
||||
leaf_module_classes: Union[List[Type], List[str]],
|
||||
flag: bool,
|
||||
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
|
||||
assert all(isinstance(module_class, (type, str)) for module_class in leaf_module_classes), \
|
||||
f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}'
|
||||
|
||||
leaf_modules = []
|
||||
leaf_modules: List[torch.nn.Module] = []
|
||||
|
||||
def _set_z3_leaf_flag(model: torch.nn.Module):
|
||||
def _set_z3_leaf_flag(module_instance: torch.nn.Module):
|
||||
nonlocal leaf_modules
|
||||
for module in leaf_module_classes:
|
||||
if (isinstance(module, type) and model.__class__ == module) or \
|
||||
(isinstance(module, str) and model.__class__.__name__ == module):
|
||||
model._z3_leaf = flag
|
||||
leaf_modules.append(model)
|
||||
if isinstance(module, type) and isinstance(module_instance, module):
|
||||
module_instance._z3_leaf = flag
|
||||
leaf_modules.append(module_instance)
|
||||
break
|
||||
|
||||
if isinstance(module, str):
|
||||
if (module_instance.__class__.__name__ == module
|
||||
or _fully_qualified_class_name(module_instance) == module):
|
||||
module_instance._z3_leaf = flag
|
||||
leaf_modules.append(module_instance)
|
||||
break
|
||||
|
||||
model.apply(_set_z3_leaf_flag)
|
||||
|
||||
if len(leaf_modules) == 0:
|
||||
if len(leaf_modules) == 0 and raise_if_not_found:
|
||||
raise ValueError(f'No modules of type {leaf_module_classes} found in model {model}')
|
||||
|
||||
return leaf_modules
|
||||
|
||||
|
||||
def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type],
|
||||
List[str]]) -> List[torch.nn.Module]:
|
||||
def set_z3_leaf_modules_by_name(model: torch.nn.Module,
|
||||
module_names: List[str],
|
||||
flag: bool = True,
|
||||
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
|
||||
"""Sets a leaf flag for modules referenced by their names in ``model.named_modules()``.
|
||||
Args:
|
||||
model (torch.nn.Module): The model containing the modules to update.
|
||||
module_names (List[str]): Module names as returned by ``named_modules()``.
|
||||
flag (bool): Desired flag state.
|
||||
raise_if_not_found (bool): Whether to raise when no module matches a provided name.
|
||||
Returns:
|
||||
Tuple[List[torch.nn.Module], List[str]]: Matched modules and missing module names.
|
||||
"""
|
||||
modules_by_name = dict(model.named_modules())
|
||||
leaf_modules: List[torch.nn.Module] = []
|
||||
missing: List[str] = []
|
||||
|
||||
for name in module_names:
|
||||
module = modules_by_name.get(name)
|
||||
if module is None:
|
||||
missing.append(name)
|
||||
continue
|
||||
module._z3_leaf = flag
|
||||
leaf_modules.append(module)
|
||||
|
||||
if missing and raise_if_not_found:
|
||||
raise ValueError(f'No modules named {missing} found in model {model}')
|
||||
|
||||
return leaf_modules, missing
|
||||
|
||||
|
||||
def set_z3_leaf_modules_by_suffix(model: torch.nn.Module,
|
||||
module_name_suffixes: List[str],
|
||||
flag: bool = True,
|
||||
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
|
||||
"""Sets a leaf flag for modules referenced by suffixes of ``model.named_modules()`` names."""
|
||||
modules_by_name = dict(model.named_modules())
|
||||
leaf_modules: List[torch.nn.Module] = []
|
||||
missing: List[str] = []
|
||||
seen_ids = set()
|
||||
|
||||
for suffix in module_name_suffixes:
|
||||
matched = False
|
||||
for name, module in modules_by_name.items():
|
||||
if name.endswith(suffix):
|
||||
module._z3_leaf = flag
|
||||
module_id = id(module)
|
||||
if module_id not in seen_ids:
|
||||
seen_ids.add(module_id)
|
||||
leaf_modules.append(module)
|
||||
matched = True
|
||||
if not matched:
|
||||
missing.append(suffix)
|
||||
|
||||
if missing and raise_if_not_found:
|
||||
raise ValueError(f'No modules matching suffixes {missing} found in model {model}')
|
||||
|
||||
return leaf_modules, missing
|
||||
|
||||
|
||||
def set_z3_leaf_modules(model: torch.nn.Module,
|
||||
leaf_module_classes: Union[List[Type], List[str]],
|
||||
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
|
||||
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
|
||||
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
|
||||
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
|
||||
Args:
|
||||
model (torch.nn.Module): The model to which the leaf module flag will be applied.
|
||||
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
|
||||
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
|
||||
match a module inside ``model``.
|
||||
Returns:
|
||||
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
|
||||
"""
|
||||
return _do_set_z3_leaf_modules(model, leaf_module_classes, True)
|
||||
return _do_set_z3_leaf_modules(model, leaf_module_classes, True, raise_if_not_found)
|
||||
|
||||
|
||||
def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]:
|
||||
def unset_z3_leaf_modules(model: torch.nn.Module,
|
||||
leaf_module_classes: List[Type],
|
||||
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
|
||||
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
|
||||
See `set_z3_leaf_modules` for more details.
|
||||
Args:
|
||||
model (torch.nn.Module): The model to which the leaf module flag will be applied.
|
||||
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
|
||||
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
|
||||
match a module inside ``model``.
|
||||
Returns:
|
||||
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
|
||||
"""
|
||||
return _do_set_z3_leaf_modules(model, leaf_module_classes, False)
|
||||
return _do_set_z3_leaf_modules(model, leaf_module_classes, False, raise_if_not_found)
|
||||
|
||||
|
||||
def apply_zero_leaf_module_config(model: torch.nn.Module,
|
||||
leaf_cfg: Optional["DeepSpeedZeroLeafModuleConfig"]) -> List[torch.nn.Module]:
|
||||
"""Apply ZeRO leaf module configuration to ``model``.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Root module to update.
|
||||
leaf_cfg (DeepSpeedZeroLeafModuleConfig | None): Parsed configuration. If ``None``
|
||||
no changes are applied.
|
||||
|
||||
Returns:
|
||||
List[torch.nn.Module]: Modules flagged as leaves.
|
||||
"""
|
||||
if leaf_cfg is None:
|
||||
return []
|
||||
|
||||
from deepspeed.runtime.zero.leaf_module_config import (
|
||||
DEFAULT_LEAF_MODULE_CLASSES,
|
||||
DEFAULT_LEAF_MODULE_NAMES,
|
||||
DEFAULT_LEAF_MODULE_NAME_SUFFIXES,
|
||||
)
|
||||
|
||||
matched_modules: List[torch.nn.Module] = []
|
||||
matched_ids = set()
|
||||
|
||||
customized_classes = leaf_cfg.classes != DEFAULT_LEAF_MODULE_CLASSES
|
||||
customized_names = leaf_cfg.names != DEFAULT_LEAF_MODULE_NAMES
|
||||
customized_suffixes = leaf_cfg.name_suffixes != DEFAULT_LEAF_MODULE_NAME_SUFFIXES
|
||||
|
||||
if leaf_cfg.classes:
|
||||
class_matched = set_z3_leaf_modules(model, leaf_cfg.classes, raise_if_not_found=False)
|
||||
for module in class_matched:
|
||||
module_id = id(module)
|
||||
if module_id not in matched_ids:
|
||||
matched_ids.add(module_id)
|
||||
matched_modules.append(module)
|
||||
|
||||
if leaf_cfg.names:
|
||||
name_matched, missing_names = set_z3_leaf_modules_by_name(model,
|
||||
leaf_cfg.names,
|
||||
flag=True,
|
||||
raise_if_not_found=False)
|
||||
for module in name_matched:
|
||||
module_id = id(module)
|
||||
if module_id not in matched_ids:
|
||||
matched_ids.add(module_id)
|
||||
matched_modules.append(module)
|
||||
|
||||
if missing_names and customized_names:
|
||||
logger.warning(f"ZeRO leaf module configuration contains unknown module names: {missing_names}")
|
||||
|
||||
if leaf_cfg.name_suffixes:
|
||||
suffix_matched, missing_suffixes = set_z3_leaf_modules_by_suffix(model,
|
||||
leaf_cfg.name_suffixes,
|
||||
flag=True,
|
||||
raise_if_not_found=False)
|
||||
for module in suffix_matched:
|
||||
module_id = id(module)
|
||||
if module_id not in matched_ids:
|
||||
matched_ids.add(module_id)
|
||||
matched_modules.append(module)
|
||||
|
||||
if missing_suffixes and customized_suffixes:
|
||||
logger.warning(f"ZeRO leaf module configuration contains unmatched module suffixes: {missing_suffixes}")
|
||||
|
||||
if not matched_modules and (customized_classes or customized_names or customized_suffixes):
|
||||
logger.warning("ZeRO leaf module configuration did not match any modules; hooks will be applied as usual")
|
||||
|
||||
return matched_modules
|
||||
|
@ -73,6 +73,84 @@ Each configuration works as follows:
|
||||
.. autofunction:: deepspeed.runtime.torch_autocast.has_autocast_dtype
|
||||
|
||||
|
||||
Configuring ZeRO Leaf Modules
|
||||
-----------------------------
|
||||
|
||||
ZeRO-3 relies on module execution order to gather partitioned parameters.
|
||||
When models select submodules dynamically (for example, MoE routers), different data-parallel ranks may gather different sets of parameters, which can cause the all-gather collective to deadlock.
|
||||
To avoid this problem, you can designate the parent of dynamically activated submodules (e.g., MoE experts) as a "leaf" module.
|
||||
When a module is marked as a leaf, ZeRO gathers all of its descendants immediately and stops inserting hooks beneath it.
|
||||
|
||||
Programmatic API
|
||||
================
|
||||
|
||||
Use :func:`deepspeed.utils.set_z3_leaf_modules` to flag modules by class, class
|
||||
name, or both. Optionally combine with
|
||||
:func:`deepspeed.utils.set_z3_leaf_modules_by_name` to target specific entries
|
||||
from ``model.named_modules()`` or
|
||||
:func:`deepspeed.utils.set_z3_leaf_modules_by_suffix` to match suffixes of those
|
||||
names.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from deepspeed.utils import (
|
||||
set_z3_leaf_modules,
|
||||
set_z3_leaf_modules_by_name,
|
||||
set_z3_leaf_modules_by_suffix,
|
||||
)
|
||||
|
||||
# Match by class or subclass
|
||||
set_z3_leaf_modules(model, [CustomMoEBlock])
|
||||
|
||||
# Match by fully qualified class name
|
||||
set_z3_leaf_modules(model, ["my_package.layers.CustomMoEBlock"])
|
||||
|
||||
# Match by module name returned from model.named_modules()
|
||||
set_z3_leaf_modules_by_name(model, ["transformer.layers.0.experts"])
|
||||
|
||||
# Match by suffix of names returned from model.named_modules()
|
||||
set_z3_leaf_modules_by_suffix(model, ["experts"])
|
||||
|
||||
Configuration in DeepSpeed config
|
||||
=================================
|
||||
|
||||
The same behavior can be controlled from the DeepSpeed config. Add a
|
||||
``leaf_module`` block to ``zero_optimization`` specifying either classes,
|
||||
module names, or name suffixes (or any combination). By default DeepSpeed marks
|
||||
several Hugging Face MoE blocks—including Mixtral and Qwen MoE sparse blocks so
|
||||
that they behave well with ZeRO3.
|
||||
|
||||
The default class list currently contains:
|
||||
|
||||
* ``transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock``
|
||||
* ``transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock``
|
||||
* ``transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock``
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"leaf_module": {
|
||||
"classes": ["my_package.layers.CustomMoEBlock"],
|
||||
"names": ["transformer.layers.0.experts"],
|
||||
"name_suffixes": ["experts"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
``names`` must match exactly what ``model.named_modules()`` produces. The
|
||||
``name_suffixes`` field compares each suffix against the end of those same
|
||||
module paths, making it convenient to apply a rule across repeated structures.
|
||||
Entries in ``classes`` may be either bare class names (for example,
|
||||
``MixtralSparseMoeBlock``) or fully qualified dotted paths; both forms are
|
||||
accepted.
|
||||
|
||||
You can mix and match the API and configuration approaches; all referenced
|
||||
modules are flagged before ZeRO installs its hooks.
|
||||
|
||||
|
||||
Model Saving
|
||||
------------
|
||||
.. autofunction:: deepspeed.DeepSpeedEngine.save_16bit_model
|
||||
|
@ -21,6 +21,8 @@ import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
import deepspeed.comm as dist
|
||||
|
||||
from .util import torch_assert_close
|
||||
|
||||
import pytest
|
||||
from _pytest.outcomes import Skipped
|
||||
from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
|
||||
@ -562,6 +564,8 @@ def enable_determinism(seed: int):
|
||||
|
||||
|
||||
def reduce_boolean_flags(flag: bool, op=all) -> bool:
|
||||
if not dist.is_initialized():
|
||||
return flag
|
||||
device = get_accelerator().current_device()
|
||||
tensor_flag = torch.tensor(1 if flag else 0, dtype=torch.int, device=device)
|
||||
world_size = dist.get_world_size()
|
||||
@ -569,3 +573,24 @@ def reduce_boolean_flags(flag: bool, op=all) -> bool:
|
||||
dist.all_gather_into_tensor(tensor_flag_buf, tensor_flag)
|
||||
list_flags = [bool(f) for f in tensor_flag_buf.tolist()]
|
||||
return op(list_flags)
|
||||
|
||||
|
||||
def allclose_on_all_ranks(actual, expected, assert_message=None, **kwargs) -> None:
|
||||
"""
|
||||
Compare two tensors across all ranks.
|
||||
We want to make sure that all ranks succeed or fail together.
|
||||
"""
|
||||
allclose_local = False
|
||||
allclose_global = False
|
||||
mismatch_msg = ""
|
||||
try:
|
||||
torch_assert_close(actual, expected, **kwargs)
|
||||
allclose_local = True
|
||||
allclose_global = reduce_boolean_flags(allclose_local, all)
|
||||
except AssertionError:
|
||||
allclose_global = reduce_boolean_flags(allclose_local, all)
|
||||
mismatch_msg = f"Tensors are not close: {actual=}, {expected=} {kwargs=}"
|
||||
|
||||
if not allclose_global:
|
||||
message = "Tensors are not close on all ranks." if assert_message is None else assert_message
|
||||
raise AssertionError(f"{message} {mismatch_msg}")
|
||||
|
@ -11,7 +11,11 @@ from unit.common import DistributedTest, preferred_dtype
|
||||
from unit.simple_model import random_dataloader
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module
|
||||
from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, \
|
||||
set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix
|
||||
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
|
||||
from deepspeed.runtime.zero.leaf_module_config import (DEFAULT_LEAF_MODULE_CLASSES, DEFAULT_LEAF_MODULE_NAMES,
|
||||
DEFAULT_LEAF_MODULE_NAME_SUFFIXES)
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from torch import nn
|
||||
import time
|
||||
@ -82,6 +86,142 @@ class FineGrainedBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class BaseLeafModule(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BaseLeafModule, self).__init__()
|
||||
|
||||
|
||||
class SubLeafModule(BaseLeafModule):
|
||||
|
||||
def __init__(self, hidden_dim):
|
||||
super(SubLeafModule, self).__init__()
|
||||
self.proj = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class WrapperLeafModule(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim):
|
||||
super(WrapperLeafModule, self).__init__()
|
||||
self.child = SubLeafModule(hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.child(x)
|
||||
|
||||
|
||||
def test_set_leaf_modules_with_fully_qualified_name():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
fq_name = f"{SubLeafModule.__module__}.{SubLeafModule.__qualname__}"
|
||||
|
||||
matched = set_z3_leaf_modules(model, [fq_name])
|
||||
|
||||
assert len(matched) == 1
|
||||
assert matched[0] is model.child
|
||||
assert z3_leaf_module(model.child)
|
||||
assert not z3_leaf_module(model)
|
||||
|
||||
|
||||
def test_set_leaf_modules_no_raise_when_missing():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched = set_z3_leaf_modules(model, ["NonExistentClass"], raise_if_not_found=False)
|
||||
|
||||
assert matched == []
|
||||
assert not z3_leaf_module(model.child)
|
||||
|
||||
|
||||
def test_set_leaf_modules_by_name():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched, missing = set_z3_leaf_modules_by_name(model, ["child"])
|
||||
|
||||
assert matched == [model.child]
|
||||
assert missing == []
|
||||
assert z3_leaf_module(model.child)
|
||||
|
||||
|
||||
def test_set_leaf_modules_by_name_missing():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched, missing = set_z3_leaf_modules_by_name(model, ["missing"], raise_if_not_found=False)
|
||||
|
||||
assert matched == []
|
||||
assert missing == ["missing"]
|
||||
|
||||
|
||||
def test_set_leaf_modules_by_suffix():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched, missing = set_z3_leaf_modules_by_suffix(model, ["child"])
|
||||
|
||||
assert missing == []
|
||||
assert matched == [model.child]
|
||||
assert z3_leaf_module(model.child)
|
||||
|
||||
|
||||
def test_set_leaf_modules_by_suffix_missing():
|
||||
hidden_dim = 16
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
matched, missing = set_z3_leaf_modules_by_suffix(model, ["missing"], raise_if_not_found=False)
|
||||
|
||||
assert matched == []
|
||||
assert missing == ["missing"]
|
||||
|
||||
|
||||
def test_zero_leaf_module_default_config():
|
||||
config = DeepSpeedZeroConfig()
|
||||
assert config.leaf_module.classes == DEFAULT_LEAF_MODULE_CLASSES
|
||||
assert config.leaf_module.names == DEFAULT_LEAF_MODULE_NAMES
|
||||
assert config.leaf_module.name_suffixes == DEFAULT_LEAF_MODULE_NAME_SUFFIXES
|
||||
|
||||
|
||||
def test_zero_leaf_module_custom_config():
|
||||
payload = {
|
||||
"leaf_module": {
|
||||
"classes": ["custom.module.CustomClass"],
|
||||
"names": ["transformer.layer"],
|
||||
"name_suffixes": ["experts"]
|
||||
}
|
||||
}
|
||||
|
||||
config = DeepSpeedZeroConfig(**payload)
|
||||
|
||||
assert config.leaf_module.classes == ["custom.module.CustomClass"]
|
||||
assert config.leaf_module.names == ["transformer.layer"]
|
||||
assert config.leaf_module.name_suffixes == ["experts"]
|
||||
|
||||
|
||||
def test_zero_leaf_module_string_coercion():
|
||||
payload = {"leaf_module": {"classes": "my.Class", "names": "submodule", "name_suffixes": "tail"}}
|
||||
|
||||
config = DeepSpeedZeroConfig(**payload)
|
||||
|
||||
assert config.leaf_module.classes == ["my.Class"]
|
||||
assert config.leaf_module.names == ["submodule"]
|
||||
assert config.leaf_module.name_suffixes == ["tail"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires Hugging Face transformers; run manually when validating defaults.")
|
||||
def test_default_leaf_module_classes_exist():
|
||||
import importlib
|
||||
|
||||
from deepspeed.runtime.zero.leaf_module_config import DEFAULT_LEAF_MODULE_CLASSES
|
||||
|
||||
for cls_path in DEFAULT_LEAF_MODULE_CLASSES:
|
||||
module_name, _, class_name = cls_path.rpartition('.')
|
||||
module = importlib.import_module(module_name)
|
||||
assert hasattr(module, class_name), f"Expected {class_name} in {module_name}"
|
||||
|
||||
|
||||
class modelWithFineGrainedBlock(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim, num_block):
|
||||
@ -123,10 +263,7 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
world_size = 2
|
||||
reuse_dist_env = True
|
||||
|
||||
def _test_set_z3_leaf_modules(self, cls, requires_grad):
|
||||
hidden_dim = 128
|
||||
|
||||
# `stage3_max_reuse_distance` is set to 0 to cause an error if the module is not set as a leaf module
|
||||
def _create_zero_config(self, hidden_dim, leaf_module=None):
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -143,11 +280,20 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
"stage3_max_reuse_distance": 0,
|
||||
}
|
||||
}
|
||||
if leaf_module is not None:
|
||||
config_dict["zero_optimization"]["leaf_module"] = leaf_module
|
||||
|
||||
if preferred_dtype() is torch.float16:
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif preferred_dtype() is torch.bfloat16:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
return config_dict
|
||||
|
||||
def _test_set_z3_leaf_modules(self, cls, requires_grad):
|
||||
hidden_dim = 128
|
||||
config_dict = self._create_zero_config(hidden_dim)
|
||||
|
||||
model = cls(hidden_dim)
|
||||
|
||||
assert not z3_leaf_module(model)
|
||||
@ -181,6 +327,17 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
"Expected only one module to be unset as a leaf module"
|
||||
assert len(get_z3_leaf_modules(model)) == 0, "Expected there is no leaf module"
|
||||
|
||||
def test_set_leaf_modules_with_subclass(self):
|
||||
hidden_dim = 32
|
||||
model = WrapperLeafModule(hidden_dim)
|
||||
|
||||
leaf_modules = set_z3_leaf_modules(model, [BaseLeafModule])
|
||||
|
||||
assert len(leaf_modules) == 1, "Expected the subclass instance to be marked as leaf"
|
||||
assert leaf_modules[0] is model.child, "Expected the subclass instance to be returned"
|
||||
assert z3_leaf_module(model.child), "Expected subclass instance flagged as leaf"
|
||||
assert not z3_leaf_module(model), "Expected wrapper module to remain non-leaf"
|
||||
|
||||
def test_set_no_match_class(self):
|
||||
hidden_dim = 128
|
||||
model = ChooseModuleByCounter(hidden_dim)
|
||||
@ -190,6 +347,25 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
except ValueError as e:
|
||||
pass
|
||||
|
||||
def test_leaf_module_enabled_via_config(self):
|
||||
hidden_dim = 128
|
||||
leaf_class_fqn = f"{ChooseModuleByCounter.__module__}.{ChooseModuleByCounter.__qualname__}"
|
||||
config_dict = self._create_zero_config(hidden_dim,
|
||||
leaf_module={
|
||||
"classes": [leaf_class_fqn],
|
||||
"name_suffixes": ["linears"]
|
||||
})
|
||||
|
||||
model = ChooseModuleByCounter(hidden_dim)
|
||||
assert not z3_leaf_module(model)
|
||||
|
||||
run_model(model, config_dict, hidden_dim, preferred_dtype(), True)
|
||||
|
||||
assert z3_leaf_module(model)
|
||||
modules_by_name = dict(model.named_modules())
|
||||
assert "linears" in modules_by_name
|
||||
assert z3_leaf_module(modules_by_name["linears"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("module_granularity_threshold", [0, 100, 12100, 10000000])
|
||||
class TestZ3LeafOptimization(DistributedTest):
|
||||
|
@ -94,15 +94,15 @@ class no_child_process_in_deepspeed_io:
|
||||
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method
|
||||
|
||||
|
||||
def torch_assert_equal(actual, expected, **kwargs):
|
||||
def torch_assert_equal(actual, expected, **kwargs) -> None:
|
||||
"""
|
||||
Compare two tensors or non-tensor numbers for their equality.
|
||||
Add msg=blah to add an additional comment to when assert fails.
|
||||
"""
|
||||
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
|
||||
torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
|
||||
|
||||
|
||||
def torch_assert_close(actual, expected, **kwargs):
|
||||
def torch_assert_close(actual, expected, **kwargs) -> None:
|
||||
"""
|
||||
Compare two tensors or non-tensor numbers for their closeness.
|
||||
|
||||
@ -113,7 +113,7 @@ def torch_assert_close(actual, expected, **kwargs):
|
||||
|
||||
The check doesn't assert when `|a - b| <= (atol + rtol * |b|)`
|
||||
"""
|
||||
return torch.testing.assert_close(actual, expected, **kwargs)
|
||||
torch.testing.assert_close(actual, expected, **kwargs)
|
||||
|
||||
|
||||
def torch_assert_dicts_of_tensors_equal(actual, expected, **kwargs):
|
||||
|
@ -12,12 +12,14 @@ from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.runtime.zero import GatheredParameters
|
||||
|
||||
from unit.simple_model import SimpleModel
|
||||
from unit.common import enable_determinism
|
||||
from unit.common import enable_determinism, allclose_on_all_ranks
|
||||
|
||||
|
||||
@enable_determinism(123)
|
||||
def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
|
||||
hidden_dim = hidden_dim_override if hidden_dim_override is not None else 10
|
||||
|
||||
# the default tolerances of torch.testing.assert_close are too small
|
||||
RTOL = 5e-1
|
||||
ATOL = 1e-2
|
||||
|
||||
@ -56,7 +58,7 @@ def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
|
||||
baseline_loss = baseline_engine(x, y)
|
||||
target_loss = target_engine(x, y)
|
||||
|
||||
assert torch.allclose(baseline_loss, target_loss, rtol=RTOL, atol=ATOL)
|
||||
allclose_on_all_ranks(baseline_loss, target_loss, "Loss values are not close.", rtol=RTOL, atol=ATOL)
|
||||
|
||||
baseline_engine.backward(baseline_loss)
|
||||
target_engine.backward(target_loss)
|
||||
@ -66,7 +68,7 @@ def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
|
||||
|
||||
with GatheredParameters(target_engine.parameters()):
|
||||
for p1, p2 in zip(baseline_engine.parameters(), target_engine.parameters()):
|
||||
assert torch.allclose(p1.to(dtype), p2, rtol=RTOL, atol=ATOL)
|
||||
allclose_on_all_ranks(p1, p2, "Parameters are not equal.", rtol=RTOL, atol=ATOL)
|
||||
|
||||
baseline_engine.destroy()
|
||||
target_engine.destroy()
|
||||
|
Reference in New Issue
Block a user