Take CUDA_VISIBLE_DEVICES into account for nvml calls (#94568)

Fixes #94472

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94568
Approved by: https://github.com/ngimel
This commit is contained in:
Johan Nordberg
2023-02-15 17:50:12 +00:00
committed by PyTorch MergeBot
parent ea657726d9
commit dc4f2af6f6
2 changed files with 18 additions and 4 deletions

View File

@ -642,6 +642,20 @@ def _device_count_nvml() -> int:
return -1
return len(visible_devices)
def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
r"""Returns the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
idx = _get_device_index(device, optional=True)
visible_devices = _parse_visible_devices()
if type(visible_devices[0]) is str:
uuids = _raw_device_uuid_nvml()
if uuids is None:
raise RuntimeError("Can't get device UUIDs")
visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids)
idx_map = {idx: real_idx for idx, real_idx in enumerate(cast(List[int], visible_devices))}
if idx not in idx_map:
raise RuntimeError(f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})")
return idx_map[idx]
@lru_cache(maxsize=1)
def device_count() -> int:
r"""Returns the number of GPUs available."""
@ -789,7 +803,7 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
pynvml.nvmlInit()
except NVMLError_DriverNotLoaded as e:
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
device = _get_device_index(device, optional=True)
device = _get_nvml_device_index(device)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
@ -815,7 +829,7 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int:
pynvml.nvmlInit()
except NVMLError_DriverNotLoaded as e:
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
device = _get_device_index(device, optional=True)
device = _get_nvml_device_index(device)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu

View File

@ -5,7 +5,7 @@ import warnings
from typing import Any, Dict, Union, Tuple
import torch
from . import is_initialized, _get_device_index, _lazy_init
from . import is_initialized, _get_device_index, _lazy_init, _get_nvml_device_index
from ._utils import _dummy_type
from ._memory_viz import segments as _segments, memory as _memory
@ -587,7 +587,7 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str:
pynvml.nvmlInit()
except NVMLError_DriverNotLoaded:
return ("cuda driver can't be loaded, is cuda enabled?")
device = _get_device_index(device, optional=True)
device = _get_nvml_device_index(device)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
lines = []