Allow torch.cuda.memory.mem_get_info to take a device str argument with an unspecified device index. (#132616)

`torch.cuda.memory.mem_get_info` allows device strings given the current type hints. However, `device = torch.device('cuda')` leads to `device.index = None`, which results in downstream problems. Setting `optional=True` will insert the default device index in such cases.

Fixes #132583

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132616
Approved by: https://github.com/soulitzer
This commit is contained in:
Tobias Ringwald
2024-08-06 13:19:46 +00:00
committed by PyTorch MergeBot
parent 7100c36c8a
commit 6753ee127c
2 changed files with 19 additions and 8 deletions

View File

@ -11,7 +11,7 @@ import tempfile
import threading
import unittest
from itertools import chain, repeat
from typing import NamedTuple
from typing import NamedTuple, Union
import torch
import torch.cuda.comm as comm
@ -1005,22 +1005,32 @@ class TestCudaMultiGPU(TestCase):
# Verifies that mem_get_info works, including when called for a different device
def test_mem_get_info(self):
def _test(idx):
before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(idx)
def _test(device: Union[str, int, torch.device]):
# Prevent PyTorch from reusing the allocated memory
torch.cuda.empty_cache()
torch.cuda.synchronize()
before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(device)
# increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
t = torch.randn(1024 * 1024 * 8, device="cuda:" + str(idx))
t = torch.randn(1024 * 1024 * 8, device=device)
if IS_JETSON:
# w/o syncing, mem_get_info will run before memory allocated has actually increased.
# This race condition causes consistent failure
torch.cuda.synchronize()
after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(idx)
after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(device)
self.assertLess(after_free_bytes, before_free_bytes)
self.assertEqual(before_available_bytes, after_available_bytes)
# Test calls with different device representations
_test(0)
_test(torch.device("cuda"))
_test(torch.device("cuda:0"))
_test("cuda")
_test("cuda:0")
if TEST_MULTIGPU:
_test(1)
_test(torch.device("cuda:1"))
_test("cuda:1")
# Test that wrap_with_cuda_memory_check successfully detects leak
def test_cuda_memory_leak_detection(self):

View File

@ -684,9 +684,9 @@ def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
Args:
device (torch.device or int, optional): selected device. Returns
device (torch.device or int or str, optional): selected device. Returns
statistic for the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
if :attr:`device` is ``None`` (default) or if the device index is not specified.
.. note::
See :ref:`cuda-memory-management` for more
@ -694,7 +694,8 @@ def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
"""
if device is None:
device = torch.cuda.current_device()
device = _get_device_index(device)
# optional=True allows `device = torch.device('cuda')` for which device.index is None
device = _get_device_index(device, optional=True)
return torch.cuda.cudart().cudaMemGetInfo(device)