fix keep_module_on_host (#7112)

Reapply https://github.com/deepspeedai/DeepSpeed/pull/6846.
FYI @oelayan7

---------

Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
inkcherry
2025-03-11 00:57:34 +08:00
committed by GitHub
parent c1acd49cdf
commit 8ec1af5f5c
2 changed files with 59 additions and 38 deletions

View File

@ -11,7 +11,7 @@ from .replace_policy import replace_policies
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
from .layers import *
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
@ -211,7 +211,7 @@ class AutoTP():
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
self.keep_module_on_host = keep_module_on_host
TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host)
def in_module_list(module, module_list):
for item in module_list:
@ -350,10 +350,7 @@ class AutoTP():
# and avoid any complex shard-related logic.
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
return_new_copy = not self.keep_module_on_host
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip

View File

@ -17,6 +17,11 @@ from deepspeed.runtime.tensor_parallel import AUTOTP_MODE
from copy import deepcopy
from typing import Union
__all__ = [
"TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce",
"Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer"
]
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
DS_IS_REPLACED_MODULE = 'ds_is_replaced_module'
DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel'
@ -43,26 +48,6 @@ def set_autotp_mode(training=False):
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
def move(tensor, device):
# TODO: consider the timing of deletion
# to save host resources when DP > 1。
if tensor.is_meta:
# Keep tensor in meta device if tensor is meta.
return tensor
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
cloned_tensor = tensor.to(device, copy=True)
# free the memory of the original tensor to reduce memory peak
# Equivalent to directly deleting the tensor reference outside the function.
# see https://github.com/microsoft/DeepSpeed/pull/4353
tensor.data = torch.empty(0, device=tensor.device)
return cloned_tensor
class RowParallel(torch.autograd.Function):
"""
A custom autograd function for performing row-wise parallelism.
@ -140,6 +125,10 @@ class TensorParallel_Layer(nn.Module, ABC):
name (Optional[str]): The name of the layer, if provided.
"""
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
keep_module_on_host: bool = False
def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
"""
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
@ -163,6 +152,16 @@ class TensorParallel_Layer(nn.Module, ABC):
if kwargs.get('name') is not None:
self.name = kwargs.get('name') # Set the layer name if provided.
@classmethod
def set_keep_module_on_host(cls, value: bool):
"""
Set the static variable keep_module_on_host.
Args:
value (bool): The new value for keep_module_on_host.
"""
cls.keep_module_on_host = value
@abstractmethod
def forward(self, input):
"""
@ -235,6 +234,31 @@ class TensorParallel_Layer(nn.Module, ABC):
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str
def move(self, tensor):
# TODO: consider the timing of deletion
# to save host resources when DP > 1。
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
if tensor.is_meta:
# Keep tensor in meta device if tensor is meta.
return tensor
else:
device = 'cpu' if self.__class__.keep_module_on_host else get_accelerator().current_device_name()
return_new_copy = not self.__class__.keep_module_on_host
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
cloned_tensor = tensor.to(device, copy=return_new_copy)
if return_new_copy:
# free the memory of the original tensor to reduce memory peak
# Equivalent to directly deleting the tensor reference outside the function.
# see https://github.com/microsoft/DeepSpeed/pull/4353
tensor.data = torch.empty(0, device=tensor.device)
return cloned_tensor
class GatherReplacedLayerParams:
"""
@ -349,7 +373,7 @@ class LinearAllreduce(TensorParallel_Layer):
return
_partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index]
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
@ -363,7 +387,7 @@ class LinearAllreduce(TensorParallel_Layer):
self.name),
dim=1)[self.tp_index]
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
@ -414,7 +438,7 @@ class LinearLayer(TensorParallel_Layer):
#split bias if provide
_partition = torch.chunk(param, self.tp_world_size, dim=0)[self.tp_index]
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
@ -429,7 +453,7 @@ class LinearLayer(TensorParallel_Layer):
self.name),
dim=0)[self.tp_index]
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
@ -475,7 +499,7 @@ class fused_LinearLayer(LinearLayer):
_partition = prepare_tp_fused_qkvw(self.fused_module.module, param, self.tp_world_size, self.tp_index)
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
params_list[idx].data = _partition
@ -492,13 +516,13 @@ class conv_LinearLayer(LinearLayer):
weight, bias = params_list[0], params_list[1]
_partition = weight.data.split(get_shard_size_list(weight.shape[0], self.tp_world_size, self.name),
dim=1)[self.tp_index]
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
weight.data = _partition
if bias is not None:
_partition = bias.data.split(get_shard_size_list(weight.shape[1], self.tp_world_size, self.name),
dim=0)[self.tp_index]
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
bias.data = _partition
@ -522,9 +546,9 @@ class Yuan_LinearLayer(LinearLayer):
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, True)
params_list[0].data = move(weight, get_accelerator().current_device_name()).detach()
params_list[0].data = self.move(weight).detach()
if bias is not None:
params_list[1].data = move(bias, get_accelerator().current_device_name()).detach()
params_list[1].data = self.move(bias).detach()
class GateUpPack_LinearLayer(LinearLayer):
@ -532,9 +556,9 @@ class GateUpPack_LinearLayer(LinearLayer):
@torch.no_grad()
def _tp_partition(self, params_list):
weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size)
params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach()
params_list[0].data = self.move(weight).detach()
if bias is not None:
params_list[1].data = move(bias, device=get_accelerator().current_device_name()).detach()
params_list[1].data = self.move(bias).detach()
class Conv_LinearALlreduce(LinearAllreduce):
@ -549,7 +573,7 @@ class Conv_LinearALlreduce(LinearAllreduce):
_partition = param.split(get_shard_size_list(param.shape[0], self.tp_world_size, self.name),
dim=1)[self.tp_index]
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
params_list[idx].data = _partition