mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Platform][Kernel] platform-specific kernel loading (#25823)
Signed-off-by: Hank <hcc.mayday@gmail.com>
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -13,16 +12,8 @@ from vllm.scalar_type import ScalarType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if not current_platform.is_tpu() and not current_platform.is_xpu():
|
||||
try:
|
||||
import vllm._C
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
|
||||
supports_moe_ops = False
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
supports_moe_ops = True
|
||||
current_platform.import_core_kernels()
|
||||
supports_moe_ops = current_platform.try_import_moe_kernels()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import enum
|
||||
import os
|
||||
import platform
|
||||
@ -163,6 +164,22 @@ class Platform:
|
||||
else:
|
||||
return device_id
|
||||
|
||||
@classmethod
|
||||
def import_core_kernels(cls) -> None:
|
||||
""" Import any platform-specific C kernels. """
|
||||
try:
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C: %r", e)
|
||||
|
||||
@classmethod
|
||||
def try_import_moe_kernels(cls) -> bool:
|
||||
""" Import any platform-specific MoE kernels. """
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> "_Backend":
|
||||
|
@ -47,6 +47,10 @@ class TpuPlatform(Platform):
|
||||
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def import_core_kernels(cls) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
|
@ -34,6 +34,10 @@ class XPUPlatform(Platform):
|
||||
dist_backend: str = "ccl" # ccl | xccl
|
||||
device_control_env_var: str = "ZE_AFFINITY_MASK"
|
||||
|
||||
@classmethod
|
||||
def import_core_kernels(cls) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
|
Reference in New Issue
Block a user