[Platform][Kernel] platform-specific kernel loading (#25823)

Signed-off-by: Hank <hcc.mayday@gmail.com>
This commit is contained in:
Hank_
2025-10-05 19:25:15 +08:00
committed by GitHub
parent 3303cfb4ac
commit 17edd8a807
4 changed files with 27 additions and 11 deletions

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import TYPE_CHECKING, Optional, Union
import torch
@ -13,16 +12,8 @@ from vllm.scalar_type import ScalarType
logger = init_logger(__name__)
if not current_platform.is_tpu() and not current_platform.is_xpu():
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True
current_platform.import_core_kernels()
supports_moe_ops = current_platform.try_import_moe_kernels()
if TYPE_CHECKING:

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import enum
import os
import platform
@ -163,6 +164,22 @@ class Platform:
else:
return device_id
@classmethod
def import_core_kernels(cls) -> None:
""" Import any platform-specific C kernels. """
try:
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C: %r", e)
@classmethod
def try_import_moe_kernels(cls) -> bool:
""" Import any platform-specific MoE kernels. """
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
return True
return False
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> "_Backend":

View File

@ -47,6 +47,10 @@ class TpuPlatform(Platform):
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
]
@classmethod
def import_core_kernels(cls) -> None:
pass
@classmethod
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],

View File

@ -34,6 +34,10 @@ class XPUPlatform(Platform):
dist_backend: str = "ccl" # ccl | xccl
device_control_env_var: str = "ZE_AFFINITY_MASK"
@classmethod
def import_core_kernels(cls) -> None:
pass
@classmethod
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],