mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Allow torch.cuda.memory.mem_get_info to take a device str argument with an unspecified device index. (#132616)
`torch.cuda.memory.mem_get_info` allows device strings given the current type hints. However, `device = torch.device('cuda')` leads to `device.index = None`, which results in downstream problems. Setting `optional=True` will insert the default device index in such cases. Fixes #132583 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132616 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
7100c36c8a
commit
6753ee127c
@ -11,7 +11,7 @@ import tempfile
|
||||
import threading
|
||||
import unittest
|
||||
from itertools import chain, repeat
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda.comm as comm
|
||||
@ -1005,22 +1005,32 @@ class TestCudaMultiGPU(TestCase):
|
||||
|
||||
# Verifies that mem_get_info works, including when called for a different device
|
||||
def test_mem_get_info(self):
|
||||
def _test(idx):
|
||||
before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(idx)
|
||||
def _test(device: Union[str, int, torch.device]):
|
||||
# Prevent PyTorch from reusing the allocated memory
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(device)
|
||||
# increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
|
||||
t = torch.randn(1024 * 1024 * 8, device="cuda:" + str(idx))
|
||||
t = torch.randn(1024 * 1024 * 8, device=device)
|
||||
if IS_JETSON:
|
||||
# w/o syncing, mem_get_info will run before memory allocated has actually increased.
|
||||
# This race condition causes consistent failure
|
||||
torch.cuda.synchronize()
|
||||
after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(idx)
|
||||
after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(device)
|
||||
|
||||
self.assertLess(after_free_bytes, before_free_bytes)
|
||||
self.assertEqual(before_available_bytes, after_available_bytes)
|
||||
|
||||
# Test calls with different device representations
|
||||
_test(0)
|
||||
_test(torch.device("cuda"))
|
||||
_test(torch.device("cuda:0"))
|
||||
_test("cuda")
|
||||
_test("cuda:0")
|
||||
if TEST_MULTIGPU:
|
||||
_test(1)
|
||||
_test(torch.device("cuda:1"))
|
||||
_test("cuda:1")
|
||||
|
||||
# Test that wrap_with_cuda_memory_check successfully detects leak
|
||||
def test_cuda_memory_leak_detection(self):
|
||||
|
@ -684,9 +684,9 @@ def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
|
||||
r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
device (torch.device or int or str, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
if :attr:`device` is ``None`` (default) or if the device index is not specified.
|
||||
|
||||
.. note::
|
||||
See :ref:`cuda-memory-management` for more
|
||||
@ -694,7 +694,8 @@ def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.cuda.current_device()
|
||||
device = _get_device_index(device)
|
||||
# optional=True allows `device = torch.device('cuda')` for which device.index is None
|
||||
device = _get_device_index(device, optional=True)
|
||||
return torch.cuda.cudart().cudaMemGetInfo(device)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user