mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] remove cupy dependency (#3625)
This commit is contained in:
@ -22,10 +22,13 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests/distributed"
|
||||
num_gpus: 2 # only support 1 or 2 for now.
|
||||
|
||||
- label: Distributed Correctness Test
|
||||
command: pytest -v -s --forked test_basic_distributed_correctness.py
|
||||
- label: Distributed Tests
|
||||
working_dir: "/vllm-workspace/tests/distributed"
|
||||
num_gpus: 2 # only support 1 or 2 for now.
|
||||
commands:
|
||||
- pytest -v -s --forked test_pynccl.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py
|
||||
|
||||
- label: Engine Test
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py
|
||||
|
@ -97,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal
|
||||
|
||||
#################### RUNTIME BASE IMAGE ####################
|
||||
# We used base cuda image because pytorch installs its own cuda libraries.
|
||||
# However cupy depends on cuda libraries so we had to switch to the runtime image
|
||||
# However pynccl depends on cuda libraries so we had to switch to the runtime image
|
||||
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
|
||||
|
||||
|
@ -23,9 +23,6 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
|
||||
# In that case, we need to use the python reference attention implementation in vllm
|
||||
ARG BUILD_FA="1"
|
||||
|
||||
# whether to build cupy on rocm
|
||||
ARG BUILD_CUPY="1"
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
@ -78,23 +75,6 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
|
||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
||||
|
||||
# build cupy
|
||||
RUN if [ "$BUILD_CUPY" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& git clone -b hipgraph_enablement --recursive https://github.com/ROCm/cupy.git \
|
||||
&& cd cupy \
|
||||
&& pip install mpi4py-mpich \
|
||||
&& pip install scipy==1.9.3 \
|
||||
&& pip install cython==0.29.* \
|
||||
&& env CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
|
||||
&& export CUPY_INSTALL_USE_HIP=1 \
|
||||
&& export ROCM_HOME=/opt/rocm \
|
||||
&& export HCC_AMDGPU_TARGET="gfx90a,gfx942,gfx1100" \
|
||||
&& pip install . \
|
||||
&& cd ..; \
|
||||
fi
|
||||
|
||||
COPY ./ /app/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
|
@ -14,4 +14,3 @@ prometheus_client >= 0.18.0
|
||||
pynvml == 11.5.0
|
||||
triton >= 2.1.0
|
||||
outlines == 0.0.34
|
||||
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
|
||||
|
6
setup.py
6
setup.py
@ -306,12 +306,6 @@ def get_requirements() -> List[str]:
|
||||
if _is_cuda():
|
||||
with open(get_path("requirements.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
if get_nvcc_cuda_version() <= Version("11.8"):
|
||||
# replace cupy-cuda12x with cupy-cuda11x for cuda 11.x
|
||||
for i in range(len(requirements)):
|
||||
if requirements[i].startswith("cupy-cuda12x"):
|
||||
requirements[i] = "cupy-cuda11x"
|
||||
break
|
||||
elif _is_hip():
|
||||
with open(get_path("requirements-rocm.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
|
@ -1,13 +1,22 @@
|
||||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||
|
||||
Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`.
|
||||
vLLM will allocate all the available memory, so we need to run the tests one
|
||||
by one. The solution is to pass arguments (model name) by environment
|
||||
variables.
|
||||
Run:
|
||||
```sh
|
||||
TEST_DIST_MODEL=facebook/opt-125m pytest \
|
||||
test_basic_distributed_correctness.py
|
||||
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
|
||||
test_basic_distributed_correctness.py
|
||||
```
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
|
||||
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
Run `pytest tests/distributed/test_comm_ops.py --forked`.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
@ -16,6 +18,12 @@ from vllm.test_utils import (init_test_distributed_environment,
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
num_elements = 8
|
||||
@ -32,6 +40,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
num_dimensions = 3
|
||||
@ -54,6 +68,12 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
test_dict = {
|
||||
|
90
tests/distributed/test_pynccl.py
Normal file
90
tests/distributed/test_pynccl.py
Normal file
@ -0,0 +1,90 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
|
||||
ncclGetUniqueId)
|
||||
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
number_of_processes = world_size
|
||||
processes = []
|
||||
for i in range(number_of_processes):
|
||||
env = os.environ.copy()
|
||||
env['RANK'] = str(i)
|
||||
env['WORLD_SIZE'] = str(number_of_processes)
|
||||
env['MASTER_ADDR'] = 'localhost'
|
||||
env['MASTER_PORT'] = '12345'
|
||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
def update_env(fn):
|
||||
# `multiprocessing.Process` cannot accept environment variables directly
|
||||
# so we need to pass the environment variables as arguments
|
||||
# and update the environment variables in the function
|
||||
def wrapper(env):
|
||||
import os
|
||||
os.environ.update(env)
|
||||
fn()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@update_env
|
||||
def worker_fn():
|
||||
comm = NCCLCommunicator()
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
|
||||
comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == comm.world_size
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl():
|
||||
distributed_run(worker_fn, 2)
|
||||
|
||||
|
||||
@update_env
|
||||
def worker_fn_with_cudagraph():
|
||||
with torch.no_grad():
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
comm = NCCLCommunicator()
|
||||
# run something in the default stream to initialize torch engine
|
||||
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.graph(graph, stream=comm.stream):
|
||||
# operation during the graph capture is recorded but not executed
|
||||
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
|
||||
comm.all_reduce(a)
|
||||
comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == comm.world_size**0
|
||||
graph.replay()
|
||||
comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == comm.world_size**1
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_with_cudagraph():
|
||||
distributed_run(worker_fn_with_cudagraph, 2)
|
||||
|
||||
|
||||
def test_ncclGetUniqueId():
|
||||
unique_id = ncclGetUniqueId()
|
||||
# `list(unique_id.internal)` is something like this:
|
||||
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
|
||||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
# as long as the function doesn't raise an exception, we're good
|
||||
assert unique_id is not None
|
@ -188,11 +188,9 @@ class RayGPUExecutor(ExecutorBase):
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): We are not properly initializing cupy NCCL when
|
||||
# FIXME(woosuk): We are not properly initializing pynccl when
|
||||
# we have multiple nodes.
|
||||
self._run_workers("init_device",
|
||||
cupy_port=get_open_port()
|
||||
if not model_config.enforce_eager else None)
|
||||
self._run_workers("init_device")
|
||||
self._run_workers(
|
||||
"load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
|
@ -4,12 +4,12 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils import pynccl_utils
|
||||
from vllm.model_executor.parallel_utils.custom_all_reduce import (
|
||||
custom_all_reduce)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size, is_cupy_nccl_enabled_for_all_reduce)
|
||||
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
@ -30,9 +30,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
out = custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
if is_cupy_nccl_enabled_for_all_reduce():
|
||||
if is_pynccl_enabled_for_all_reduce():
|
||||
# TODO: support multiple parallel groups.
|
||||
cupy_utils.all_reduce(input_)
|
||||
pynccl_utils.all_reduce(input_)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_,
|
||||
group=get_tensor_model_parallel_group())
|
||||
|
@ -1,130 +0,0 @@
|
||||
"""CuPy utilities for all-reduce.
|
||||
|
||||
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
|
||||
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
|
||||
CUDA graphs.
|
||||
|
||||
NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
|
||||
TODO: Remove this file when torch.distributed.all_reduce is fixed.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
try:
|
||||
import cupy
|
||||
from cupy.cuda import nccl
|
||||
from cupyx.distributed import NCCLBackend
|
||||
except ImportError as e:
|
||||
cupy = e
|
||||
nccl = None
|
||||
|
||||
class NCCLBackend:
|
||||
...
|
||||
|
||||
|
||||
_OP_MAPPING = {
|
||||
ReduceOp.SUM: "sum",
|
||||
ReduceOp.PRODUCT: "prod",
|
||||
ReduceOp.MIN: "min",
|
||||
ReduceOp.MAX: "max",
|
||||
}
|
||||
|
||||
|
||||
class NCCLBackendWithBFloat16(NCCLBackend):
|
||||
# This is enough to add bfloat16 support for most operations,
|
||||
# but broadcast will fail (will require changes in compiled
|
||||
# cupy code).
|
||||
def _get_nccl_dtype_and_count(self, array, count=None):
|
||||
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
|
||||
torch_dtype = getattr(array, "_torch_dtype", None)
|
||||
if torch_dtype is torch.bfloat16:
|
||||
nccl_dtype = nccl.NCCL_BFLOAT16
|
||||
return nccl_dtype, count
|
||||
|
||||
def barrier(self) -> None:
|
||||
raise RuntimeError(
|
||||
"Currently, CuPy NCCL barrier is not supported since the TCP "
|
||||
"store is immediately stopped after the initialization.")
|
||||
|
||||
|
||||
_NCCL_BACKEND = None
|
||||
_WORLD_SIZE = 0
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""Returns whether the NCCL backend is initialized."""
|
||||
return _NCCL_BACKEND is not None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_cupy_stream(stream: torch.cuda.Stream):
|
||||
"""Set the cuda stream for communication"""
|
||||
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
|
||||
stream.device_index)
|
||||
with cupy_stream:
|
||||
yield
|
||||
|
||||
|
||||
def init_process_group(world_size: int, rank: int, host: str,
|
||||
port: int) -> None:
|
||||
"""Initializes the CuPy NCCL backend.
|
||||
|
||||
# TODO: handle NCCL timeouts.
|
||||
"""
|
||||
assert not is_initialized()
|
||||
|
||||
if isinstance(cupy, Exception):
|
||||
raise ImportError(
|
||||
"NCCLBackend is not available. Please install cupy.") from cupy
|
||||
|
||||
# TODO(woosuk): Create TP and PP process groups for CuPy.
|
||||
global _NCCL_BACKEND
|
||||
global _WORLD_SIZE
|
||||
assert world_size > 0, f"{world_size=} should be a positive integer"
|
||||
assert 0 <= rank < world_size, (
|
||||
f"{rank=} should be a integer between [0, {world_size})")
|
||||
|
||||
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
|
||||
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
|
||||
_WORLD_SIZE = world_size
|
||||
|
||||
# Stop the TCP store to prevent the deadlock issues at termination time.
|
||||
# FIXME(woosuk): This is hacky. Find a more robust solution.
|
||||
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
|
||||
_NCCL_BACKEND._store.stop()
|
||||
|
||||
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
"""All-reduces the input tensor across the process group."""
|
||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||
# Hack to support bfloat16
|
||||
torch_dtype = input_.dtype
|
||||
if torch_dtype is torch.bfloat16:
|
||||
# We need to view as float16, otherwise
|
||||
# cupy will fail. This will not change
|
||||
# the underlying data.
|
||||
input_ = input_.view(torch.float16)
|
||||
cupy_input = cupy.asarray(input_)
|
||||
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
|
||||
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
|
||||
out_array=cupy_input,
|
||||
op=_OP_MAPPING[op])
|
||||
|
||||
|
||||
def destroy_process_group() -> None:
|
||||
"""Destroys the NCCL backend."""
|
||||
global _NCCL_BACKEND
|
||||
global _WORLD_SIZE
|
||||
_NCCL_BACKEND = None
|
||||
_WORLD_SIZE = 0
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size."""
|
||||
return _WORLD_SIZE
|
||||
|
||||
|
||||
def get_nccl_backend():
|
||||
return _NCCL_BACKEND
|
@ -7,7 +7,7 @@ import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils import pynccl_utils
|
||||
|
||||
# Tensor model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
@ -210,36 +210,36 @@ def destroy_model_parallel():
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
# Destroy the cupy states if any.
|
||||
cupy_utils.destroy_process_group()
|
||||
# Destroy the pynccl states if any.
|
||||
pynccl_utils.destroy_process_group()
|
||||
|
||||
|
||||
# Whether to use cupy for nccl all reduce.
|
||||
# We use cupy for all reduce when using CUDA graph, because torch.distributed
|
||||
# Whether to use pynccl for nccl all reduce.
|
||||
# We use pynccl for all reduce when using CUDA graph, because torch.distributed
|
||||
# is not well supported by CUDA graph.
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = False
|
||||
_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_cupy_nccl_for_all_reduce():
|
||||
"""use CuPy nccl instead of torch.distributed for all reduce"""
|
||||
def with_pynccl_for_all_reduce():
|
||||
"""use pynccl instead of torch.distributed for all reduce"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size == 1:
|
||||
# No-op.
|
||||
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
|
||||
# NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
|
||||
yield
|
||||
else:
|
||||
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
old = _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = True
|
||||
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
||||
old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
||||
_ENABLE_PYNCCL_FOR_ALL_REDUCE = True
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
with cupy_utils.set_cupy_stream(stream):
|
||||
with pynccl_utils.set_pynccl_stream(stream):
|
||||
yield
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = old
|
||||
_ENABLE_PYNCCL_FOR_ALL_REDUCE = old
|
||||
|
||||
|
||||
def is_cupy_nccl_enabled_for_all_reduce():
|
||||
"""check if CuPy nccl is enabled for all reduce"""
|
||||
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
return _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
def is_pynccl_enabled_for_all_reduce():
|
||||
"""check if pynccl is enabled for all reduce"""
|
||||
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
||||
return _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
||||
|
258
vllm/model_executor/parallel_utils/pynccl.py
Normal file
258
vllm/model_executor/parallel_utils/pynccl.py
Normal file
@ -0,0 +1,258 @@
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")
|
||||
|
||||
# manually load the nccl library
|
||||
if so_file:
|
||||
logger.info(
|
||||
f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}")
|
||||
else:
|
||||
if torch.version.cuda is not None:
|
||||
so_file = "libnccl.so"
|
||||
elif torch.version.hip is not None:
|
||||
so_file = "librccl.so"
|
||||
else:
|
||||
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||
logger.debug(f"Loading nccl from library {so_file}")
|
||||
|
||||
try:
|
||||
nccl = ctypes.CDLL(so_file)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load NCCL library from {so_file} ."
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise please set the environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.")
|
||||
raise e
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
_c_ncclGetVersion = nccl.ncclGetVersion
|
||||
_c_ncclGetVersion.restype = ctypes.c_int
|
||||
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
||||
|
||||
|
||||
def ncclGetVersion() -> str:
|
||||
version = ctypes.c_int()
|
||||
result = _c_ncclGetVersion(ctypes.byref(version))
|
||||
assert result == 0
|
||||
# something like 21903 --> "2.19.3"
|
||||
version_str = str(version.value)
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
|
||||
class NcclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
|
||||
_c_ncclGetUniqueId.restype = ctypes.c_int
|
||||
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
|
||||
|
||||
|
||||
def ncclGetUniqueId() -> NcclUniqueId:
|
||||
unique_id = NcclUniqueId()
|
||||
result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
|
||||
assert result == 0
|
||||
return unique_id
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
_c_ncclCommInitRank = nccl.ncclCommInitRank
|
||||
_c_ncclCommInitRank.restype = ctypes.c_int
|
||||
_c_ncclCommInitRank.argtypes = [
|
||||
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
|
||||
]
|
||||
|
||||
|
||||
# enums
|
||||
class ncclDataType_t(ctypes.c_int):
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
class ncclRedOp_t(ctypes.c_int):
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# udaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument is a pointer
|
||||
_c_ncclAllReduce = nccl.ncclAllReduce
|
||||
_c_ncclAllReduce.restype = ctypes.c_int
|
||||
_c_ncclAllReduce.argtypes = [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
|
||||
]
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
_c_ncclCommDestroy = nccl.ncclCommDestroy
|
||||
_c_ncclCommDestroy.restype = ctypes.c_int
|
||||
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
|
||||
|
||||
|
||||
class NCCLCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend=None,
|
||||
init_method=None,
|
||||
timeout=datetime.timedelta(seconds=10),
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
store=None,
|
||||
group_name: str = "",
|
||||
pg_options=None,
|
||||
):
|
||||
if not dist.is_initialized():
|
||||
backend = backend or "nccl"
|
||||
assert backend == 'nccl', (
|
||||
"only use nccl backend for starting the NCCL communicator")
|
||||
dist.init_process_group(backend=backend,
|
||||
init_method=init_method,
|
||||
timeout=timeout,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
store=store,
|
||||
group_name=group_name,
|
||||
pg_options=pg_options)
|
||||
self.world_size = dist.get_world_size()
|
||||
self.rank = dist.get_rank()
|
||||
torch.cuda.set_device(self.rank)
|
||||
if self.rank == 0:
|
||||
self.unique_id = ncclGetUniqueId()
|
||||
else:
|
||||
self.unique_id = NcclUniqueId()
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
|
||||
self.rank)
|
||||
dist.broadcast(tensor, src=0)
|
||||
byte_list = tensor.cpu().tolist()
|
||||
self.unique_id = NcclUniqueId()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
self.comm = ctypes.c_void_p()
|
||||
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||
self.unique_id, self.rank)
|
||||
assert result == 0
|
||||
self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}")
|
||||
|
||||
def all_reduce(self,
|
||||
tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None):
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
||||
ctypes.c_void_p(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataType_t.from_torch(tensor.dtype),
|
||||
ncclRedOp_t.from_torch(op), self.comm,
|
||||
ctypes.c_void_p(stream.cuda_stream))
|
||||
assert result == 0
|
||||
|
||||
def __del__(self):
|
||||
dist.destroy_process_group()
|
||||
_c_ncclCommDestroy(self.comm)
|
64
vllm/model_executor/parallel_utils/pynccl_utils.py
Normal file
64
vllm/model_executor/parallel_utils/pynccl_utils.py
Normal file
@ -0,0 +1,64 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
|
||||
ncclGetVersion)
|
||||
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
|
||||
except Exception as e:
|
||||
# in non-NVIDIA environments, we can't import the nccl module
|
||||
# e.g. when running on machines with AMD GPUs
|
||||
logger.info(f"Failed to import NCCL library: {e}")
|
||||
logger.info("It is expected if you are not running on NVIDIA GPUs.")
|
||||
pass
|
||||
|
||||
comm: Optional["NCCLCommunicator"] = None
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""Returns whether the NCCL backend is initialized."""
|
||||
return comm is not None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_pynccl_stream(stream: torch.cuda.Stream):
|
||||
"""Set the cuda stream for communication"""
|
||||
try:
|
||||
comm.stream = stream
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def init_process_group(world_size: int, rank: int, init_method: str) -> None:
|
||||
assert not is_initialized()
|
||||
global comm
|
||||
comm = NCCLCommunicator(init_method=init_method,
|
||||
world_size=world_size,
|
||||
rank=rank)
|
||||
|
||||
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
"""All-reduces the input tensor across the process group."""
|
||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||
comm.all_reduce(input_, op)
|
||||
|
||||
|
||||
def destroy_process_group() -> None:
|
||||
global comm
|
||||
comm = None
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size."""
|
||||
return comm.world_size
|
||||
|
||||
|
||||
def get_nccl_backend():
|
||||
return comm
|
@ -16,10 +16,7 @@ def init_test_distributed_environment(
|
||||
worker_use_ray=True)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
init_distributed_environment(
|
||||
parallel_config,
|
||||
rank,
|
||||
cupy_port=None,
|
||||
distributed_init_method=distributed_init_method)
|
||||
parallel_config, rank, distributed_init_method=distributed_init_method)
|
||||
|
||||
|
||||
def multi_process_tensor_parallel(
|
||||
|
@ -15,11 +15,11 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.parallel_utils import cupy_utils, custom_all_reduce
|
||||
from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
with_cupy_nccl_for_all_reduce)
|
||||
with_pynccl_for_all_reduce)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
@ -764,7 +764,7 @@ class ModelRunner:
|
||||
"""
|
||||
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
|
||||
# deleted before the CUDA graphs.
|
||||
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
|
||||
self.pynccl_backend = pynccl_utils.get_nccl_backend()
|
||||
|
||||
assert not self.model_config.enforce_eager
|
||||
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
||||
@ -794,11 +794,11 @@ class ModelRunner:
|
||||
]
|
||||
|
||||
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
|
||||
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
|
||||
# kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using CUDA
|
||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||
# We always prioritize using custom all-reduce kernel but fall back
|
||||
# to PyTorch or CuPy NCCL if it is disabled or not supported.
|
||||
# to PyTorch or pynccl if it is disabled or not supported.
|
||||
with custom_all_reduce.capture():
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
@ -846,12 +846,14 @@ class ModelRunner:
|
||||
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Delete the CUDA graphs before deleting the CuPy NCCL communicator.
|
||||
# Delete the CUDA graphs before deleting the pynccl communicator.
|
||||
# NOTE(woosuk): This is necessary because otherwise deadlocks can
|
||||
# happen.
|
||||
# FIXME(woosuk): This is a bit hacky. Find a more robust solution.
|
||||
# TODO(youkaichao): when we get enough user feedback that pynccl is
|
||||
# more stable than cupy, we can remove this, e.g. in v0.4.1.
|
||||
self.graph_runners.clear()
|
||||
self.cupy_nccl_backend = None
|
||||
self.pynccl_backend = None
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
@ -879,7 +881,7 @@ class CUDAGraphRunner:
|
||||
# Run the model once without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
with _maybe_cupy_nccl():
|
||||
with _maybe_pynccl():
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
@ -894,7 +896,7 @@ class CUDAGraphRunner:
|
||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
|
||||
with _maybe_cupy_nccl():
|
||||
with _maybe_pynccl():
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
@ -947,9 +949,10 @@ class CUDAGraphRunner:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _maybe_cupy_nccl():
|
||||
if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
|
||||
with with_cupy_nccl_for_all_reduce():
|
||||
def _maybe_pynccl():
|
||||
if pynccl_utils.is_initialized(
|
||||
) and not custom_all_reduce.is_initialized():
|
||||
with with_pynccl_for_all_reduce():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils import pynccl_utils
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
|
||||
@ -75,7 +75,7 @@ class Worker:
|
||||
self.cache_engine = None
|
||||
self.gpu_cache = None
|
||||
|
||||
def init_device(self, cupy_port: Optional[int] = None) -> None:
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
@ -98,7 +98,7 @@ class Worker:
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_distributed_environment(self.parallel_config, self.rank,
|
||||
cupy_port, self.distributed_init_method)
|
||||
self.distributed_init_method)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@ -250,7 +250,6 @@ class Worker:
|
||||
def init_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
cupy_port: Optional[int],
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
@ -273,28 +272,27 @@ def init_distributed_environment(
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
if cupy_utils.is_initialized():
|
||||
cupy_world_size = cupy_utils.get_world_size()
|
||||
if cupy_world_size != parallel_config.world_size:
|
||||
if pynccl_utils.is_initialized():
|
||||
pynccl_world_size = pynccl_utils.get_world_size()
|
||||
if pynccl_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"cupy.distributed is already initialized but the cupy world "
|
||||
"pynccl is already initialized but the pynccl world "
|
||||
"size does not match parallel_config.world_size "
|
||||
f"({cupy_world_size} vs. {parallel_config.world_size}).")
|
||||
elif (parallel_config.world_size > 1 and cupy_port is not None):
|
||||
# NOTE(woosuk): We don't initialize CuPy process group when world size
|
||||
f"({pynccl_world_size} vs. {parallel_config.world_size}).")
|
||||
elif parallel_config.world_size > 1:
|
||||
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
||||
# is 1.
|
||||
# TODO(woosuk): Support multi-node connection.
|
||||
cupy_utils.init_process_group(
|
||||
pynccl_utils.init_process_group(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
host="localhost",
|
||||
port=cupy_port,
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
if cupy_utils.is_initialized():
|
||||
cupy_utils.all_reduce(torch.zeros(1).cuda())
|
||||
if pynccl_utils.is_initialized():
|
||||
pynccl_utils.all_reduce(torch.zeros(1).cuda())
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
Reference in New Issue
Block a user