[ROCm] amdsmi library integration (#119182)

Adds monitoring support for ROCm using amdsmi in place of pynvml.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119182
Approved by: https://github.com/jeffdaily, https://github.com/malfet, https://github.com/xw285cornell
This commit is contained in:
Jack Taylor
2024-05-09 18:21:38 +00:00
committed by PyTorch MergeBot
parent 0e419b9146
commit 85447c41e3
9 changed files with 274 additions and 129 deletions

View File

@ -77,6 +77,9 @@ RUN rm install_rocm.sh
COPY ./common/install_rocm_magma.sh install_rocm_magma.sh
RUN bash ./install_rocm_magma.sh
RUN rm install_rocm_magma.sh
COPY ./common/install_amdsmi.sh install_amdsmi.sh
RUN bash ./install_amdsmi.sh
RUN rm install_amdsmi.sh
ENV PATH /opt/rocm/bin:$PATH
ENV PATH /opt/rocm/hcc/bin:$PATH
ENV PATH /opt/rocm/hip/bin:$PATH

View File

@ -0,0 +1,5 @@
#!/bin/bash
set -ex
cd /opt/rocm/share/amd_smi && pip install .

View File

@ -59,7 +59,8 @@ install_ubuntu() {
rocm-libs \
rccl \
rocprofiler-dev \
roctracer-dev
roctracer-dev \
amd-smi-lib
if [[ $(ver $ROCM_VERSION) -ge $(ver 6.1) ]]; then
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev
@ -143,7 +144,8 @@ install_centos() {
rocm-libs \
rccl \
rocprofiler-dev \
roctracer-dev
roctracer-dev \
amd-smi-lib
# precompiled miopen kernels; search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails

View File

@ -78,6 +78,11 @@ ENV MAGMA_HOME /opt/rocm/magma
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8
# Install amdsmi
COPY ./common/install_amdsmi.sh install_amdsmi.sh
RUN bash ./install_amdsmi.sh
RUN rm install_amdsmi.sh
# (optional) Install non-default CMake version
ARG CMAKE_VERSION
COPY ./common/install_cmake.sh install_cmake.sh

View File

@ -4217,7 +4217,10 @@ class TestCudaMallocAsync(TestCase):
@unittest.skipIf(TEST_PYNVML, "pynvml is not available")
def test_nvml_get_handler(self):
self.assertTrue(torch.cuda._get_pynvml_handler() is not None)
if not torch.version.hip:
self.assertTrue(torch.cuda._get_pynvml_handler() is not None)
else:
self.assertTrue(torch.cuda._get_amdsmi_handler() is not None)
@unittest.skipIf(TEST_PYNVML, "pynvml is not available")
def test_temperature(self):

View File

@ -2,28 +2,10 @@
import datetime
import json
import signal
import sys
import time
from typing import Any, Dict, List
import psutil # type: ignore[import]
import pynvml # type: ignore[import]
# ROCm does not currently have the rocm_smi module installed to a pythonic location.
# Must import from ROCm installation path.
# Cannot use the high-level rocm_smi cmdline module due to its use of exit().
# Must use the lower-level ctypes wrappers exposed through rsmiBindings.
sys.path.append("/opt/rocm/libexec/rocm_smi")
try:
from ctypes import byref, c_uint32, c_uint64
from rsmiBindings import ( # type: ignore[import]
rocmsmi,
rsmi_process_info_t,
rsmi_status_t,
)
except ImportError as e:
pass
def get_processes_running_python_tests() -> List[Any]:
@ -76,78 +58,42 @@ def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
return per_process_info
def rocm_ret_ok(ret: int) -> Any:
return ret == rsmi_status_t.RSMI_STATUS_SUCCESS
def rocm_list_devices() -> List[int]:
num = c_uint32(0)
ret = rocmsmi.rsmi_num_monitor_devices(byref(num))
if rocm_ret_ok(ret):
return list(range(num.value))
return []
def rocm_get_mem_use(device: int) -> float:
memoryUse = c_uint64()
memoryTot = c_uint64()
ret = rocmsmi.rsmi_dev_memory_usage_get(device, 0, byref(memoryUse))
if rocm_ret_ok(ret):
ret = rocmsmi.rsmi_dev_memory_total_get(device, 0, byref(memoryTot))
if rocm_ret_ok(ret):
return float(memoryUse.value) / float(memoryTot.value)
return 0.0
def rocm_get_gpu_use(device: int) -> float:
percent = c_uint32()
ret = rocmsmi.rsmi_dev_busy_percent_get(device, byref(percent))
if rocm_ret_ok(ret):
return float(percent.value)
return 0.0
def rocm_get_pid_list() -> List[Any]:
num_items = c_uint32()
ret = rocmsmi.rsmi_compute_process_info_get(None, byref(num_items))
if rocm_ret_ok(ret):
buff_sz = num_items.value + 10
procs = (rsmi_process_info_t * buff_sz)()
procList = []
ret = rocmsmi.rsmi_compute_process_info_get(byref(procs), byref(num_items))
for i in range(num_items.value):
procList.append(procs[i].process_id)
return procList
return []
def rocm_get_per_process_gpu_info() -> List[Dict[str, Any]]:
def rocm_get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
processes = amdsmi.amdsmi_get_gpu_process_list(handle)
per_process_info = []
for pid in rocm_get_pid_list():
proc = rsmi_process_info_t()
ret = rocmsmi.rsmi_compute_process_info_by_pid_get(int(pid), byref(proc))
if rocm_ret_ok(ret):
info = {"pid": pid, "gpu_memory": proc.vram_usage}
per_process_info.append(info)
for p in processes:
proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p)
info = {
"pid": proc_info["pid"],
"gpu_memory": proc_info["memory_usage"]["vram_mem"],
}
per_process_info.append(info)
return per_process_info
if __name__ == "__main__":
handle = None
try:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
except pynvml.NVMLError:
import pynvml # type: ignore[import]
try:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
except pynvml.NVMLError:
pass
except ModuleNotFoundError:
# no pynvml avaliable, probably because not cuda
pass
rsmi_handles = []
try:
ret = rocmsmi.rsmi_init(0)
rsmi_handles = rocm_list_devices()
except Exception:
# no rocmsmi available, probably because not rocm
import amdsmi # type: ignore[import]
try:
amdsmi.amdsmi_init()
amdsmi_handle = amdsmi.amdsmi_get_processor_handles()[0]
except amdsmi.AmdSmiException:
pass
except ModuleNotFoundError:
# no amdsmi is available
pass
kill_now = False
@ -171,17 +117,16 @@ if __name__ == "__main__":
gpu_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
stats["total_gpu_utilization"] = gpu_utilization.gpu
stats["total_gpu_mem_utilization"] = gpu_utilization.memory
if rsmi_handles:
stats["per_process_gpu_info"] = rocm_get_per_process_gpu_info()
# There are 1 to 4 GPUs in use; these values may sum > 1.0.
gpu_utilization = 0.0
gpu_memory = 0.0
for dev in rsmi_handles:
gpu_utilization += rocm_get_gpu_use(dev)
gpu_memory += rocm_get_mem_use(dev)
stats["total_gpu_utilization"] = gpu_utilization
stats["total_gpu_mem_utilization"] = gpu_memory
if amdsmi_handle is not None:
stats["per_process_gpu_info"] = rocm_get_per_process_gpu_info(
amdsmi_handle
)
stats["total_gpu_utilization"] = amdsmi.amdsmi_get_gpu_activity(
amdsmi_handle
)["gfx_activity"]
stats["total_gpu_mem_utilization"] = amdsmi.amdsmi_get_gpu_activity(
amdsmi_handle
)["umc_activity"]
except Exception as e:
stats = {
"time": datetime.datetime.utcnow().isoformat("T") + "Z",

View File

@ -2430,7 +2430,10 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.cpu.synchronize",
"torch.cuda._check_capability",
"torch.cuda._check_cubins",
"torch.cuda._device_count_amdsmi",
"torch.cuda._device_count_nvml",
"torch.cuda._get_amdsmi_handler",
"torch.cuda._get_amdsmi_device_index",
"torch.cuda._get_device",
"torch.cuda._get_generator",
"torch.cuda._get_nvml_device_index",
@ -2461,7 +2464,9 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.cuda._memory_viz.trace",
"torch.cuda._nvml_based_avail",
"torch.cuda._parse_visible_devices",
"torch.cuda._raw_device_count_amdsmi",
"torch.cuda._raw_device_count_nvml",
"torch.cuda._raw_device_uuid_amdsmi",
"torch.cuda._raw_device_uuid_nvml",
"torch.cuda._register_triton_kernels",
"torch.cuda._set_rng_state_offset",

View File

@ -53,9 +53,18 @@ _device_t = Union[_device, str, int, None]
_HAS_PYNVML = False
_PYNVML_ERR = None
try:
import pynvml # type: ignore[import]
try:
import pynvml # type: ignore[import]
_HAS_PYNVML = True
_HAS_PYNVML = True
except ModuleNotFoundError:
pass
try:
import amdsmi # type: ignore[import]
_HAS_PYNVML = True
except ModuleNotFoundError:
pass
except ImportError as err:
_PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
@ -563,7 +572,9 @@ def set_stream(stream: Stream):
def _parse_visible_devices() -> Union[List[int], List[str]]:
r"""Parse CUDA_VISIBLE_DEVICES environment variable."""
var = os.getenv("CUDA_VISIBLE_DEVICES")
var = os.getenv(
"CUDA_VISIBLE_DEVICES" if not torch.version.hip else "HIP_VISIBLE_DEVICES"
)
if var is None:
return list(range(64))
@ -609,6 +620,16 @@ def _parse_visible_devices() -> Union[List[int], List[str]]:
return rc
def _raw_device_count_amdsmi() -> int:
try:
amdsmi.amdsmi_init()
except amdsmi.AmdSmiException as e:
warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}")
return -1
socket_handles = amdsmi.amdsmi_get_processor_handles()
return len(socket_handles)
def _raw_device_count_nvml() -> int:
r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""
from ctypes import byref, c_int, CDLL
@ -627,6 +648,36 @@ def _raw_device_count_nvml() -> int:
return dev_count.value
def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
try:
amdsmi.amdsmi_init()
except amdsmi.AmdSmiException:
warnings.warn("Can't initialize amdsmi")
return None
try:
socket_handles = amdsmi.amdsmi_get_processor_handles()
dev_count = len(socket_handles)
except amdsmi.AmdSmiException:
warnings.warn("Can't get amdsmi device count")
return None
uuids: List[str] = []
for idx in range(dev_count):
try:
handler = amdsmi.amdsmi_get_processor_handles()[idx]
except amdsmi.AmdSmiException:
warnings.warn("Cannot get amd device handler")
return None
try:
uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler)
except amdsmi.AmdSmiException:
warnings.warn("Cannot get uuid for amd device")
return None
uuids.append(str(uuid))
return uuids
def _raw_device_uuid_nvml() -> Optional[List[str]]:
r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
@ -686,6 +737,28 @@ def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List
return rc
def _device_count_amdsmi() -> int:
visible_devices = _parse_visible_devices()
if not visible_devices:
return 0
try:
if type(visible_devices[0]) is str:
return -1
else:
raw_cnt = _raw_device_count_amdsmi()
if raw_cnt <= 0:
return raw_cnt
# Trim the list up to a maximum available device
for idx, val in enumerate(visible_devices):
if cast(int, val) >= raw_cnt:
return idx
except OSError:
return -1
except AttributeError:
return -1
return len(visible_devices)
def _device_count_nvml() -> int:
r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
@ -750,7 +823,7 @@ def device_count() -> int:
if _cached_device_count is not None:
return _cached_device_count
# bypass _device_count_nvml() if rocm (not supported)
nvml_count = -1 if torch.version.hip else _device_count_nvml()
nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml()
r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
# NB: Do not cache the device count prior to CUDA initialization, because
# the number of devices can change due to changes to CUDA_VISIBLE_DEVICES
@ -908,6 +981,68 @@ def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
return handle
def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None):
if not _HAS_PYNVML:
raise ModuleNotFoundError(
"amdsmi does not seem to be installed or it can't be imported."
) from _PYNVML_ERR
try:
amdsmi.amdsmi_init()
except amdsmi.AmdSmiException as e:
raise RuntimeError(
"amdsmi driver can't be loaded, requires >=ROCm5.6 installation"
) from e
device = _get_amdsmi_device_index(device)
handle = amdsmi.amdsmi_get_processor_handles()[device]
return handle
def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
r"""Return the amdsmi index of the device, taking HIP_VISIBLE_DEVICES into account."""
idx = _get_device_index(device, optional=True)
visible_devices = _parse_visible_devices()
if type(visible_devices[0]) is str:
raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings")
idx_map = dict(enumerate(cast(List[int], visible_devices)))
if idx not in idx_map:
raise RuntimeError(
f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})"
)
return idx_map[idx]
def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int:
handle = _get_amdsmi_handler()
device = _get_amdsmi_device_index(device)
return amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"]
def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int:
handle = _get_amdsmi_handler()
device = _get_amdsmi_device_index(device)
handle = amdsmi.amdsmi_get_processor_handles()[device]
return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"]
def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int:
handle = _get_amdsmi_handler(device)
return amdsmi.amdsmi_get_temp_metric(
handle,
amdsmi.AmdSmiTemperatureType.JUNCTION,
amdsmi.AmdSmiTemperatureMetric.CURRENT,
)
def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int:
handle = _get_amdsmi_handler(device)
return amdsmi.amdsmi_get_power_info(handle)["average_socket_power"]
def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int:
handle = _get_amdsmi_handler(device)
return amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)["cur_clk"]
def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
r"""Return the percent of time over the past sample period during which global (device)
memory was being read or written as given by `nvidia-smi`.
@ -920,11 +1055,13 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
Warning: Each sample period may be between 1 second and 1/6 second,
depending on the product being queried.
"""
handle = _get_pynvml_handler()
device = _get_nvml_device_index(device)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
if not torch.version.hip:
handle = _get_pynvml_handler()
device = _get_nvml_device_index(device)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
else:
return _get_amdsmi_memory_usage(device)
def utilization(device: Optional[Union[Device, int]] = None) -> int:
@ -939,10 +1076,13 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int:
Warning: Each sample period may be between 1 second and 1/6 second,
depending on the product being queried.
"""
handle = _get_pynvml_handler(device)
device = _get_nvml_device_index(device)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
if not torch.version.hip:
handle = _get_pynvml_handler(device)
device = _get_nvml_device_index(device)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
else:
return _get_amdsmi_utilization(device)
def temperature(device: Optional[Union[Device, int]] = None) -> int:
@ -958,9 +1098,12 @@ def temperature(device: Optional[Union[Device, int]] = None) -> int:
Warning: Each sample period may be between 1 second and 1/6 second,
depending on the product being queried.
"""
handle = _get_pynvml_handler(device)
# 0 refers to the temperature sensor for the GPU die.
return pynvml.nvmlDeviceGetTemperature(handle, 0)
if not torch.version.hip:
handle = _get_pynvml_handler(device)
# 0 refers to the temperature sensor for the GPU die.
return pynvml.nvmlDeviceGetTemperature(handle, 0)
else:
return _get_amdsmi_temperature(device)
def power_draw(device: Optional[Union[Device, int]] = None) -> int:
@ -975,8 +1118,11 @@ def power_draw(device: Optional[Union[Device, int]] = None) -> int:
Warning: Each sample period may be between 1 second and 1/6 second,
depending on the product being queried.
"""
handle = _get_pynvml_handler(device)
return pynvml.nvmlDeviceGetPowerUsage(handle)
if not torch.version.hip:
handle = _get_pynvml_handler(device)
return pynvml.nvmlDeviceGetPowerUsage(handle)
else:
return _get_amdsmi_power_draw(device)
def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
@ -990,8 +1136,11 @@ def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
Warning: Each sample period may be between 1 second and 1/6 second,
depending on the product being queried.
"""
handle = _get_pynvml_handler(device)
return pynvml.nvmlDeviceGetClockInfo(handle, 1)
if not torch.version.hip:
handle = _get_pynvml_handler(device)
return pynvml.nvmlDeviceGetClockInfo(handle, 1)
else:
return _get_amdsmi_clock_rate(device)
def _get_device(device: Union[int, str, torch.device]) -> torch.device:

View File

@ -15,7 +15,13 @@ from torch import _C
from torch.types import Device
from .._utils import _dummy_type
from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized
from . import (
_get_amdsmi_device_index,
_get_device_index,
_get_nvml_device_index,
_lazy_init,
is_initialized,
)
from ._memory_viz import memory as _memory, segments as _segments
@ -609,26 +615,48 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str:
printout for the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
"""
try:
import pynvml # type: ignore[import]
except ModuleNotFoundError:
return "pynvml module not found, please install pynvml"
from pynvml import NVMLError_DriverNotLoaded
if not torch.version.hip:
try:
import pynvml # type: ignore[import]
except ModuleNotFoundError:
return "pynvml module not found, please install pynvml"
from pynvml import NVMLError_DriverNotLoaded
try:
pynvml.nvmlInit()
except NVMLError_DriverNotLoaded:
return "cuda driver can't be loaded, is cuda enabled?"
device = _get_nvml_device_index(device)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
else:
try:
import amdsmi # type: ignore[import]
except ModuleNotFoundError:
return "amdsmi module not found, please install amdsmi"
try:
amdsmi.amdsmi_init() # type: ignore[attr-defined]
except amdsmi.AmdSmiException: # type: ignore[attr-defined]
return "amdsmi driver can't be loaded, is ROCm installed?"
device = _get_amdsmi_device_index(device)
handle = amdsmi.amdsmi_get_processor_handles()[device] # type: ignore[attr-defined]
procs = amdsmi.amdsmi_get_gpu_process_list(handle) # type: ignore[attr-defined]
try:
pynvml.nvmlInit()
except NVMLError_DriverNotLoaded:
return "cuda driver can't be loaded, is cuda enabled?"
device = _get_nvml_device_index(device)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
lines = []
lines.append(f"GPU:{device}")
if len(procs) == 0:
lines.append("no processes are running")
for p in procs:
mem = p.usedGpuMemory / (1024 * 1024)
lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory")
if not torch.version.hip:
mem = p.usedGpuMemory / (1024 * 1024)
pid = p.pid
else:
proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) # type: ignore[possibly-undefined]
mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024)
pid = proc_info["pid"]
lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory")
return "\n".join(lines)