diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 393b8a1852..e506689dc3 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from functools import lru_cache +import os +from functools import lru_cache, wraps from typing import TYPE_CHECKING, Dict, List, Optional import torch +from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles, + amdsmi_init, amdsmi_shut_down) import vllm.envs as envs from vllm.logger import init_logger @@ -53,6 +56,41 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") } +# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`` +if "HIP_VISIBLE_DEVICES" in os.environ: + val = os.environ["HIP_VISIBLE_DEVICES"] + if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None): + assert val == cuda_val + else: + os.environ["CUDA_VISIBLE_DEVICES"] = val + +# AMDSMI utils +# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using AMDSMI is that it will not initialize CUDA + + +def with_amdsmi_context(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + amdsmi_init() + try: + return fn(*args, **kwargs) + finally: + amdsmi_shut_down() + + return wrapper + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM @@ -96,13 +134,12 @@ class RocmPlatform(Platform): return DeviceCapability(major=major, minor=minor) @classmethod + @with_amdsmi_context @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: - # NOTE: When using V1 this function is called when overriding the - # engine args. Calling torch.cuda.get_device_name(device_id) here - # will result in the ROCm context being initialized before other - # processes can be created. - return "AMD" + physical_device_id = device_id_to_physical_device_id(device_id) + handle = amdsmi_get_processor_handles()[physical_device_id] + return amdsmi_get_gpu_asic_info(handle)["market_name"] @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: