diff --git a/test/test_cuda_multigpu.py b/test/test_cuda_multigpu.py index a0125700a3a4..3dfc82fbf312 100644 --- a/test/test_cuda_multigpu.py +++ b/test/test_cuda_multigpu.py @@ -11,7 +11,7 @@ import tempfile import threading import unittest from itertools import chain, repeat -from typing import NamedTuple +from typing import NamedTuple, Union import torch import torch.cuda.comm as comm @@ -1005,22 +1005,32 @@ class TestCudaMultiGPU(TestCase): # Verifies that mem_get_info works, including when called for a different device def test_mem_get_info(self): - def _test(idx): - before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(idx) + def _test(device: Union[str, int, torch.device]): + # Prevent PyTorch from reusing the allocated memory + torch.cuda.empty_cache() + torch.cuda.synchronize() + before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(device) # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms - t = torch.randn(1024 * 1024 * 8, device="cuda:" + str(idx)) + t = torch.randn(1024 * 1024 * 8, device=device) if IS_JETSON: # w/o syncing, mem_get_info will run before memory allocated has actually increased. # This race condition causes consistent failure torch.cuda.synchronize() - after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(idx) + after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(device) self.assertLess(after_free_bytes, before_free_bytes) self.assertEqual(before_available_bytes, after_available_bytes) + # Test calls with different device representations _test(0) + _test(torch.device("cuda")) + _test(torch.device("cuda:0")) + _test("cuda") + _test("cuda:0") if TEST_MULTIGPU: _test(1) + _test(torch.device("cuda:1")) + _test("cuda:1") # Test that wrap_with_cuda_memory_check successfully detects leak def test_cuda_memory_leak_detection(self): diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 236a7bcfd2c2..1726cbe439dc 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -684,9 +684,9 @@ def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]: r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo. Args: - device (torch.device or int, optional): selected device. Returns + device (torch.device or int or str, optional): selected device. Returns statistic for the current device, given by :func:`~torch.cuda.current_device`, - if :attr:`device` is ``None`` (default). + if :attr:`device` is ``None`` (default) or if the device index is not specified. .. note:: See :ref:`cuda-memory-management` for more @@ -694,7 +694,8 @@ def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]: """ if device is None: device = torch.cuda.current_device() - device = _get_device_index(device) + # optional=True allows `device = torch.device('cuda')` for which device.index is None + device = _get_device_index(device, optional=True) return torch.cuda.cudart().cudaMemGetInfo(device)