Files
vllm/vllm/platforms/cuda.py
2025-03-21 03:14:19 +00:00

457 lines
18 KiB
Python

# SPDX-License-Identifier: Apache-2.0
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""
import os
from functools import wraps
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
Union)
import torch
from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.utils import import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
_P = ParamSpec("_P")
_R = TypeVar("_R")
pynvml = import_pynvml()
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
msg = (
"CUDA_VISIBLE_DEVICES is set to empty string, which means"
" GPU support is disabled. If you are using ray, please unset"
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
" worker/actor. "
"Check https://github.com/vllm-project/vllm/issues/8402 for"
" more information.")
raise RuntimeError(msg)
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
class CudaPlatformBase(Platform):
_enum = PlatformEnum.CUDA
device_name: str = "cuda"
device_type: str = "cuda"
dispatch_key: str = "CUDA"
ray_device_key: str = "GPU"
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
@classmethod
def get_device_capability(cls,
device_id: int = 0
) -> Optional[DeviceCapability]:
raise NotImplementedError
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
if enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
return False
return True
@classmethod
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
raise NotImplementedError
@classmethod
def log_warnings(cls):
pass
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
if envs.VLLM_USE_V1:
raise NotImplementedError(
"Multi-step scheduling is not supported (and not "
"needed) on vLLM V1. Please launch without "
"--num-scheduler-steps.")
else:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing
if model_config is not None and model_config.use_mla:
# if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
# we default to FlashMLA backend, so we need to force the blocksize
# here
use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
from vllm.attention.ops.flashmla import is_flashmla_supported
if use_flashmla and is_flashmla_supported()[0] \
and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")
if (parallel_config.data_parallel_size > 1
and compilation_config.use_cudagraph):
logger.info(
"Data Parallel: Forcing enforce eager to be True since DP is "
"currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
else:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
if not is_flashmla_supported()[0]:
logger.warning(
"FlashMLA backend is not supported due to %s",
is_flashmla_supported()[1])
elif block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
if use_v1:
logger.info_once(
"Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
logger.info("Using FlashMLA backend.")
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
if use_v1:
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend"
elif selected_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
target_backend = _Backend.XFORMERS
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
target_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
target_backend = _Backend.XFORMERS
# FlashAttn is valid for the model, checking if the package is
# installed.
if target_backend == _Backend.FLASH_ATTN:
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
target_backend = _Backend.XFORMERS
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and get_flash_attn_version() != 3):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
target_backend = _Backend.XFORMERS
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
target_backend = _Backend.XFORMERS
if target_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
@classmethod
def supports_fp8(cls) -> bool:
return cls.has_device_capability(89)
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
@classmethod
@with_nvml_context
def get_device_capability(cls,
device_id: int = 0
) -> Optional[DeviceCapability]:
try:
physical_device_id = device_id_to_physical_device_id(device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
return DeviceCapability(major=major, minor=minor)
except RuntimeError:
return None
@classmethod
@with_nvml_context
def has_device_capability(
cls,
capability: Union[Tuple[int, int], int],
device_id: int = 0,
) -> bool:
try:
return super().has_device_capability(capability, device_id)
except RuntimeError:
return False
@classmethod
@with_nvml_context
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
return cls._get_physical_device_name(physical_device_id)
@classmethod
@with_nvml_context
def get_device_uuid(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
return pynvml.nvmlDeviceGetUUID(handle)
@classmethod
@with_nvml_context
def get_device_total_memory(cls, device_id: int = 0) -> int:
physical_device_id = device_id_to_physical_device_id(device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
@classmethod
@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [
pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle,
peer_handle,
pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if"
" your machine has no NVLink equipped.")
return False
return True
@classmethod
def _get_physical_device_name(cls, device_id: int = 0) -> str:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetName(handle)
@classmethod
@with_nvml_context
def log_warnings(cls):
device_ids: int = pynvml.nvmlDeviceGetCount()
if device_ids > 1:
device_names = [
cls._get_physical_device_name(i) for i in range(device_ids)
]
if (len(set(device_names)) > 1
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
logger.warning(
"Detected different devices in the system: %s. Please"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.",
", ".join(device_names),
)
class NonNvmlCudaPlatform(CudaPlatformBase):
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
logger.exception(
"NVLink detection not possible, as context support was"
" not found. Assuming no NVLink available.")
return False
# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
try:
pynvml.nvmlInit()
nvml_available = True
except Exception:
# On Jetson, NVML is not supported.
nvml_available = False
finally:
if nvml_available:
pynvml.nvmlShutdown()
CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
try:
from sphinx.ext.autodoc.mock import _MockModule
if not isinstance(pynvml, _MockModule):
CudaPlatform.log_warnings()
except ModuleNotFoundError:
CudaPlatform.log_warnings()