mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
1 Commits
v0.11.0rc5
...
reduce_sca
Author | SHA1 | Date | |
---|---|---|---|
3679753af5 |
@ -61,6 +61,40 @@ class DeviceCommunicatorBase:
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
def reduce_scatter(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
|
||||
output_tensor = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
|
||||
# Perform reduce-scatter operation
|
||||
torch.distributed.reduce_scatter_tensor(output_tensor,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
|
||||
# Reshape before returning
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
|
@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
|
||||
pynccl_comm.reduce_scatter(output, input_)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
|
@ -114,10 +114,26 @@ def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
return group._all_reduce_out_place(tensor)
|
||||
|
||||
|
||||
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
group_name: str) -> torch.Tensor:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group.reduce_scatter(tensor, dim)
|
||||
|
||||
|
||||
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
return torch.empty_like(tensor)
|
||||
|
||||
|
||||
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
group_name: str) -> torch.Tensor:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[dim] = tensor.shape[dim] // world_size
|
||||
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce",
|
||||
@ -126,6 +142,13 @@ if supports_custom_op():
|
||||
fake_impl=all_reduce_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="reduce_scatter",
|
||||
op_func=reduce_scatter,
|
||||
mutates_args=[],
|
||||
fake_impl=reduce_scatter_fake,
|
||||
)
|
||||
|
||||
|
||||
class GroupCoordinator:
|
||||
"""
|
||||
@ -322,6 +345,18 @@ class GroupCoordinator:
|
||||
|
||||
return self.device_communicator.all_gather(input_, dim)
|
||||
|
||||
def reduce_scatter(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
return self.device_communicator.reduce_scatter(input_, dim)
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
|
Reference in New Issue
Block a user