mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[bugfix] add supports_v1 platform interface (#15417)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@ -1666,9 +1666,8 @@ class EngineArgs:
|
||||
_raise_or_fallback(feature_name=name, recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
# No support for device type other than CUDA, AMD (experiemntal) or
|
||||
# TPU (experimental) so far.
|
||||
if not (current_platform.is_cuda_alike() or current_platform.is_tpu()):
|
||||
# Platforms must decide if they can support v1 for this model
|
||||
if not current_platform.supports_v1(model_config=model_config):
|
||||
_raise_or_fallback(
|
||||
feature_name=f"device type={current_platform.device_type}",
|
||||
recommend_to_remove=False)
|
||||
|
@ -20,8 +20,9 @@ from vllm.utils import import_pynvml
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -303,6 +304,10 @@ class CudaPlatformBase(Platform):
|
||||
def supports_fp8(cls) -> bool:
|
||||
return cls.has_device_capability(89)
|
||||
|
||||
@classmethod
|
||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
|
@ -12,9 +12,10 @@ import torch
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
FlexibleArgumentParser = None
|
||||
|
||||
@ -371,6 +372,13 @@ class Platform:
|
||||
or parallel_config.distributed_executor_backend
|
||||
== "external_launcher")
|
||||
|
||||
@classmethod
|
||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||
"""Returns whether the current platform can support v1 for the supplied
|
||||
model configuration.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
@ -12,8 +12,9 @@ from vllm.logger import init_logger
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -249,3 +250,8 @@ class RocmPlatform(Platform):
|
||||
return torch.float8_e4m3fnuz
|
||||
else:
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
@classmethod
|
||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||
# V1 support on AMD gpus is experimental
|
||||
return True
|
||||
|
@ -10,8 +10,9 @@ from vllm.logger import init_logger
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -127,3 +128,8 @@ class TpuPlatform(Platform):
|
||||
@classmethod
|
||||
def use_all_gather(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||
# V1 support on TPU is experimental
|
||||
return True
|
||||
|
Reference in New Issue
Block a user