async tp allreduce (#7115)

Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Shaik Raza Sikander <srsikander@habana.ai>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Signed-off-by: shaomin <wukon1992@gmail.com>
Signed-off-by: Stas Bekman <stas@stason.org>
Signed-off-by: siqi <siqi@tecorigin.com>
Signed-off-by: Wei Wu <wuwei211x@gmail.com>
Signed-off-by: ShellyNR <shelly.nahir@live.biu.ac.il>
Signed-off-by: Lai, Yejing <yejing.lai@intel.com>
Signed-off-by: Hongwei <hongweichen@microsoft.com>
Signed-off-by: Liang Cheng <astarxp777@gmail.com>
Signed-off-by: A-transformer <astarxp777@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: A-transformer <cl5743590921@gmail.com>
Co-authored-by: Raza Sikander <srsikander@habana.ai>
Co-authored-by: Max Kovalenko <mkovalenko@habana.ai>
Co-authored-by: wukong1992 <wukong1992@users.noreply.github.com>
Co-authored-by: shaomin <wukon1992@gmail.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
Co-authored-by: loadams <loadams@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: siqi654321 <siqi202311@163.com>
Co-authored-by: siqi <siqi@tecorigin.com>
Co-authored-by: Wei Wu <45323446+U-rara@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Shelly Nahir <73890534+ShellyNR@users.noreply.github.com>
Co-authored-by: snahir <snahir@habana.ai>
Co-authored-by: Yejing-Lai <yejing.lai@intel.com>
Co-authored-by: A-transformer <astarxp777@gmail.com>
Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
This commit is contained in:
inkcherry
2025-03-29 06:48:17 +08:00
committed by GitHub
parent f355b9eadf
commit b8cc1eb078
5 changed files with 78 additions and 20 deletions

View File

@ -366,7 +366,7 @@ def init_inference(model, config=None, **kwargs):
return engine
def tp_model_init(model, tp_size, dtype):
def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
"""
Initialize the model for tensor parallelism.
@ -379,8 +379,9 @@ def tp_model_init(model, tp_size, dtype):
torch.nn.Module: The initialized model with tensor parallelism.
"""
# avoid re-entry
assert not hasattr(
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."
if hasattr(model, 'ds_autotp_parsed'):
logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.")
return
set_autotp_mode(training=True)

View File

@ -80,6 +80,35 @@ class RowParallel(torch.autograd.Function):
return None, grad_output, None
class AsyncColumnParallel(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bias) -> torch.Tensor:
"""
Forward pass.
"""
ctx.use_bias = bias is not None
ctx.group = group
output = torch.matmul(input, weight.transpose(-1, -2))
if bias is not None:
output += bias
ctx.save_for_backward(input, weight)
return output
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
input, weight = ctx.saved_tensors
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True)
grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1]))
grad_bias = grad_output.sum(0) if ctx.use_bias else None
handle.wait()
return None, grad_input, grad_weight, grad_bias
class ColumnParallel(torch.autograd.Function):
"""
Custom autograd function for column-wise parallelism.
@ -124,11 +153,17 @@ class TensorParallel_Layer(nn.Module, ABC):
support_training (bool): Flag indicating whether the layer supports training (default: False).
name (Optional[str]): The name of the layer, if provided.
"""
##### Initialize Parameter List #####
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
# keep_module_on_host determines whether to keep the module on the host.
# Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory),
# so an additional copy is unnecessary.
keep_module_on_host: bool = False
##### Runtime Parameter List #####
tp_overlap_comm: bool = False
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
"""
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
@ -260,6 +295,13 @@ class TensorParallel_Layer(nn.Module, ABC):
return cloned_tensor
def configure_tensor_parallel_runtime(config):
runtime_keys = ['tp_overlap_comm']
for key in runtime_keys:
if hasattr(config, key):
setattr(TensorParallel_Layer, key, getattr(config, key))
class GatherReplacedLayerParams:
"""
A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality
@ -406,11 +448,15 @@ class LinearLayer(TensorParallel_Layer):
self.config_tp_params(self.bias)
def forward(self, input):
if getattr(self, 'mp_group', None) is not None:
input = ColumnParallel.apply(self.mp_group, input)
output = torch.matmul(input, self.weight.transpose(-1, -2))
if self.bias is not None:
output += self.bias
if not self.__class__.tp_overlap_comm:
if getattr(self, 'mp_group', None) is not None:
input = ColumnParallel.apply(self.mp_group, input)
output = torch.matmul(input, self.weight.transpose(-1, -2))
if self.bias is not None:
output += self.bias
else:
output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias)
return output
@torch.no_grad()

View File

@ -37,8 +37,7 @@ from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
from deepspeed.linear.optimized_linear import LoRAOptimizedLinear
from deepspeed.module_inject.layers import GatherReplacedLayerParams
from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
@ -248,7 +247,7 @@ class DeepSpeedEngine(Module):
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
if self.autotp_size() > 1:
self._configure_tensor_parallel_states(model)
self._configure_tensor_parallel(model, self.tensor_parallel_config())
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
if mpu is not None:
if self.elasticity_enabled():
@ -416,6 +415,10 @@ class DeepSpeedEngine(Module):
else:
p.ds_offload = False
def _configure_tensor_parallel(self, model, tp_config):
self._configure_tensor_parallel_states(model)
configure_tensor_parallel_runtime(tp_config)
def _configure_tensor_parallel_states(self, model):
"""
Configures the tensor parallel states for the model.
@ -423,7 +426,6 @@ class DeepSpeedEngine(Module):
and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks.
"""
self._set_client_model(model)
# sanity check
# currently, the compatibility between 'autotp' and 'zero > 1' has not been validated
assert self.zero_optimization_stage(
@ -902,6 +904,9 @@ class DeepSpeedEngine(Module):
def zero_ignore_unused_parameters(self):
return self._config.zero_config.ignore_unused_parameters
def tensor_parallel_config(self):
return self._config.tensor_parallel_config
def autotp_size(self):
return self._config.tensor_parallel_config.autotp_size

View File

@ -47,6 +47,9 @@ class TPTrainingConfig(DeepSpeedConfigModel):
In automatic tensor-parallelism training, 'tensor_parallel_size'
When set to 0, indicates that it is disabled.
"""
tp_overlap_comm: bool = False
""" Whether to overlap communication with computation. Currently, only allreduce supports overlap. """
tensor_parallel: TPConfig = Field({}, alias="tp")
"""
Configuration for tensor parallelism used to split the model across several

View File

@ -163,11 +163,12 @@ def process_linear_layer(hidden_dim, input):
@pytest.mark.sequential
@pytest.mark.parametrize("tp_size", [2, 4])
@pytest.mark.parametrize("tp_overlap_comm", [True, False])
class TestTpLayerFwdBwd(DistributedTest):
world_size = 4
reuse_dist_env = True
def testRowParallel(self, tp_size: int):
def testRowParallel(self, tp_size: int, tp_overlap_comm: bool):
skip_on_device()
hidden_dim = 128
batch_size_per_device = 1
@ -182,7 +183,8 @@ class TestTpLayerFwdBwd(DistributedTest):
}
},
"tensor_parallel": {
"autotp_size": tp_size
"autotp_size": tp_size,
"tp_overlap_comm": tp_overlap_comm
},
"zero_optimization": {
"stage": 0,
@ -214,9 +216,9 @@ class TestTpLayerFwdBwd(DistributedTest):
torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()]
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-3)
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-2)
def testColumnParallel(self, tp_size: int):
def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool):
skip_on_device()
hidden_dim = 128
batch_size_per_device = 1
@ -231,7 +233,8 @@ class TestTpLayerFwdBwd(DistributedTest):
}
},
"tensor_parallel": {
"autotp_size": tp_size
"autotp_size": tp_size,
"tp_overlap_comm": tp_overlap_comm
},
"zero_optimization": {
"stage": 0,
@ -266,7 +269,7 @@ class TestTpLayerFwdBwd(DistributedTest):
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
assert torch.allclose(cur_device_out.to(get_accelerator().current_device()).contiguous(),
out.contiguous(),
atol=1e-3)
atol=1e-2)
@pytest.mark.sequential