mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
117 lines
4.4 KiB
Python
117 lines
4.4 KiB
Python
"""batched collective operations for overhead amortization and better
|
|
bandwidth utilization"""
|
|
|
|
import math
|
|
from typing import List
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
import torch.distributed
|
|
from torch.distributed import ProcessGroup
|
|
import torch.nn.functional
|
|
|
|
from deepspeed.utils import instrument_w_nvtx
|
|
from deepspeed.utils.logging import logger
|
|
|
|
if hasattr(torch.distributed, "_reduce_scatter_base"):
|
|
|
|
def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group):
|
|
instrument_w_nvtx(torch.distributed._reduce_scatter_base)(
|
|
output_tensor,
|
|
input_tensor,
|
|
group=group,
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
|
|
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
|
|
"please consider upgrading your pytorch installation.")
|
|
|
|
def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group):
|
|
input_tensor_lst = list(
|
|
torch.chunk(input_tensor,
|
|
torch.distributed.get_world_size(group)))
|
|
instrument_w_nvtx(torch.distributed.reduce_scatter)(
|
|
output_tensor,
|
|
input_tensor_lst,
|
|
group=group,
|
|
)
|
|
|
|
|
|
@instrument_w_nvtx
|
|
@torch.no_grad()
|
|
def reduce_scatter_coalesced(
|
|
tensors: List[Tensor],
|
|
group: ProcessGroup = None,
|
|
) -> List[Tensor]:
|
|
"""simultaneously reduce-scatter a list of tensors - this can be done more
|
|
efficiently than individual reduce scatter calls
|
|
|
|
TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL
|
|
"""
|
|
this_rank = torch.distributed.get_rank(group)
|
|
world_sz = torch.distributed.get_world_size(group)
|
|
|
|
partition_lst_for_each_tensor = [None] * len(tensors)
|
|
for tensor_idx, tensor in enumerate(tensors):
|
|
flattened_tensor = tensor.view(-1)
|
|
chunk_sz = math.ceil(tensor.numel() / world_sz)
|
|
partition_lst_for_each_tensor[tensor_idx] = [
|
|
flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz]
|
|
for rank in range(0,
|
|
world_sz)
|
|
]
|
|
|
|
padded_partition_sz_for_each_tensor = tuple(
|
|
math.ceil(t.numel() / world_sz) for t in tensors)
|
|
|
|
if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
|
|
# if there's only one tensor being reduced and we don't need to pad
|
|
# we have an opportunity to avoid a memory allocation
|
|
tensor_partition_flat_buffer = tensors[0].view(-1)
|
|
else:
|
|
# interleave tensor partitions such that the correct reduced partitions of each tensor
|
|
# end up at each rank
|
|
tensor_partitions_lst_with_padding = []
|
|
for rank in range(world_sz):
|
|
for tensor_idx in range(len(tensors)):
|
|
# add tensor content
|
|
tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
|
|
tensor_partitions_lst_with_padding.append(tensor_chunk)
|
|
|
|
# add padding if necessary
|
|
padding_sz = padded_partition_sz_for_each_tensor[
|
|
tensor_idx] - tensor_chunk.numel()
|
|
if padding_sz > 0:
|
|
tensor_partitions_lst_with_padding.append(
|
|
torch.empty(padding_sz,
|
|
dtype=tensor_chunk.dtype,
|
|
device=tensor_chunk.device))
|
|
|
|
tensor_partition_flat_buffer = instrument_w_nvtx(
|
|
torch.cat)(tensor_partitions_lst_with_padding)
|
|
|
|
tensor_partition_flat_buffer.div_(world_sz) # pre-divide
|
|
tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(
|
|
tensor_partition_flat_buffer,
|
|
world_sz)
|
|
|
|
# batched reduce-scatter call
|
|
torch_reduce_scatter_fn(tensor_partition_flat_buffer,
|
|
tensor_partition_buffer_for_each_rank[this_rank],
|
|
group)
|
|
|
|
# reverse procedure of the interleaving done previously, done on the
|
|
# result of the batched reduce-scatter
|
|
output_lst: List[Tensor] = [None] * len(tensors)
|
|
offset = 0
|
|
for tensor_idx in range(len(tensors)):
|
|
output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
|
|
0,
|
|
offset,
|
|
partition_lst_for_each_tensor[tensor_idx][this_rank].numel())
|
|
|
|
offset += padded_partition_sz_for_each_tensor[tensor_idx]
|
|
|
|
return output_lst
|