mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
Add pyhccl (#503)
This is the first step to support trl vllm serve on Ascend NPU https://github.com/vllm-project/vllm-ascend/issues/459. This PR can work properly only when https://github.com/vllm-project/vllm/pull/16464 is merged into vLLM. --------- Signed-off-by: hzji210@gmail.com <hzji210@gmail.com>
This commit is contained in:
111
tests/multicard/test_pyhccl_distributed.py
Normal file
111
tests/multicard/test_pyhccl_distributed.py
Normal file
@ -0,0 +1,111 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
from vllm.distributed.parallel_state import (get_world_group,
|
||||
init_distributed_environment)
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl import \
|
||||
PyHcclCommunicator
|
||||
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
number_of_processes = world_size
|
||||
processes: list[multiprocessing.Process] = []
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env['RANK'] = str(i)
|
||||
env['LOCAL_RANK'] = str(i)
|
||||
env['WORLD_SIZE'] = str(number_of_processes)
|
||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
||||
env['MASTER_ADDR'] = 'localhost'
|
||||
env['MASTER_PORT'] = '12345'
|
||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
# `multiprocessing.Process` cannot accept environment variables directly
|
||||
# so we need to pass the environment variables as arguments
|
||||
# and update the environment variables in the function
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
local_rank = os.environ['LOCAL_RANK']
|
||||
device = torch.device(f"npu:{local_rank}")
|
||||
torch.npu.set_device(device)
|
||||
init_distributed_environment(backend="hccl")
|
||||
fn()
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
pynccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).npu(pynccl_comm.rank)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.npu.synchronize()
|
||||
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
||||
|
||||
|
||||
# def test_pyhccl():
|
||||
# distributed_run(worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def broadcast_worker_fn():
|
||||
# Test broadcast for every root rank.
|
||||
# Essentially this is an all-gather operation.
|
||||
pyhccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
recv_tensors = [
|
||||
torch.empty(16,
|
||||
1024,
|
||||
1024,
|
||||
dtype=torch.float32,
|
||||
device=pyhccl_comm.device)
|
||||
for i in range(pyhccl_comm.world_size)
|
||||
]
|
||||
recv_tensors[pyhccl_comm.rank] = torch.ones(
|
||||
16, 1024, 1024, dtype=torch.float32,
|
||||
device=pyhccl_comm.device) * pyhccl_comm.rank
|
||||
|
||||
for i in range(pyhccl_comm.world_size):
|
||||
pyhccl_comm.broadcast(recv_tensors[i], src=i)
|
||||
# the broadcast op might be launched in a different stream
|
||||
# need to synchronize to make sure the tensor is ready
|
||||
torch.npu.synchronize()
|
||||
assert torch.all(recv_tensors[i] == i).cpu().item()
|
||||
|
||||
|
||||
# def test_pyhccl_broadcast():
|
||||
# distributed_run(broadcast_worker_fn, 4)
|
30
tests/singlecard/test_pyhccl.py
Normal file
30
tests/singlecard/test_pyhccl.py
Normal file
@ -0,0 +1,30 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import \
|
||||
HCCLLibrary
|
||||
|
||||
|
||||
def test_hcclGetUniqueId():
|
||||
torch.npu.set_device(0)
|
||||
lib = HCCLLibrary()
|
||||
unique_id = lib.hcclGetUniqueId()
|
||||
assert unique_id is not None
|
@ -30,5 +30,6 @@ class NPUCommunicator(DeviceCommunicatorBase):
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
|
||||
# init device according to rank
|
||||
self.device = torch.npu.current_device()
|
||||
|
166
vllm_ascend/distributed/device_communicators/pyhccl.py
Normal file
166
vllm_ascend/distributed/device_communicators/pyhccl.py
Normal file
@ -0,0 +1,166 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu # noqa: F401
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
||||
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
|
||||
hcclRedOpTypeEnum, hcclUniqueId)
|
||||
from vllm_ascend.utils import current_stream
|
||||
|
||||
|
||||
class PyHcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyHcclCommunicator to. If None,
|
||||
it will be bind to f"npu:{local_rank}".
|
||||
library_path: the path to the HCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.HCCL, (
|
||||
"PyHcclCommunicator should be attached to a non-HCCL group.")
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
try:
|
||||
self.hccl = HCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing HCCL library
|
||||
# e.g. in a non-NPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("vLLM is using pyhccl")
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"npu:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from HCCL
|
||||
with torch.npu.device(device):
|
||||
self.unique_id = self.hccl.hcclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = hcclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
|
||||
# hccl communicator and stream will use this device
|
||||
# `torch.npu.device` is a context manager that changes the
|
||||
# current npu device to the specified one
|
||||
with torch.npu.device(device):
|
||||
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# hccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
hcclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
aclrtStream_t(stream.npu_stream))
|
||||
return out_tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
self.hccl.hcclBroadcast(buffer, tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, aclrtStream_t(stream.npu_stream))
|
253
vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py
Normal file
253
vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py
Normal file
@ -0,0 +1,253 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.utils import find_hccl_library
|
||||
|
||||
# export types and functions from hccl to Python ===
|
||||
# for the original hccl definition, please check
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48
|
||||
|
||||
hcclResult_t = ctypes.c_int
|
||||
hcclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class hcclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 4108)]
|
||||
|
||||
|
||||
aclrtStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
hcclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclDataTypeEnum:
|
||||
hcclInt8 = 0
|
||||
hcclInt16 = 1
|
||||
hcclInt32 = 2
|
||||
hcclFloat16 = 3
|
||||
hcclFloat32 = 4
|
||||
hcclInt64 = 5
|
||||
hcclUint64 = 6
|
||||
hcclUint8 = 7
|
||||
hcclUint16 = 8
|
||||
hcclUint32 = 9
|
||||
hcclFloat64 = 10
|
||||
hcclBfloat16 = 11
|
||||
hcclInt128 = 12
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.hcclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.hcclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.hcclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.hcclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.hcclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.hcclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.hcclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.hcclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
hcclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclRedOpTypeEnum:
|
||||
hcclSum = 0
|
||||
hcclProd = 1
|
||||
hcclMax = 2
|
||||
hcclMin = 3
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.hcclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.hcclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.hcclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.hcclMin
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class HCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* HcclGetErrorString(HcclResult code);
|
||||
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
|
||||
|
||||
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
|
||||
Function("HcclGetRootInfo", hcclResult_t,
|
||||
[ctypes.POINTER(hcclUniqueId)]),
|
||||
|
||||
# HcclResult HcclCommInitRootInfo(
|
||||
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
|
||||
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
|
||||
Function("HcclCommInitRootInfo", hcclResult_t, [
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclUniqueId),
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclComm_t),
|
||||
]),
|
||||
|
||||
# HcclResult HcclAllReduce(
|
||||
# void *sendBuf, void *recvBuf, uint64_t count,
|
||||
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
|
||||
# aclrtStream stream);
|
||||
Function("HcclAllReduce", hcclResult_t, [
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
hcclRedOp_t,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclBroadcast(
|
||||
# void *buf, uint64_t count,
|
||||
# HcclDataType dataType, uint32_t root,
|
||||
# HcclComm comm, aclrtStream stream);
|
||||
Function("HcclBroadcast", hcclResult_t, [
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
ctypes.c_int,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclCommDestroy(HcclComm comm);
|
||||
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the correspongding directory
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_hccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
HCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = HCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load HCCL library from %s. "
|
||||
"It is expected if you are not running on Ascend NPUs."
|
||||
"Otherwise, the hccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable HCCL_SO_PATH"
|
||||
" to point to the correct hccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
for func in HCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
HCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = HCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def hcclGetErrorString(self, result: hcclResult_t) -> str:
|
||||
return self._funcs["HcclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def HCCL_CHECK(self, result: hcclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.hcclGetErrorString(result)
|
||||
raise RuntimeError(f"HCCL error: {error_str}")
|
||||
|
||||
def hcclGetUniqueId(self) -> hcclUniqueId:
|
||||
unique_id = hcclUniqueId()
|
||||
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
|
||||
rank: int) -> hcclComm_t:
|
||||
comm = hcclComm_t()
|
||||
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
|
||||
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
|
||||
return comm
|
||||
|
||||
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
# `datatype` actually should be `hcclDataType_t`
|
||||
# and `op` should be `hcclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
|
||||
root: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
|
||||
root, comm, stream))
|
||||
|
||||
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HCCLLibrary",
|
||||
"hcclDataTypeEnum",
|
||||
"hcclRedOpTypeEnum",
|
||||
"hcclUniqueId",
|
||||
"hcclComm_t",
|
||||
"aclrtStream_t",
|
||||
"buffer_type",
|
||||
]
|
@ -46,6 +46,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# Used for disaggregated prefilling
|
||||
"HCCN_PATH":
|
||||
lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
|
||||
"HCCL_SO_PATH":
|
||||
lambda: os.environ.get("HCCL_SO_PATH", None),
|
||||
"PROMPT_DEVICE_ID":
|
||||
lambda: os.getenv("PROMPT_DEVICE_ID", None),
|
||||
"DECODE_DEVICE_ID":
|
||||
@ -53,7 +55,7 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
"LLMDATADIST_COMM_PORT":
|
||||
lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
|
||||
"LLMDATADIST_SYNC_CACHE_WAIT_TIME":
|
||||
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000")
|
||||
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -17,8 +17,11 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
|
||||
|
||||
def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||
import importlib
|
||||
@ -33,6 +36,28 @@ def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||
pass
|
||||
|
||||
|
||||
def find_hccl_library() -> str:
|
||||
"""
|
||||
We either use the library file specified by the `HCCL_SO_PATH`
|
||||
environment variable, or we find the library file brought by PyTorch.
|
||||
After importing `torch`, `libhccl.so` can be
|
||||
found by `ctypes` automatically.
|
||||
"""
|
||||
so_file = envs.HCCL_SO_PATH
|
||||
|
||||
# manually load the hccl library
|
||||
if so_file:
|
||||
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
|
||||
so_file)
|
||||
else:
|
||||
if torch.version.cann is not None:
|
||||
so_file = "libhccl.so"
|
||||
else:
|
||||
raise ValueError("HCCL only supports Ascend NPU backends.")
|
||||
logger.info("Found hccl from library %s", so_file)
|
||||
return so_file
|
||||
|
||||
|
||||
_current_stream = None
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user