mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
fix keep_module_on_host (#7112)
Reapply https://github.com/deepspeedai/DeepSpeed/pull/6846. FYI @oelayan7 --------- Signed-off-by: inkcherry <mingzhi.liu@intel.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
@ -11,7 +11,7 @@ from .replace_policy import replace_policies
|
||||
from typing import Optional
|
||||
import torch
|
||||
from deepspeed import comm as dist
|
||||
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
|
||||
from .layers import *
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from .fusedqkv_utils import require_tp_fused_qkvw
|
||||
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
|
||||
@ -211,7 +211,7 @@ class AutoTP():
|
||||
self.orig_layer_impl = orig_layer_impl
|
||||
self.linear_policies = None
|
||||
self.conv_linear_layer = False
|
||||
self.keep_module_on_host = keep_module_on_host
|
||||
TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host)
|
||||
|
||||
def in_module_list(module, module_list):
|
||||
for item in module_list:
|
||||
@ -350,10 +350,7 @@ class AutoTP():
|
||||
# and avoid any complex shard-related logic.
|
||||
if getattr(child, "replaced", False) == True:
|
||||
return
|
||||
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
|
||||
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
|
||||
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
|
||||
return_new_copy = not self.keep_module_on_host
|
||||
|
||||
weight_shape = child.weight.shape
|
||||
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
|
||||
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
|
||||
|
@ -17,6 +17,11 @@ from deepspeed.runtime.tensor_parallel import AUTOTP_MODE
|
||||
from copy import deepcopy
|
||||
from typing import Union
|
||||
|
||||
__all__ = [
|
||||
"TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce",
|
||||
"Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer"
|
||||
]
|
||||
|
||||
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
|
||||
DS_IS_REPLACED_MODULE = 'ds_is_replaced_module'
|
||||
DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel'
|
||||
@ -43,26 +48,6 @@ def set_autotp_mode(training=False):
|
||||
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
|
||||
|
||||
|
||||
def move(tensor, device):
|
||||
# TODO: consider the timing of deletion
|
||||
# to save host resources when DP > 1。
|
||||
|
||||
if tensor.is_meta:
|
||||
# Keep tensor in meta device if tensor is meta.
|
||||
return tensor
|
||||
else:
|
||||
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
|
||||
# Using copy=True instead of clone() will help in case of cpu --> cpu.
|
||||
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
|
||||
cloned_tensor = tensor.to(device, copy=True)
|
||||
|
||||
# free the memory of the original tensor to reduce memory peak
|
||||
# Equivalent to directly deleting the tensor reference outside the function.
|
||||
# see https://github.com/microsoft/DeepSpeed/pull/4353
|
||||
tensor.data = torch.empty(0, device=tensor.device)
|
||||
return cloned_tensor
|
||||
|
||||
|
||||
class RowParallel(torch.autograd.Function):
|
||||
"""
|
||||
A custom autograd function for performing row-wise parallelism.
|
||||
@ -140,6 +125,10 @@ class TensorParallel_Layer(nn.Module, ABC):
|
||||
name (Optional[str]): The name of the layer, if provided.
|
||||
"""
|
||||
|
||||
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
|
||||
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
|
||||
keep_module_on_host: bool = False
|
||||
|
||||
def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
|
||||
"""
|
||||
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
|
||||
@ -163,6 +152,16 @@ class TensorParallel_Layer(nn.Module, ABC):
|
||||
if kwargs.get('name') is not None:
|
||||
self.name = kwargs.get('name') # Set the layer name if provided.
|
||||
|
||||
@classmethod
|
||||
def set_keep_module_on_host(cls, value: bool):
|
||||
"""
|
||||
Set the static variable keep_module_on_host.
|
||||
|
||||
Args:
|
||||
value (bool): The new value for keep_module_on_host.
|
||||
"""
|
||||
cls.keep_module_on_host = value
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input):
|
||||
"""
|
||||
@ -235,6 +234,31 @@ class TensorParallel_Layer(nn.Module, ABC):
|
||||
in_features, out_features, self.bias is not None, dtype)
|
||||
return extra_repr_str
|
||||
|
||||
def move(self, tensor):
|
||||
# TODO: consider the timing of deletion
|
||||
# to save host resources when DP > 1。
|
||||
|
||||
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
|
||||
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
|
||||
if tensor.is_meta:
|
||||
# Keep tensor in meta device if tensor is meta.
|
||||
return tensor
|
||||
else:
|
||||
device = 'cpu' if self.__class__.keep_module_on_host else get_accelerator().current_device_name()
|
||||
return_new_copy = not self.__class__.keep_module_on_host
|
||||
|
||||
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
|
||||
# Using copy=True instead of clone() will help in case of cpu --> cpu.
|
||||
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
|
||||
cloned_tensor = tensor.to(device, copy=return_new_copy)
|
||||
|
||||
if return_new_copy:
|
||||
# free the memory of the original tensor to reduce memory peak
|
||||
# Equivalent to directly deleting the tensor reference outside the function.
|
||||
# see https://github.com/microsoft/DeepSpeed/pull/4353
|
||||
tensor.data = torch.empty(0, device=tensor.device)
|
||||
return cloned_tensor
|
||||
|
||||
|
||||
class GatherReplacedLayerParams:
|
||||
"""
|
||||
@ -349,7 +373,7 @@ class LinearAllreduce(TensorParallel_Layer):
|
||||
return
|
||||
_partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index]
|
||||
|
||||
_partition = move(_partition, get_accelerator().current_device_name()).detach()
|
||||
_partition = self.move(_partition).detach()
|
||||
|
||||
params_list[idx].data = _partition
|
||||
|
||||
@ -363,7 +387,7 @@ class LinearAllreduce(TensorParallel_Layer):
|
||||
self.name),
|
||||
dim=1)[self.tp_index]
|
||||
|
||||
_partition = move(_partition, get_accelerator().current_device_name()).detach()
|
||||
_partition = self.move(_partition).detach()
|
||||
params_list[idx].data = _partition
|
||||
|
||||
|
||||
@ -414,7 +438,7 @@ class LinearLayer(TensorParallel_Layer):
|
||||
#split bias if provide
|
||||
_partition = torch.chunk(param, self.tp_world_size, dim=0)[self.tp_index]
|
||||
|
||||
_partition = move(_partition, get_accelerator().current_device_name()).detach()
|
||||
_partition = self.move(_partition).detach()
|
||||
|
||||
params_list[idx].data = _partition
|
||||
|
||||
@ -429,7 +453,7 @@ class LinearLayer(TensorParallel_Layer):
|
||||
self.name),
|
||||
dim=0)[self.tp_index]
|
||||
|
||||
_partition = move(_partition, get_accelerator().current_device_name()).detach()
|
||||
_partition = self.move(_partition).detach()
|
||||
|
||||
params_list[idx].data = _partition
|
||||
|
||||
@ -475,7 +499,7 @@ class fused_LinearLayer(LinearLayer):
|
||||
|
||||
_partition = prepare_tp_fused_qkvw(self.fused_module.module, param, self.tp_world_size, self.tp_index)
|
||||
|
||||
_partition = move(_partition, get_accelerator().current_device_name()).detach()
|
||||
_partition = self.move(_partition).detach()
|
||||
|
||||
params_list[idx].data = _partition
|
||||
|
||||
@ -492,13 +516,13 @@ class conv_LinearLayer(LinearLayer):
|
||||
weight, bias = params_list[0], params_list[1]
|
||||
_partition = weight.data.split(get_shard_size_list(weight.shape[0], self.tp_world_size, self.name),
|
||||
dim=1)[self.tp_index]
|
||||
_partition = move(_partition, get_accelerator().current_device_name()).detach()
|
||||
_partition = self.move(_partition).detach()
|
||||
weight.data = _partition
|
||||
|
||||
if bias is not None:
|
||||
_partition = bias.data.split(get_shard_size_list(weight.shape[1], self.tp_world_size, self.name),
|
||||
dim=0)[self.tp_index]
|
||||
_partition = move(_partition, get_accelerator().current_device_name()).detach()
|
||||
_partition = self.move(_partition).detach()
|
||||
|
||||
bias.data = _partition
|
||||
|
||||
@ -522,9 +546,9 @@ class Yuan_LinearLayer(LinearLayer):
|
||||
def _tp_partition(self, params_list):
|
||||
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
|
||||
self.tp_world_size, True)
|
||||
params_list[0].data = move(weight, get_accelerator().current_device_name()).detach()
|
||||
params_list[0].data = self.move(weight).detach()
|
||||
if bias is not None:
|
||||
params_list[1].data = move(bias, get_accelerator().current_device_name()).detach()
|
||||
params_list[1].data = self.move(bias).detach()
|
||||
|
||||
|
||||
class GateUpPack_LinearLayer(LinearLayer):
|
||||
@ -532,9 +556,9 @@ class GateUpPack_LinearLayer(LinearLayer):
|
||||
@torch.no_grad()
|
||||
def _tp_partition(self, params_list):
|
||||
weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size)
|
||||
params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach()
|
||||
params_list[0].data = self.move(weight).detach()
|
||||
if bias is not None:
|
||||
params_list[1].data = move(bias, device=get_accelerator().current_device_name()).detach()
|
||||
params_list[1].data = self.move(bias).detach()
|
||||
|
||||
|
||||
class Conv_LinearALlreduce(LinearAllreduce):
|
||||
@ -549,7 +573,7 @@ class Conv_LinearALlreduce(LinearAllreduce):
|
||||
_partition = param.split(get_shard_size_list(param.shape[0], self.tp_world_size, self.name),
|
||||
dim=1)[self.tp_index]
|
||||
|
||||
_partition = move(_partition, get_accelerator().current_device_name()).detach()
|
||||
_partition = self.move(_partition).detach()
|
||||
|
||||
params_list[idx].data = _partition
|
||||
|
||||
|
Reference in New Issue
Block a user