Respect ROCR_VISIBLE_DEVICES on AMD GPU device discovery (#142292)

Reland of #140320 after failing test on trunk. Fixes potential environment clobbering in test, makes ROCr+HIP devices (if specified together) more robust to index errors.

Fixes #140318

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142292
Approved by: https://github.com/jataylo, https://github.com/huydhn, https://github.com/jeffdaily

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Tal Ben-Nun
2024-12-25 02:37:09 +00:00
committed by PyTorch MergeBot
parent 7013be0094
commit c0d710634f
2 changed files with 21 additions and 1 deletions

View File

@ -3317,6 +3317,8 @@ print(f"{torch.cuda.device_count()}")
{"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None},
{"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"},
{"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"},
{"ROCR_VISIBLE_DEVICES": "1,2,3", "HIP_VISIBLE_DEVICES": "0"},
{"ROCR_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None},
]
for env_config in custom_envs:

View File

@ -648,7 +648,25 @@ def _parse_visible_devices() -> Union[List[int], List[str]]:
if torch.version.hip:
hip_devices = os.getenv("HIP_VISIBLE_DEVICES")
if hip_devices is not None:
rocr_devices = os.getenv("ROCR_VISIBLE_DEVICES")
# You must take care if both HIP and ROCR env vars are set as they have
# different meanings. Both env vars accept either a list of ints or a
# list of UUIDs. The ROCR env var is processed first which then reduces
# the number of GPUs that HIP can select from.
if rocr_devices is not None:
rocr_count = len(rocr_devices.split(","))
if hip_devices is not None:
# sanity check if both env vars are set
if len(hip_devices.split(",")) > rocr_count:
raise RuntimeError(
"HIP_VISIBLE_DEVICES contains more devices than ROCR_VISIBLE_DEVICES"
)
# HIP_VISIBLE_DEVICES is preferred over ROCR_VISIBLE_DEVICES
var = hip_devices
else:
return list(range(rocr_count))
elif hip_devices is not None:
var = hip_devices
if var is None: