mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[torch.compile] improve allreduce registration (#9061)
This commit is contained in:
@ -265,24 +265,21 @@ class CustomAllreduce:
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# when custom allreduce is disabled, this will be None
|
||||
if self.disabled:
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if self.should_custom_ar(input):
|
||||
return self.all_reduce_reg(input)
|
||||
return self.all_reduce_reg(input)
|
||||
else:
|
||||
if self.should_custom_ar(input):
|
||||
# if warm up, mimic the allocation pattern
|
||||
# since custom allreduce is out-of-place
|
||||
return torch.empty_like(input)
|
||||
# if warm up, mimic the allocation pattern
|
||||
# since custom allreduce is out-of-place
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
# gains of using custom kernels
|
||||
if self.should_custom_ar(input):
|
||||
return self.all_reduce_unreg(input)
|
||||
return self.all_reduce_unreg(input)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -105,7 +105,7 @@ if supports_custom_op():
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
group._all_reduce(tensor)
|
||||
group._all_reduce_in_place(tensor)
|
||||
|
||||
@inplace_all_reduce.register_fake
|
||||
def _(tensor: torch.Tensor, group_name: str) -> None:
|
||||
@ -118,7 +118,7 @@ if supports_custom_op():
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group._all_reduce(tensor)
|
||||
return group._all_reduce_out_place(tensor)
|
||||
|
||||
@outplace_all_reduce.register_fake
|
||||
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
@ -338,14 +338,17 @@ class GroupCoordinator:
|
||||
return input_
|
||||
|
||||
if not supports_custom_op():
|
||||
return self._all_reduce(input_)
|
||||
self._all_reduce_in_place(input_)
|
||||
return input_
|
||||
|
||||
if self.tpu_communicator is not None and \
|
||||
not self.tpu_communicator.disabled:
|
||||
# TPU handles Dynamo with its own logic.
|
||||
return self._all_reduce(input_)
|
||||
return self.tpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
|
||||
if self.ca_comm is not None and \
|
||||
not self.ca_comm.disabled and \
|
||||
self.ca_comm.should_custom_ar(input_):
|
||||
return torch.ops.vllm.outplace_all_reduce(
|
||||
input_, group_name=self.unique_name)
|
||||
else:
|
||||
@ -353,25 +356,15 @@ class GroupCoordinator:
|
||||
group_name=self.unique_name)
|
||||
return input_
|
||||
|
||||
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The actual all-reduce implementation.
|
||||
|
||||
NOTE: This operation will be applied in-place or out-of-place.
|
||||
Always assume this function modifies its input, but use the return
|
||||
value as the output.
|
||||
"""
|
||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
ca_comm = self.ca_comm
|
||||
assert ca_comm is not None
|
||||
assert not ca_comm.disabled
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
|
||||
# For TPUs, use TPU communicator.
|
||||
tpu_comm = self.tpu_communicator
|
||||
if tpu_comm is not None and not tpu_comm.disabled:
|
||||
return tpu_comm.all_reduce(input_)
|
||||
|
||||
if ca_comm is not None:
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||
pynccl_comm.all_reduce(input_)
|
||||
@ -380,7 +373,6 @@ class GroupCoordinator:
|
||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
|
Reference in New Issue
Block a user