mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ROCm] fix get_device_name for rocm (#13438)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import lru_cache
|
||||
import os
|
||||
from functools import lru_cache, wraps
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles,
|
||||
amdsmi_init, amdsmi_shut_down)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@ -53,6 +56,41 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
}
|
||||
|
||||
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
|
||||
if "HIP_VISIBLE_DEVICES" in os.environ:
|
||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||
if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
|
||||
assert val == cuda_val
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = val
|
||||
|
||||
# AMDSMI utils
|
||||
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
# the major benefit of using AMDSMI is that it will not initialize CUDA
|
||||
|
||||
|
||||
def with_amdsmi_context(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
amdsmi_init()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def device_id_to_physical_device_id(device_id: int) -> int:
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
||||
physical_device_id = device_ids[device_id]
|
||||
return int(physical_device_id)
|
||||
else:
|
||||
return device_id
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
@ -96,13 +134,12 @@ class RocmPlatform(Platform):
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@classmethod
|
||||
@with_amdsmi_context
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
# NOTE: When using V1 this function is called when overriding the
|
||||
# engine args. Calling torch.cuda.get_device_name(device_id) here
|
||||
# will result in the ROCm context being initialized before other
|
||||
# processes can be created.
|
||||
return "AMD"
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
handle = amdsmi_get_processor_handles()[physical_device_id]
|
||||
return amdsmi_get_gpu_asic_info(handle)["market_name"]
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
|
Reference in New Issue
Block a user