[Core] remove cupy dependency (#3625)

This commit is contained in:
youkaichao
2024-03-27 00:33:26 -07:00
committed by GitHub
parent e66b629c04
commit 8f44facddd
17 changed files with 506 additions and 223 deletions

View File

@ -22,10 +22,13 @@ steps:
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
- label: Distributed Correctness Test
command: pytest -v -s --forked test_basic_distributed_correctness.py
- label: Distributed Tests
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
commands:
- pytest -v -s --forked test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py
- label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py

View File

@ -97,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal
#################### RUNTIME BASE IMAGE ####################
# We used base cuda image because pytorch installs its own cuda libraries.
# However cupy depends on cuda libraries so we had to switch to the runtime image
# However pynccl depends on cuda libraries so we had to switch to the runtime image
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base

View File

@ -23,9 +23,6 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"
# whether to build cupy on rocm
ARG BUILD_CUPY="1"
# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
@ -78,23 +75,6 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
# build cupy
RUN if [ "$BUILD_CUPY" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone -b hipgraph_enablement --recursive https://github.com/ROCm/cupy.git \
&& cd cupy \
&& pip install mpi4py-mpich \
&& pip install scipy==1.9.3 \
&& pip install cython==0.29.* \
&& env CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
&& export CUPY_INSTALL_USE_HIP=1 \
&& export ROCM_HOME=/opt/rocm \
&& export HCC_AMDGPU_TARGET="gfx90a,gfx942,gfx1100" \
&& pip install . \
&& cd ..; \
fi
COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip

View File

@ -14,4 +14,3 @@ prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
outlines == 0.0.34
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.

View File

@ -306,12 +306,6 @@ def get_requirements() -> List[str]:
if _is_cuda():
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
if get_nvcc_cuda_version() <= Version("11.8"):
# replace cupy-cuda12x with cupy-cuda11x for cuda 11.x
for i in range(len(requirements)):
if requirements[i].startswith("cupy-cuda12x"):
requirements[i] = "cupy-cuda11x"
break
elif _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")

View File

@ -1,13 +1,22 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_basic_distributed_correctness.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_basic_distributed_correctness.py
```
"""
import os
import pytest
import torch
MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
os.environ["TEST_DIST_MODEL"],
]

View File

@ -2,6 +2,8 @@
Run `pytest tests/distributed/test_comm_ops.py --forked`.
"""
import os
import pytest
import ray
import torch
@ -16,6 +18,12 @@ from vllm.test_utils import (init_test_distributed_environment,
@ray.remote(num_gpus=1, max_calls=1)
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_elements = 8
@ -32,6 +40,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_dimensions = 3
@ -54,6 +68,12 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
test_dict = {

View File

@ -0,0 +1,90 @@
import multiprocessing
import os
import pytest
import torch
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = os.environ.copy()
env['RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, ))
processes.append(p)
p.start()
for p in processes:
p.join()
def update_env(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
import os
os.environ.update(env)
fn()
return wrapper
@update_env
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == comm.world_size
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
distributed_run(worker_fn, 2)
@update_env
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
comm = NCCLCommunicator()
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph, stream=comm.stream):
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
comm.all_reduce(a)
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**0
graph.replay()
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**1
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
distributed_run(worker_fn_with_cudagraph, 2)
def test_ncclGetUniqueId():
unique_id = ncclGetUniqueId()
# `list(unique_id.internal)` is something like this:
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# as long as the function doesn't raise an exception, we're good
assert unique_id is not None

View File

@ -188,11 +188,9 @@ class RayGPUExecutor(ExecutorBase):
is_driver_worker=True,
)
# FIXME(woosuk): We are not properly initializing cupy NCCL when
# FIXME(woosuk): We are not properly initializing pynccl when
# we have multiple nodes.
self._run_workers("init_device",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers("init_device")
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.

View File

@ -4,12 +4,12 @@ from typing import Any, Dict, List, Optional, Union
import torch
from torch.distributed import ProcessGroup
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils import pynccl_utils
from vllm.model_executor.parallel_utils.custom_all_reduce import (
custom_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, is_cupy_nccl_enabled_for_all_reduce)
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@ -30,9 +30,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out = custom_all_reduce(input_)
if out is not None:
return out
if is_cupy_nccl_enabled_for_all_reduce():
if is_pynccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
cupy_utils.all_reduce(input_)
pynccl_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())

View File

@ -1,130 +0,0 @@
"""CuPy utilities for all-reduce.
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.
NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import contextlib
import torch
from torch.distributed import ReduceOp
try:
import cupy
from cupy.cuda import nccl
from cupyx.distributed import NCCLBackend
except ImportError as e:
cupy = e
nccl = None
class NCCLBackend:
...
_OP_MAPPING = {
ReduceOp.SUM: "sum",
ReduceOp.PRODUCT: "prod",
ReduceOp.MIN: "min",
ReduceOp.MAX: "max",
}
class NCCLBackendWithBFloat16(NCCLBackend):
# This is enough to add bfloat16 support for most operations,
# but broadcast will fail (will require changes in compiled
# cupy code).
def _get_nccl_dtype_and_count(self, array, count=None):
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
torch_dtype = getattr(array, "_torch_dtype", None)
if torch_dtype is torch.bfloat16:
nccl_dtype = nccl.NCCL_BFLOAT16
return nccl_dtype, count
def barrier(self) -> None:
raise RuntimeError(
"Currently, CuPy NCCL barrier is not supported since the TCP "
"store is immediately stopped after the initialization.")
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return _NCCL_BACKEND is not None
@contextlib.contextmanager
def set_cupy_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
stream.device_index)
with cupy_stream:
yield
def init_process_group(world_size: int, rank: int, host: str,
port: int) -> None:
"""Initializes the CuPy NCCL backend.
# TODO: handle NCCL timeouts.
"""
assert not is_initialized()
if isinstance(cupy, Exception):
raise ImportError(
"NCCLBackend is not available. Please install cupy.") from cupy
# TODO(woosuk): Create TP and PP process groups for CuPy.
global _NCCL_BACKEND
global _WORLD_SIZE
assert world_size > 0, f"{world_size=} should be a positive integer"
assert 0 <= rank < world_size, (
f"{rank=} should be a integer between [0, {world_size})")
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
_WORLD_SIZE = world_size
# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
_NCCL_BACKEND._store.stop()
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
# Hack to support bfloat16
torch_dtype = input_.dtype
if torch_dtype is torch.bfloat16:
# We need to view as float16, otherwise
# cupy will fail. This will not change
# the underlying data.
input_ = input_.view(torch.float16)
cupy_input = cupy.asarray(input_)
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
out_array=cupy_input,
op=_OP_MAPPING[op])
def destroy_process_group() -> None:
"""Destroys the NCCL backend."""
global _NCCL_BACKEND
global _WORLD_SIZE
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def get_world_size() -> int:
"""Returns the world size."""
return _WORLD_SIZE
def get_nccl_backend():
return _NCCL_BACKEND

View File

@ -7,7 +7,7 @@ import contextlib
import torch
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils import pynccl_utils
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
@ -210,36 +210,36 @@ def destroy_model_parallel():
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None
# Destroy the cupy states if any.
cupy_utils.destroy_process_group()
# Destroy the pynccl states if any.
pynccl_utils.destroy_process_group()
# Whether to use cupy for nccl all reduce.
# We use cupy for all reduce when using CUDA graph, because torch.distributed
# Whether to use pynccl for nccl all reduce.
# We use pynccl for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_CUPY_FOR_ALL_REDUCE = False
_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
@contextlib.contextmanager
def with_cupy_nccl_for_all_reduce():
"""use CuPy nccl instead of torch.distributed for all reduce"""
def with_pynccl_for_all_reduce():
"""use pynccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
# No-op.
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
# NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
yield
else:
global _ENABLE_CUPY_FOR_ALL_REDUCE
old = _ENABLE_CUPY_FOR_ALL_REDUCE
_ENABLE_CUPY_FOR_ALL_REDUCE = True
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
_ENABLE_PYNCCL_FOR_ALL_REDUCE = True
stream = torch.cuda.current_stream()
with cupy_utils.set_cupy_stream(stream):
with pynccl_utils.set_pynccl_stream(stream):
yield
_ENABLE_CUPY_FOR_ALL_REDUCE = old
_ENABLE_PYNCCL_FOR_ALL_REDUCE = old
def is_cupy_nccl_enabled_for_all_reduce():
"""check if CuPy nccl is enabled for all reduce"""
global _ENABLE_CUPY_FOR_ALL_REDUCE
return _ENABLE_CUPY_FOR_ALL_REDUCE
def is_pynccl_enabled_for_all_reduce():
"""check if pynccl is enabled for all reduce"""
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
return _ENABLE_PYNCCL_FOR_ALL_REDUCE

View File

@ -0,0 +1,258 @@
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import datetime
import logging
import os
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
logger = logging.getLogger(__name__)
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")
# manually load the nccl library
if so_file:
logger.info(
f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}")
else:
if torch.version.cuda is not None:
so_file = "libnccl.so"
elif torch.version.hip is not None:
so_file = "librccl.so"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.debug(f"Loading nccl from library {so_file}")
try:
nccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
f"Failed to load NCCL library from {so_file} ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.")
raise e
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
_c_ncclGetVersion.restype = ctypes.c_int
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def ncclGetVersion() -> str:
version = ctypes.c_int()
result = _c_ncclGetVersion(ctypes.byref(version))
assert result == 0
# something like 21903 --> "2.19.3"
version_str = str(version.value)
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
class NcclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
# equivalent to c declaration:
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
_c_ncclGetUniqueId.restype = ctypes.c_int
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId()
result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
assert result == 0
return unique_id
# equivalent to c declaration:
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_ncclCommInitRank = nccl.ncclCommInitRank
_c_ncclCommInitRank.restype = ctypes.c_int
_c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]
# enums
class ncclDataType_t(ctypes.c_int):
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
class ncclRedOp_t(ctypes.c_int):
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
# equivalent to c declaration:
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# udaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument is a pointer
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
]
# equivalent to c declaration:
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy = nccl.ncclCommDestroy
_c_ncclCommDestroy.restype = ctypes.c_int
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
class NCCLCommunicator:
def __init__(
self,
backend=None,
init_method=None,
timeout=datetime.timedelta(seconds=10),
world_size: int = -1,
rank: int = -1,
store=None,
group_name: str = "",
pg_options=None,
):
if not dist.is_initialized():
backend = backend or "nccl"
assert backend == 'nccl', (
"only use nccl backend for starting the NCCL communicator")
dist.init_process_group(backend=backend,
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=rank,
store=store,
group_name=group_name,
pg_options=pg_options)
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
torch.cuda.set_device(self.rank)
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
self.rank)
dist.broadcast(tensor, src=0)
byte_list = tensor.cpu().tolist()
self.unique_id = NcclUniqueId()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank)
assert result == 0
self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}")
def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
if stream is None:
stream = self.stream
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataType_t.from_torch(tensor.dtype),
ncclRedOp_t.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0
def __del__(self):
dist.destroy_process_group()
_c_ncclCommDestroy(self.comm)

View File

@ -0,0 +1,64 @@
import contextlib
import logging
from typing import Optional
import torch
from torch.distributed import ReduceOp
logger = logging.getLogger(__name__)
try:
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
ncclGetVersion)
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs
logger.info(f"Failed to import NCCL library: {e}")
logger.info("It is expected if you are not running on NVIDIA GPUs.")
pass
comm: Optional["NCCLCommunicator"] = None
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return comm is not None
@contextlib.contextmanager
def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
comm.stream = stream
yield
finally:
pass
def init_process_group(world_size: int, rank: int, init_method: str) -> None:
assert not is_initialized()
global comm
comm = NCCLCommunicator(init_method=init_method,
world_size=world_size,
rank=rank)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
comm.all_reduce(input_, op)
def destroy_process_group() -> None:
global comm
comm = None
def get_world_size() -> int:
"""Returns the world size."""
return comm.world_size
def get_nccl_backend():
return comm

View File

@ -16,10 +16,7 @@ def init_test_distributed_environment(
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
parallel_config,
rank,
cupy_port=None,
distributed_init_method=distributed_init_method)
parallel_config, rank, distributed_init_method=distributed_init_method)
def multi_process_tensor_parallel(

View File

@ -15,11 +15,11 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils import cupy_utils, custom_all_reduce
from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
with_cupy_nccl_for_all_reduce)
with_pynccl_for_all_reduce)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
@ -764,7 +764,7 @@ class ModelRunner:
"""
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
self.pynccl_backend = pynccl_utils.get_nccl_backend()
assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to "
@ -794,11 +794,11 @@ class ModelRunner:
]
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
# kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or CuPy NCCL if it is disabled or not supported.
# to PyTorch or pynccl if it is disabled or not supported.
with custom_all_reduce.capture():
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
@ -846,12 +846,14 @@ class ModelRunner:
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
def __del__(self) -> None:
# Delete the CUDA graphs before deleting the CuPy NCCL communicator.
# Delete the CUDA graphs before deleting the pynccl communicator.
# NOTE(woosuk): This is necessary because otherwise deadlocks can
# happen.
# FIXME(woosuk): This is a bit hacky. Find a more robust solution.
# TODO(youkaichao): when we get enough user feedback that pynccl is
# more stable than cupy, we can remove this, e.g. in v0.4.1.
self.graph_runners.clear()
self.cupy_nccl_backend = None
self.pynccl_backend = None
@property
def vocab_size(self) -> int:
@ -879,7 +881,7 @@ class CUDAGraphRunner:
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with _maybe_cupy_nccl():
with _maybe_pynccl():
self.model(
input_ids,
positions,
@ -894,7 +896,7 @@ class CUDAGraphRunner:
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with _maybe_cupy_nccl():
with _maybe_pynccl():
hidden_states = self.model(
input_ids,
positions,
@ -947,9 +949,10 @@ class CUDAGraphRunner:
@contextlib.contextmanager
def _maybe_cupy_nccl():
if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
with with_cupy_nccl_for_all_reduce():
def _maybe_pynccl():
if pynccl_utils.is_initialized(
) and not custom_all_reduce.is_initialized():
with with_pynccl_for_all_reduce():
yield
else:
yield

View File

@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils import pynccl_utils
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
@ -75,7 +75,7 @@ class Worker:
self.cache_engine = None
self.gpu_cache = None
def init_device(self, cupy_port: Optional[int] = None) -> None:
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
@ -98,7 +98,7 @@ class Worker:
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
cupy_port, self.distributed_init_method)
self.distributed_init_method)
# Set random seed.
set_random_seed(self.model_config.seed)
@ -250,7 +250,6 @@ class Worker:
def init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
cupy_port: Optional[int],
distributed_init_method: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
@ -273,28 +272,27 @@ def init_distributed_environment(
init_method=distributed_init_method,
)
if cupy_utils.is_initialized():
cupy_world_size = cupy_utils.get_world_size()
if cupy_world_size != parallel_config.world_size:
if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size()
if pynccl_world_size != parallel_config.world_size:
raise RuntimeError(
"cupy.distributed is already initialized but the cupy world "
"pynccl is already initialized but the pynccl world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif (parallel_config.world_size > 1 and cupy_port is not None):
# NOTE(woosuk): We don't initialize CuPy process group when world size
f"({pynccl_world_size} vs. {parallel_config.world_size}).")
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
cupy_utils.init_process_group(
pynccl_utils.init_process_group(
world_size=parallel_config.world_size,
rank=rank,
host="localhost",
port=cupy_port,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if cupy_utils.is_initialized():
cupy_utils.all_reduce(torch.zeros(1).cuda())
if pynccl_utils.is_initialized():
pynccl_utils.all_reduce(torch.zeros(1).cuda())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)