[XPU] Delay BF16 check to worker init for spawn compatibility (#22979)
Signed-off-by: chzhang <chaojun.zhang@intel.com>
This commit is contained in:
@ -518,6 +518,26 @@ class CudaPlatformBase(Platform):
|
||||
supported = True
|
||||
return supported
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not cls.has_device_capability(80):
|
||||
capability = cls.get_device_capability()
|
||||
gpu_name = cls.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs "
|
||||
"with compute capability of at least 8.0. "
|
||||
f"Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
|
@ -572,6 +572,13 @@ class Platform:
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
"""
|
||||
Check if the dtype is supported by the current platform.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
@ -462,3 +462,23 @@ class RocmPlatform(Platform):
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not cls.has_device_capability(80):
|
||||
capability = cls.get_device_capability()
|
||||
gpu_name = cls.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs "
|
||||
"with compute capability of at least 8.0. "
|
||||
f"Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
@ -97,13 +97,6 @@ class XPUPlatform(Platform):
|
||||
from vllm.config import CompilationLevel
|
||||
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
|
||||
|
||||
# Instances created using VllmConfig() typically have model_config as
|
||||
# None by default. The modification involves adding a check to prevent
|
||||
# potential null exceptions check and update model config.
|
||||
if model_config is not None and model_config.dtype == torch.bfloat16 \
|
||||
and not cls.device_support_bf16():
|
||||
model_config.dtype = torch.float16
|
||||
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CUDAGraphMode
|
||||
compilation_config = vllm_config.compilation_config
|
||||
@ -162,30 +155,11 @@ class XPUPlatform(Platform):
|
||||
torch.xpu.reset_peak_memory_stats(device)
|
||||
return torch.xpu.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def device_support_bf16(cls) -> bool:
|
||||
device_name = cls.get_device_name().lower()
|
||||
if cls.is_client_gpu_a770():
|
||||
logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
|
||||
" fallback to float16")
|
||||
return False
|
||||
else:
|
||||
logger.info(
|
||||
"Device name %s supports bfloat16. Please file an issue "
|
||||
"if you encounter any accuracy problems with bfloat16.",
|
||||
device_name)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_data_center_gpu(cls) -> bool:
|
||||
device_name = cls.get_device_name().lower()
|
||||
return device_name.count("data center gpu") > 0
|
||||
|
||||
@classmethod
|
||||
def is_client_gpu_a770(cls) -> bool:
|
||||
device_name = cls.get_device_name().lower()
|
||||
return device_name.count("a770") > 0
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
||||
@ -197,3 +171,14 @@ class XPUPlatform(Platform):
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return torch.xpu.device_count()
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
device_name = cls.get_device_name().lower()
|
||||
# client gpu a770
|
||||
if device_name.count("a770") > 0:
|
||||
raise ValueError(
|
||||
"Intel Arc A770 have bfloat16 accuracy known issue. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
@ -167,7 +167,7 @@ class Worker(WorkerBase):
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -612,23 +612,3 @@ def init_worker_distributed_environment(
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not current_platform.has_device_capability(80):
|
||||
capability = current_platform.get_device_capability()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
@ -145,6 +145,7 @@ class XPUWorker(Worker):
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
torch.xpu.empty_cache()
|
||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank).total_memory
|
||||
|
Reference in New Issue
Block a user