[torch.compile] improve allreduce registration (#9061)

This commit is contained in:
youkaichao
2024-10-04 16:43:50 -07:00
committed by GitHub
parent cc90419e89
commit 663874e048
2 changed files with 21 additions and 32 deletions

View File

@ -265,24 +265,21 @@ class CustomAllreduce:
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
# when custom allreduce is disabled, this will be None
if self.disabled:
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
if self.should_custom_ar(input):
return self.all_reduce_reg(input)
return self.all_reduce_reg(input)
else:
if self.should_custom_ar(input):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if self.should_custom_ar(input):
return self.all_reduce_unreg(input)
return self.all_reduce_unreg(input)
return None

View File

@ -105,7 +105,7 @@ if supports_custom_op():
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce(tensor)
group._all_reduce_in_place(tensor)
@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
@ -118,7 +118,7 @@ if supports_custom_op():
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce(tensor)
return group._all_reduce_out_place(tensor)
@outplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
@ -338,14 +338,17 @@ class GroupCoordinator:
return input_
if not supports_custom_op():
return self._all_reduce(input_)
self._all_reduce_in_place(input_)
return input_
if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self._all_reduce(input_)
return self.tpu_communicator.all_reduce(input_)
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
if self.ca_comm is not None and \
not self.ca_comm.disabled and \
self.ca_comm.should_custom_ar(input_):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
@ -353,25 +356,15 @@ class GroupCoordinator:
group_name=self.unique_name)
return input_
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
ca_comm = self.ca_comm
assert ca_comm is not None
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_reduce(input_)
if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
return out
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
@ -380,7 +373,6 @@ class GroupCoordinator:
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size