mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
async tp allreduce (#7115)
Signed-off-by: inkcherry <mingzhi.liu@intel.com> Signed-off-by: Logan Adams <loadams@microsoft.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Signed-off-by: Shaik Raza Sikander <srsikander@habana.ai> Signed-off-by: Max Kovalenko <mkovalenko@habana.ai> Signed-off-by: shaomin <wukon1992@gmail.com> Signed-off-by: Stas Bekman <stas@stason.org> Signed-off-by: siqi <siqi@tecorigin.com> Signed-off-by: Wei Wu <wuwei211x@gmail.com> Signed-off-by: ShellyNR <shelly.nahir@live.biu.ac.il> Signed-off-by: Lai, Yejing <yejing.lai@intel.com> Signed-off-by: Hongwei <hongweichen@microsoft.com> Signed-off-by: Liang Cheng <astarxp777@gmail.com> Signed-off-by: A-transformer <astarxp777@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: A-transformer <cl5743590921@gmail.com> Co-authored-by: Raza Sikander <srsikander@habana.ai> Co-authored-by: Max Kovalenko <mkovalenko@habana.ai> Co-authored-by: wukong1992 <wukong1992@users.noreply.github.com> Co-authored-by: shaomin <wukon1992@gmail.com> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com> Co-authored-by: loadams <loadams@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: siqi654321 <siqi202311@163.com> Co-authored-by: siqi <siqi@tecorigin.com> Co-authored-by: Wei Wu <45323446+U-rara@users.noreply.github.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Shelly Nahir <73890534+ShellyNR@users.noreply.github.com> Co-authored-by: snahir <snahir@habana.ai> Co-authored-by: Yejing-Lai <yejing.lai@intel.com> Co-authored-by: A-transformer <astarxp777@gmail.com> Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
This commit is contained in:
@ -366,7 +366,7 @@ def init_inference(model, config=None, **kwargs):
|
||||
return engine
|
||||
|
||||
|
||||
def tp_model_init(model, tp_size, dtype):
|
||||
def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
|
||||
"""
|
||||
Initialize the model for tensor parallelism.
|
||||
|
||||
@ -379,8 +379,9 @@ def tp_model_init(model, tp_size, dtype):
|
||||
torch.nn.Module: The initialized model with tensor parallelism.
|
||||
"""
|
||||
# avoid re-entry
|
||||
assert not hasattr(
|
||||
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."
|
||||
if hasattr(model, 'ds_autotp_parsed'):
|
||||
logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.")
|
||||
return
|
||||
|
||||
set_autotp_mode(training=True)
|
||||
|
||||
|
@ -80,6 +80,35 @@ class RowParallel(torch.autograd.Function):
|
||||
return None, grad_output, None
|
||||
|
||||
|
||||
class AsyncColumnParallel(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bias) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass.
|
||||
"""
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.group = group
|
||||
output = torch.matmul(input, weight.transpose(-1, -2))
|
||||
if bias is not None:
|
||||
output += bias
|
||||
|
||||
ctx.save_for_backward(input, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
|
||||
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_output.matmul(weight)
|
||||
handle = dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True)
|
||||
grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1]))
|
||||
grad_bias = grad_output.sum(0) if ctx.use_bias else None
|
||||
handle.wait()
|
||||
return None, grad_input, grad_weight, grad_bias
|
||||
|
||||
|
||||
class ColumnParallel(torch.autograd.Function):
|
||||
"""
|
||||
Custom autograd function for column-wise parallelism.
|
||||
@ -124,11 +153,17 @@ class TensorParallel_Layer(nn.Module, ABC):
|
||||
support_training (bool): Flag indicating whether the layer supports training (default: False).
|
||||
name (Optional[str]): The name of the layer, if provided.
|
||||
"""
|
||||
##### Initialize Parameter List #####
|
||||
|
||||
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
|
||||
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
|
||||
# keep_module_on_host determines whether to keep the module on the host.
|
||||
# Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory),
|
||||
# so an additional copy is unnecessary.
|
||||
keep_module_on_host: bool = False
|
||||
|
||||
##### Runtime Parameter List #####
|
||||
tp_overlap_comm: bool = False
|
||||
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
|
||||
|
||||
def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
|
||||
"""
|
||||
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
|
||||
@ -260,6 +295,13 @@ class TensorParallel_Layer(nn.Module, ABC):
|
||||
return cloned_tensor
|
||||
|
||||
|
||||
def configure_tensor_parallel_runtime(config):
|
||||
runtime_keys = ['tp_overlap_comm']
|
||||
for key in runtime_keys:
|
||||
if hasattr(config, key):
|
||||
setattr(TensorParallel_Layer, key, getattr(config, key))
|
||||
|
||||
|
||||
class GatherReplacedLayerParams:
|
||||
"""
|
||||
A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality
|
||||
@ -406,11 +448,15 @@ class LinearLayer(TensorParallel_Layer):
|
||||
self.config_tp_params(self.bias)
|
||||
|
||||
def forward(self, input):
|
||||
if getattr(self, 'mp_group', None) is not None:
|
||||
input = ColumnParallel.apply(self.mp_group, input)
|
||||
output = torch.matmul(input, self.weight.transpose(-1, -2))
|
||||
if self.bias is not None:
|
||||
output += self.bias
|
||||
if not self.__class__.tp_overlap_comm:
|
||||
if getattr(self, 'mp_group', None) is not None:
|
||||
input = ColumnParallel.apply(self.mp_group, input)
|
||||
output = torch.matmul(input, self.weight.transpose(-1, -2))
|
||||
if self.bias is not None:
|
||||
output += self.bias
|
||||
else:
|
||||
output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias)
|
||||
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -37,8 +37,7 @@ from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
|
||||
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
|
||||
|
||||
from deepspeed.linear.optimized_linear import LoRAOptimizedLinear
|
||||
from deepspeed.module_inject.layers import GatherReplacedLayerParams
|
||||
|
||||
from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime
|
||||
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
|
||||
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
|
||||
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
|
||||
@ -248,7 +247,7 @@ class DeepSpeedEngine(Module):
|
||||
self._configure_with_arguments(args, mpu)
|
||||
self._do_sanity_check()
|
||||
if self.autotp_size() > 1:
|
||||
self._configure_tensor_parallel_states(model)
|
||||
self._configure_tensor_parallel(model, self.tensor_parallel_config())
|
||||
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
|
||||
if mpu is not None:
|
||||
if self.elasticity_enabled():
|
||||
@ -416,6 +415,10 @@ class DeepSpeedEngine(Module):
|
||||
else:
|
||||
p.ds_offload = False
|
||||
|
||||
def _configure_tensor_parallel(self, model, tp_config):
|
||||
self._configure_tensor_parallel_states(model)
|
||||
configure_tensor_parallel_runtime(tp_config)
|
||||
|
||||
def _configure_tensor_parallel_states(self, model):
|
||||
"""
|
||||
Configures the tensor parallel states for the model.
|
||||
@ -423,7 +426,6 @@ class DeepSpeedEngine(Module):
|
||||
and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks.
|
||||
"""
|
||||
self._set_client_model(model)
|
||||
|
||||
# sanity check
|
||||
# currently, the compatibility between 'autotp' and 'zero > 1' has not been validated
|
||||
assert self.zero_optimization_stage(
|
||||
@ -902,6 +904,9 @@ class DeepSpeedEngine(Module):
|
||||
def zero_ignore_unused_parameters(self):
|
||||
return self._config.zero_config.ignore_unused_parameters
|
||||
|
||||
def tensor_parallel_config(self):
|
||||
return self._config.tensor_parallel_config
|
||||
|
||||
def autotp_size(self):
|
||||
return self._config.tensor_parallel_config.autotp_size
|
||||
|
||||
|
@ -47,6 +47,9 @@ class TPTrainingConfig(DeepSpeedConfigModel):
|
||||
In automatic tensor-parallelism training, 'tensor_parallel_size'
|
||||
When set to 0, indicates that it is disabled.
|
||||
"""
|
||||
tp_overlap_comm: bool = False
|
||||
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
|
||||
|
||||
tensor_parallel: TPConfig = Field({}, alias="tp")
|
||||
"""
|
||||
Configuration for tensor parallelism used to split the model across several
|
||||
|
@ -163,11 +163,12 @@ def process_linear_layer(hidden_dim, input):
|
||||
|
||||
@pytest.mark.sequential
|
||||
@pytest.mark.parametrize("tp_size", [2, 4])
|
||||
@pytest.mark.parametrize("tp_overlap_comm", [True, False])
|
||||
class TestTpLayerFwdBwd(DistributedTest):
|
||||
world_size = 4
|
||||
reuse_dist_env = True
|
||||
|
||||
def testRowParallel(self, tp_size: int):
|
||||
def testRowParallel(self, tp_size: int, tp_overlap_comm: bool):
|
||||
skip_on_device()
|
||||
hidden_dim = 128
|
||||
batch_size_per_device = 1
|
||||
@ -182,7 +183,8 @@ class TestTpLayerFwdBwd(DistributedTest):
|
||||
}
|
||||
},
|
||||
"tensor_parallel": {
|
||||
"autotp_size": tp_size
|
||||
"autotp_size": tp_size,
|
||||
"tp_overlap_comm": tp_overlap_comm
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 0,
|
||||
@ -214,9 +216,9 @@ class TestTpLayerFwdBwd(DistributedTest):
|
||||
|
||||
torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()]
|
||||
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
|
||||
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-3)
|
||||
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-2)
|
||||
|
||||
def testColumnParallel(self, tp_size: int):
|
||||
def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool):
|
||||
skip_on_device()
|
||||
hidden_dim = 128
|
||||
batch_size_per_device = 1
|
||||
@ -231,7 +233,8 @@ class TestTpLayerFwdBwd(DistributedTest):
|
||||
}
|
||||
},
|
||||
"tensor_parallel": {
|
||||
"autotp_size": tp_size
|
||||
"autotp_size": tp_size,
|
||||
"tp_overlap_comm": tp_overlap_comm
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 0,
|
||||
@ -266,7 +269,7 @@ class TestTpLayerFwdBwd(DistributedTest):
|
||||
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
|
||||
assert torch.allclose(cur_device_out.to(get_accelerator().current_device()).contiguous(),
|
||||
out.contiguous(),
|
||||
atol=1e-3)
|
||||
atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.sequential
|
||||
|
Reference in New Issue
Block a user