[ROCm] fix get_device_name for rocm (#13438)

Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
Divakar Verma
2025-02-17 22:07:12 -06:00
committed by GitHub
parent 67ef8f666a
commit 7c7adf81fc

View File

@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
from functools import lru_cache
import os
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Dict, List, Optional
import torch
from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles,
amdsmi_init, amdsmi_shut_down)
import vllm.envs as envs
from vllm.logger import init_logger
@ -53,6 +56,41 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
if "HIP_VISIBLE_DEVICES" in os.environ:
val = os.environ["HIP_VISIBLE_DEVICES"]
if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
assert val == cuda_val
else:
os.environ["CUDA_VISIBLE_DEVICES"] = val
# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA
def with_amdsmi_context(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
amdsmi_init()
try:
return fn(*args, **kwargs)
finally:
amdsmi_shut_down()
return wrapper
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
@ -96,13 +134,12 @@ class RocmPlatform(Platform):
return DeviceCapability(major=major, minor=minor)
@classmethod
@with_amdsmi_context
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
# NOTE: When using V1 this function is called when overriding the
# engine args. Calling torch.cuda.get_device_name(device_id) here
# will result in the ROCm context being initialized before other
# processes can be created.
return "AMD"
physical_device_id = device_id_to_physical_device_id(device_id)
handle = amdsmi_get_processor_handles()[physical_device_id]
return amdsmi_get_gpu_asic_info(handle)["market_name"]
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int: