Add pyhccl (#503)

This is the first step to support trl vllm serve on Ascend NPU
https://github.com/vllm-project/vllm-ascend/issues/459.
This PR can work properly only when
https://github.com/vllm-project/vllm/pull/16464 is merged into vLLM.

---------

Signed-off-by: hzji210@gmail.com <hzji210@gmail.com>
This commit is contained in:
Huazhong Ji
2025-04-17 14:57:52 +08:00
committed by GitHub
parent 64fdf4cbef
commit c3d1a3782a
8 changed files with 589 additions and 1 deletions

View File

@ -0,0 +1,111 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import multiprocessing
import os
import torch
import torch_npu # noqa: F401
from vllm.distributed.parallel_state import (get_world_group,
init_distributed_environment)
from vllm.utils import update_environment_variables
from vllm_ascend.distributed.device_communicators.pyhccl import \
PyHcclCommunicator
def distributed_run(fn, world_size):
number_of_processes = world_size
processes: list[multiprocessing.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, ))
processes.append(p)
p.start()
for p in processes:
p.join()
for p in processes:
assert p.exitcode == 0
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ['LOCAL_RANK']
device = torch.device(f"npu:{local_rank}")
torch.npu.set_device(device)
init_distributed_environment(backend="hccl")
fn()
return wrapped_fn
@worker_fn_wrapper
def worker_fn():
pynccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).npu(pynccl_comm.rank)
tensor = pynccl_comm.all_reduce(tensor)
torch.npu.synchronize()
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
# def test_pyhccl():
# distributed_run(worker_fn, 2)
@worker_fn_wrapper
def broadcast_worker_fn():
# Test broadcast for every root rank.
# Essentially this is an all-gather operation.
pyhccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
recv_tensors = [
torch.empty(16,
1024,
1024,
dtype=torch.float32,
device=pyhccl_comm.device)
for i in range(pyhccl_comm.world_size)
]
recv_tensors[pyhccl_comm.rank] = torch.ones(
16, 1024, 1024, dtype=torch.float32,
device=pyhccl_comm.device) * pyhccl_comm.rank
for i in range(pyhccl_comm.world_size):
pyhccl_comm.broadcast(recv_tensors[i], src=i)
# the broadcast op might be launched in a different stream
# need to synchronize to make sure the tensor is ready
torch.npu.synchronize()
assert torch.all(recv_tensors[i] == i).cpu().item()
# def test_pyhccl_broadcast():
# distributed_run(broadcast_worker_fn, 4)

View File

@ -0,0 +1,30 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch_npu # noqa: F401
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import \
HCCLLibrary
def test_hcclGetUniqueId():
torch.npu.set_device(0)
lib = HCCLLibrary()
unique_id = lib.hcclGetUniqueId()
assert unique_id is not None

View File

@ -30,5 +30,6 @@ class NPUCommunicator(DeviceCommunicatorBase):
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank
self.device = torch.npu.current_device()

View File

@ -0,0 +1,166 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Union
import torch
import torch.distributed as dist
import torch_npu # noqa: F401
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import logger
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
hcclRedOpTypeEnum, hcclUniqueId)
from vllm_ascend.utils import current_stream
class PyHcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyHcclCommunicator to. If None,
it will be bind to f"npu:{local_rank}".
library_path: the path to the HCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.HCCL, (
"PyHcclCommunicator should be attached to a non-HCCL group.")
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group
# if world_size == 1, no need to create communicator
if self.world_size == 1:
self.available = False
self.disabled = True
return
try:
self.hccl = HCCLLibrary(library_path)
except Exception:
# disable because of missing HCCL library
# e.g. in a non-NPU environment
self.available = False
self.disabled = True
return
self.available = True
self.disabled = False
logger.info("vLLM is using pyhccl")
if isinstance(device, int):
device = torch.device(f"npu:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
if self.rank == 0:
# get the unique id from HCCL
with torch.npu.device(device):
self.unique_id = self.hccl.hcclGetUniqueId()
else:
# construct an empty unique id
self.unique_id = hcclUniqueId()
if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
# hccl communicator and stream will use this device
# `torch.npu.device` is a context manager that changes the
# current npu device to the specified one
with torch.npu.device(device):
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
self.world_size, self.unique_id, self.rank)
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
stream.synchronize()
del data
def all_reduce(self,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
if self.disabled:
return None
# hccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
hcclDataTypeEnum.from_torch(in_tensor.dtype),
hcclRedOpTypeEnum.from_torch(op), self.comm,
aclrtStream_t(stream.npu_stream))
return out_tensor
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
if src == self.rank:
buffer = buffer_type(tensor.data_ptr())
else:
buffer = buffer_type(tensor.data_ptr())
self.hccl.hcclBroadcast(buffer, tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, aclrtStream_t(stream.npu_stream))

View File

@ -0,0 +1,253 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from torch.distributed import ReduceOp
from vllm.logger import logger
from vllm_ascend.utils import find_hccl_library
# export types and functions from hccl to Python ===
# for the original hccl definition, please check
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48
hcclResult_t = ctypes.c_int
hcclComm_t = ctypes.c_void_p
class hcclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 4108)]
aclrtStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
hcclDataType_t = ctypes.c_int
class hcclDataTypeEnum:
hcclInt8 = 0
hcclInt16 = 1
hcclInt32 = 2
hcclFloat16 = 3
hcclFloat32 = 4
hcclInt64 = 5
hcclUint64 = 6
hcclUint8 = 7
hcclUint16 = 8
hcclUint32 = 9
hcclFloat64 = 10
hcclBfloat16 = 11
hcclInt128 = 12
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.hcclInt8
if dtype == torch.uint8:
return cls.hcclUint8
if dtype == torch.int32:
return cls.hcclInt32
if dtype == torch.int64:
return cls.hcclInt64
if dtype == torch.float16:
return cls.hcclFloat16
if dtype == torch.float32:
return cls.hcclFloat32
if dtype == torch.float64:
return cls.hcclFloat64
if dtype == torch.bfloat16:
return cls.hcclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
hcclRedOp_t = ctypes.c_int
class hcclRedOpTypeEnum:
hcclSum = 0
hcclProd = 1
hcclMax = 2
hcclMin = 3
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.hcclSum
if op == ReduceOp.PRODUCT:
return cls.hcclProd
if op == ReduceOp.MAX:
return cls.hcclMax
if op == ReduceOp.MIN:
return cls.hcclMin
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
class HCCLLibrary:
exported_functions = [
# const char* HcclGetErrorString(HcclResult code);
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
Function("HcclGetRootInfo", hcclResult_t,
[ctypes.POINTER(hcclUniqueId)]),
# HcclResult HcclCommInitRootInfo(
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
Function("HcclCommInitRootInfo", hcclResult_t, [
ctypes.c_int,
ctypes.POINTER(hcclUniqueId),
ctypes.c_int,
ctypes.POINTER(hcclComm_t),
]),
# HcclResult HcclAllReduce(
# void *sendBuf, void *recvBuf, uint64_t count,
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
# aclrtStream stream);
Function("HcclAllReduce", hcclResult_t, [
buffer_type,
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
hcclRedOp_t,
hcclComm_t,
aclrtStream_t,
]),
# HcclResult HcclBroadcast(
# void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root,
# HcclComm comm, aclrtStream stream);
Function("HcclBroadcast", hcclResult_t, [
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
ctypes.c_int,
hcclComm_t,
aclrtStream_t,
]),
# HcclResult HcclCommDestroy(HcclComm comm);
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the correspongding directory
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_hccl_library()
try:
if so_file not in HCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
HCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = HCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load HCCL library from %s. "
"It is expected if you are not running on Ascend NPUs."
"Otherwise, the hccl library might not exist, be corrupted "
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable HCCL_SO_PATH"
" to point to the correct hccl library path.", so_file,
platform.platform())
raise e
if so_file not in HCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
for func in HCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
HCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = HCCLLibrary.path_to_dict_mapping[so_file]
def hcclGetErrorString(self, result: hcclResult_t) -> str:
return self._funcs["HcclGetErrorString"](result).decode("utf-8")
def HCCL_CHECK(self, result: hcclResult_t) -> None:
if result != 0:
error_str = self.hcclGetErrorString(result)
raise RuntimeError(f"HCCL error: {error_str}")
def hcclGetUniqueId(self) -> hcclUniqueId:
unique_id = hcclUniqueId()
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
ctypes.byref(unique_id)))
return unique_id
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
rank: int) -> hcclComm_t:
comm = hcclComm_t()
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
return comm
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
# `datatype` actually should be `hcclDataType_t`
# and `op` should be `hcclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
root: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
root, comm, stream))
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
__all__ = [
"HCCLLibrary",
"hcclDataTypeEnum",
"hcclRedOpTypeEnum",
"hcclUniqueId",
"hcclComm_t",
"aclrtStream_t",
"buffer_type",
]

View File

@ -46,6 +46,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
# Used for disaggregated prefilling
"HCCN_PATH":
lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
"HCCL_SO_PATH":
lambda: os.environ.get("HCCL_SO_PATH", None),
"PROMPT_DEVICE_ID":
lambda: os.getenv("PROMPT_DEVICE_ID", None),
"DECODE_DEVICE_ID":
@ -53,7 +55,7 @@ env_variables: Dict[str, Callable[[], Any]] = {
"LLMDATADIST_COMM_PORT":
lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
"LLMDATADIST_SYNC_CACHE_WAIT_TIME":
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000")
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"),
}
# end-env-vars-definition

View File

@ -17,8 +17,11 @@
# limitations under the License.
#
import torch
import torch_npu # noqa: F401
from vllm.logger import logger
import vllm_ascend.envs as envs
def try_register_lib(lib_name: str, lib_info: str = ""):
import importlib
@ -33,6 +36,28 @@ def try_register_lib(lib_name: str, lib_info: str = ""):
pass
def find_hccl_library() -> str:
"""
We either use the library file specified by the `HCCL_SO_PATH`
environment variable, or we find the library file brought by PyTorch.
After importing `torch`, `libhccl.so` can be
found by `ctypes` automatically.
"""
so_file = envs.HCCL_SO_PATH
# manually load the hccl library
if so_file:
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
so_file)
else:
if torch.version.cann is not None:
so_file = "libhccl.so"
else:
raise ValueError("HCCL only supports Ascend NPU backends.")
logger.info("Found hccl from library %s", so_file)
return so_file
_current_stream = None