Compare commits

...

1 Commits

Author SHA1 Message Date
3679753af5 Reduce Scatter Plumbing
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-28 16:33:52 +00:00
3 changed files with 94 additions and 0 deletions

View File

@ -61,6 +61,40 @@ class DeviceCommunicatorBase:
input_size[dim + 1:])
return output_tensor
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output_tensor = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
# Perform reduce-scatter operation
torch.distributed.reduce_scatter_tensor(output_tensor,
input_tensor,
group=self.device_group)
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def gather(self,
input_: torch.Tensor,
dst: int = 0,

View File

@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch.distributed.all_reduce(out, group=self.device_group)
return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""

View File

@ -114,10 +114,26 @@ def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return group._all_reduce_out_place(tensor)
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.reduce_scatter(tensor, dim)
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] // world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
if supports_custom_op():
direct_register_custom_op(
op_name="all_reduce",
@ -126,6 +142,13 @@ if supports_custom_op():
fake_impl=all_reduce_fake,
)
direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
)
class GroupCoordinator:
"""
@ -322,6 +345,18 @@ class GroupCoordinator:
return self.device_communicator.all_gather(input_, dim)
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
return self.device_communicator.reduce_scatter(input_, dim)
def gather(self,
input_: torch.Tensor,
dst: int = 0,