mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[TPU] optimize the all-reduce performance (#15903)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@ -22,6 +22,8 @@ if current_platform.is_tpu():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
from torch_xla.distributed.xla_multiprocessing import (
|
||||
create_optimized_replica_groups)
|
||||
|
||||
if USE_RAY:
|
||||
from vllm.executor import ray_utils
|
||||
@ -79,9 +81,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||
xr._init_world_size_ordinal()
|
||||
self.groups = create_optimized_replica_groups()
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, input_)
|
||||
# TODO: Remove the groups specification after XLA compiler can support
|
||||
# auto-reordering the ring order for all-reduce.
|
||||
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||
|
@ -119,11 +119,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
from vllm.platforms import current_platform
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce",
|
||||
op_func=all_reduce,
|
||||
mutates_args=[],
|
||||
fake_impl=all_reduce_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
@ -219,7 +221,8 @@ class GroupCoordinator:
|
||||
self.cpu_group, 1 << 22, 6)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
self.use_custom_op_call = current_platform.is_cuda_alike()
|
||||
self.use_custom_op_call = (current_platform.is_cuda_alike()
|
||||
or current_platform.is_tpu())
|
||||
|
||||
@property
|
||||
def first_rank(self):
|
||||
|
@ -84,6 +84,12 @@ class TPUWorker:
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
# ring, the xla tpu compiler flag
|
||||
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
||||
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
||||
os.environ["LIBTPU_INIT_ARGS"] = (
|
||||
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
|
Reference in New Issue
Block a user