mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Remove dependency on CuPy (#2152)
This commit is contained in:
@ -12,4 +12,3 @@ fastapi
|
||||
uvicorn[standard]
|
||||
pydantic == 1.10.13 # Required for OpenAI server.
|
||||
aioprometheus[starlette]
|
||||
cupy-cuda12x # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. # FIXME: Fix this in setup.py.
|
||||
|
@ -17,7 +17,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter, get_open_port
|
||||
from vllm.utils import Counter
|
||||
|
||||
if ray:
|
||||
from ray.air.util.torch_dist import init_torch_dist_process_group
|
||||
@ -190,7 +190,6 @@ class LLMEngine:
|
||||
))
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
cupy_port=get_open_port(),
|
||||
get_all_outputs=True,
|
||||
)
|
||||
self._run_workers(
|
||||
|
@ -1,10 +1,8 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
||||
is_custom_nccl_enabled_for_all_reduce,
|
||||
)
|
||||
|
||||
|
||||
@ -17,12 +15,8 @@ def tensor_model_parallel_all_reduce(input_):
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return input_
|
||||
# All-reduce.
|
||||
if is_custom_nccl_enabled_for_all_reduce():
|
||||
# TODO: support multiple parallel groups.
|
||||
cupy_utils.all_reduce(input_)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_,
|
||||
group=get_tensor_model_parallel_group())
|
||||
torch.distributed.all_reduce(input_,
|
||||
group=get_tensor_model_parallel_group())
|
||||
return input_
|
||||
|
||||
|
||||
|
@ -1,115 +0,0 @@
|
||||
"""CuPy utilities for all-reduce.
|
||||
|
||||
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
|
||||
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
|
||||
CUDA graphs.
|
||||
|
||||
TODO: Remove this file when torch.distributed.all_reduce is fixed.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
try:
|
||||
import cupy
|
||||
from cupyx.distributed import NCCLBackend
|
||||
from cupy.cuda import nccl
|
||||
except ImportError as e:
|
||||
cupy = e
|
||||
nccl = None
|
||||
|
||||
class NCCLBackend:
|
||||
...
|
||||
|
||||
|
||||
_OP_MAPPING = {
|
||||
ReduceOp.SUM: "sum",
|
||||
ReduceOp.PRODUCT: "prod",
|
||||
ReduceOp.MIN: "min",
|
||||
ReduceOp.MAX: "max",
|
||||
}
|
||||
|
||||
|
||||
class NCCLBackendWithBFloat16(NCCLBackend):
|
||||
# This is enough to add bfloat16 support for most operations,
|
||||
# but broadcast will fail (will require changes in compiled
|
||||
# cupy code).
|
||||
def _get_nccl_dtype_and_count(self, array, count=None):
|
||||
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
|
||||
torch_dtype = getattr(array, "_torch_dtype", None)
|
||||
if torch_dtype is torch.bfloat16:
|
||||
nccl_dtype = nccl.NCCL_BFLOAT16
|
||||
return nccl_dtype, count
|
||||
|
||||
|
||||
_NCCL_BACKEND = None
|
||||
_WORLD_SIZE = 0
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""Returns whether the NCCL backend is initialized."""
|
||||
return _NCCL_BACKEND is not None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_cupy_stream(stream: torch.cuda.Stream) -> None:
|
||||
"""Set the cuda stream for communication"""
|
||||
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
|
||||
stream.device_index)
|
||||
with cupy_stream:
|
||||
yield
|
||||
|
||||
|
||||
def init_process_group(world_size: int, rank: int, host: str,
|
||||
port: int) -> None:
|
||||
"""Initializes the CuPy NCCL backend.
|
||||
|
||||
# TODO: handle NCCL timeouts.
|
||||
"""
|
||||
assert not is_initialized()
|
||||
|
||||
if isinstance(cupy, Exception):
|
||||
raise ImportError(
|
||||
"NCCLBackend is not available. Please install cupy.") from cupy
|
||||
|
||||
# TODO(woosuk): Create TP and PP process groups for CuPy.
|
||||
global _NCCL_BACKEND
|
||||
global _WORLD_SIZE
|
||||
assert world_size > 0, f"{world_size=} should be a positive integer"
|
||||
assert 0 <= rank < world_size, (
|
||||
f"{rank=} should be a integer between [0, {world_size})")
|
||||
|
||||
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
|
||||
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
|
||||
_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
"""All-reduces the input tensor across the process group."""
|
||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||
# Hack to support bfloat16
|
||||
torch_dtype = input_.dtype
|
||||
if torch_dtype is torch.bfloat16:
|
||||
# We need to view as float16, otherwise
|
||||
# cupy will fail. This will not change
|
||||
# the underlying data.
|
||||
input_ = input_.view(torch.float16)
|
||||
cupy_input = cupy.asarray(input_)
|
||||
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
|
||||
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
|
||||
out_array=cupy_input,
|
||||
op=_OP_MAPPING[op])
|
||||
|
||||
|
||||
def destroy_process_group() -> None:
|
||||
"""Destroys the NCCL backend."""
|
||||
global _NCCL_BACKEND
|
||||
global _WORLD_SIZE
|
||||
_NCCL_BACKEND = None
|
||||
_WORLD_SIZE = 0
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size."""
|
||||
return _WORLD_SIZE
|
@ -3,12 +3,9 @@
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
"""Tensor and pipeline parallel groups."""
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
|
||||
# Tensor model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
# Pipeline model parallel group that the current rank belongs to.
|
||||
@ -180,37 +177,3 @@ def destroy_model_parallel():
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
# Destroy the cupy states if any.
|
||||
cupy_utils.destroy_process_group()
|
||||
|
||||
|
||||
# Whether to use cupy for nccl all reduce.
|
||||
# We use cupy for all reduce when using CUDA graph, because torch.distributed
|
||||
# is not well supported by CUDA graph.
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_custom_nccl_for_all_reduce():
|
||||
"""use custom nccl instead of torch.distributed for all reduce"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size == 1:
|
||||
# No-op.
|
||||
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
|
||||
yield
|
||||
else:
|
||||
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
old = _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = True
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
with cupy_utils.set_cupy_stream(stream):
|
||||
yield
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = old
|
||||
|
||||
|
||||
def is_custom_nccl_enabled_for_all_reduce():
|
||||
"""check if custom nccl is enabled for all reduce"""
|
||||
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
return _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
|
@ -8,8 +8,6 @@ import torch.nn as nn
|
||||
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
with_custom_nccl_for_all_reduce)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
|
||||
@ -459,8 +457,18 @@ class CUDAGraphRunner:
|
||||
# Run the model once without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
with with_custom_nccl_for_all_reduce():
|
||||
self.model(
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, pool=memory_pool):
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
@ -468,20 +476,6 @@ class CUDAGraphRunner:
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
|
||||
with with_custom_nccl_for_all_reduce():
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Save the input and output buffers.
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
|
@ -8,7 +8,6 @@ import torch.distributed
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
@ -47,7 +46,7 @@ class Worker:
|
||||
self.cache_events = None
|
||||
self.gpu_cache = None
|
||||
|
||||
def init_model(self, cupy_port: Optional[int] = None):
|
||||
def init_model(self) -> None:
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
@ -71,7 +70,7 @@ class Worker:
|
||||
|
||||
# Initialize the distributed environment.
|
||||
_init_distributed_environment(self.parallel_config, self.rank,
|
||||
cupy_port, self.distributed_init_method)
|
||||
self.distributed_init_method)
|
||||
|
||||
# Initialize the model.
|
||||
set_random_seed(self.model_config.seed)
|
||||
@ -165,7 +164,6 @@ class Worker:
|
||||
def _init_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
cupy_port: Optional[int],
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
@ -188,29 +186,8 @@ def _init_distributed_environment(
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
if cupy_utils.is_initialized():
|
||||
cupy_world_size = cupy_utils.get_world_size()
|
||||
if cupy_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"cupy.distributed is already initialized but the cupy world "
|
||||
"size does not match parallel_config.world_size "
|
||||
f"({cupy_world_size} vs. {parallel_config.world_size}).")
|
||||
elif parallel_config.world_size > 1:
|
||||
# NOTE(woosuk): We don't initialize CuPy process group when world size
|
||||
# is 1.
|
||||
# TODO(woosuk): Support multi-node connection.
|
||||
cupy_utils.init_process_group(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
host="localhost",
|
||||
port=cupy_port,
|
||||
)
|
||||
|
||||
if parallel_config.world_size > 1:
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
cupy_utils.all_reduce(torch.zeros(1).cuda())
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
Reference in New Issue
Block a user