Fix kv_cache_dtype handling for out-of-tree HPU plugin (#21302)

Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
This commit is contained in:
Konrad Zawora
2025-07-22 08:35:14 +02:00
committed by GitHub
parent 6e5b5ca580
commit c17231e827
5 changed files with 30 additions and 16 deletions

View File

@ -1352,22 +1352,8 @@ class EngineArgs:
# No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto":
fp8_attention = self.kv_cache_dtype.startswith("fp8")
will_use_fa = (
current_platform.is_cuda()
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if (current_platform.is_rocm()
or (current_platform.is_cuda()
and current_platform.is_device_capability(100))
or current_platform.is_tpu()):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
supported = current_platform.is_kv_cache_dtype_supported(
self.kv_cache_dtype)
if not supported:
_raise_or_fallback(feature_name="--kv-cache-dtype",
recommend_to_remove=False)

View File

@ -586,6 +586,19 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
" not found. Assuming no NVLink available.")
return False
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
fp8_attention = kv_cache_dtype.startswith("fp8")
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if cls.is_device_capability(100):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
supported = flash_attn_supports_fp8()
return supported
# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.

View File

@ -543,6 +543,13 @@ class Platform:
"""
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
"""
Returns if the kv_cache_dtype is supported by the current platform.
"""
return False
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -454,3 +454,7 @@ class RocmPlatform(Platform):
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
return True

View File

@ -190,6 +190,10 @@ class TpuPlatform(Platform):
and params.sampling_type == SamplingType.RANDOM_SEED):
raise ValueError("Torch XLA does not support per-request seed.")
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
return True
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform