[TPU] optimize the all-reduce performance (#15903)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-04-02 17:25:14 -07:00
committed by GitHub
parent 1b84eff03a
commit 01b6113659
3 changed files with 16 additions and 2 deletions

View File

@ -22,6 +22,8 @@ if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils
@ -79,9 +81,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
self.groups = create_optimized_replica_groups()
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, input_)
# TODO: Remove the groups specification after XLA compiler can support
# auto-reordering the ring order for all-reduce.
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."

View File

@ -119,11 +119,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
if supports_custom_op():
from vllm.platforms import current_platform
direct_register_custom_op(
op_name="all_reduce",
op_func=all_reduce,
mutates_args=[],
fake_impl=all_reduce_fake,
dispatch_key=current_platform.dispatch_key,
)
@ -219,7 +221,8 @@ class GroupCoordinator:
self.cpu_group, 1 << 22, 6)
from vllm.platforms import current_platform
self.use_custom_op_call = current_platform.is_cuda_alike()
self.use_custom_op_call = (current_platform.is_cuda_alike()
or current_platform.is_tpu())
@property
def first_rank(self):

View File

@ -84,6 +84,12 @@ class TPUWorker:
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)