[BE][accelerator] formalize API name {current,set}_device_{idx => index} (#140542)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140542
Approved by: https://github.com/guangyey, https://github.com/albanD
This commit is contained in:
Xuehai Pan
2024-12-11 22:37:22 +08:00
committed by PyTorch MergeBot
parent 82aaf64422
commit fb02b40d27
4 changed files with 46 additions and 31 deletions

View File

@ -10,7 +10,9 @@ torch.accelerator
device_count
is_available
current_accelerator
set_device_index
set_device_idx
current_device_index
current_device_idx
set_stream
current_stream

View File

@ -27,23 +27,23 @@ class TestAccelerator(TestCase):
with self.assertRaisesRegex(
ValueError, "doesn't match the current accelerator"
):
torch.accelerator.set_device_idx("cpu")
torch.accelerator.set_device_index("cpu")
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
def test_generic_multi_device_behavior(self):
orig_device = torch.accelerator.current_device_idx()
orig_device = torch.accelerator.current_device_index()
target_device = (orig_device + 1) % torch.accelerator.device_count()
torch.accelerator.set_device_idx(target_device)
self.assertEqual(target_device, torch.accelerator.current_device_idx())
torch.accelerator.set_device_idx(orig_device)
self.assertEqual(orig_device, torch.accelerator.current_device_idx())
torch.accelerator.set_device_index(target_device)
self.assertEqual(target_device, torch.accelerator.current_device_index())
torch.accelerator.set_device_index(orig_device)
self.assertEqual(orig_device, torch.accelerator.current_device_index())
s1 = torch.Stream(target_device)
torch.accelerator.set_stream(s1)
self.assertEqual(target_device, torch.accelerator.current_device_idx())
self.assertEqual(target_device, torch.accelerator.current_device_index())
torch.accelerator.synchronize(orig_device)
self.assertEqual(target_device, torch.accelerator.current_device_idx())
self.assertEqual(target_device, torch.accelerator.current_device_index())
def test_generic_stream_behavior(self):
s1 = torch.Stream()

View File

@ -2,11 +2,27 @@ r"""
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
"""
from typing_extensions import deprecated
import torch
from ._utils import _device_t, _get_device_index
__all__ = [
"current_accelerator",
"current_device_idx", # deprecated
"current_device_index",
"current_stream",
"device_count",
"is_available",
"set_device_idx", # deprecated
"set_device_index",
"set_stream",
"synchronize",
]
def device_count() -> int:
r"""Return the number of current :ref:`accelerator<accelerators>` available.
@ -37,7 +53,7 @@ def current_accelerator() -> torch.device:
torch.device: return the current accelerator as :class:`torch.device`.
.. note:: The index of the returned :class:`torch.device` will be ``None``, please use
:func:`torch.accelerator.current_device_idx` to know the current index being used.
:func:`torch.accelerator.current_device_index` to know the current index being used.
And ensure to use :func:`torch.accelerator.is_available` to check if there is an available
accelerator. If there is no available accelerator, this function will raise an exception.
@ -58,7 +74,7 @@ def current_accelerator() -> torch.device:
return torch._C._accelerator_getAccelerator()
def current_device_idx() -> int:
def current_device_index() -> int:
r"""Return the index of a currently selected device for the current :ref:`accelerator<accelerators>`.
Returns:
@ -67,7 +83,13 @@ def current_device_idx() -> int:
return torch._C._accelerator_getDeviceIndex()
def set_device_idx(device: _device_t, /) -> None:
current_device_idx = deprecated(
"Use `current_device_index` instead.",
category=FutureWarning,
)(current_device_index)
def set_device_index(device: _device_t, /) -> None:
r"""Set the current device index to a given device.
Args:
@ -80,13 +102,19 @@ def set_device_idx(device: _device_t, /) -> None:
torch._C._accelerator_setDeviceIndex(device_index)
set_device_idx = deprecated(
"Use `set_device_index` instead.",
category=FutureWarning,
)(set_device_index)
def current_stream(device: _device_t = None, /) -> torch.Stream:
r"""Return the currently selected stream for a given device.
Args:
device (:class:`torch.device`, str, int, optional): a given device that must match the current
:ref:`accelerator<accelerators>` device type. If not given,
use :func:`torch.accelerator.current_device_idx` by default.
use :func:`torch.accelerator.current_device_index` by default.
Returns:
torch.Stream: the currently selected stream for a given device.
@ -112,7 +140,7 @@ def synchronize(device: _device_t = None, /) -> None:
Args:
device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match
the current :ref:`accelerator<accelerators>` device type. If not given,
use :func:`torch.accelerator.current_device_idx` by default.
use :func:`torch.accelerator.current_device_index` by default.
.. note:: This function is a no-op if the current :ref:`accelerator<accelerators>` is not initialized.
@ -131,15 +159,3 @@ def synchronize(device: _device_t = None, /) -> None:
"""
device_index = _get_device_index(device, True)
torch._C._accelerator_synchronizeDevice(device_index)
__all__ = [
"current_accelerator",
"current_device_idx",
"current_stream",
"device_count",
"is_available",
"set_device_idx",
"set_stream",
"synchronize",
]

View File

@ -1,10 +1,7 @@
from typing import Optional, Union
from typing import Optional
import torch
from torch import device as _device
_device_t = Union[_device, str, int, None]
from torch.types import Device as _device_t
def _get_device_index(device: _device_t, optional: bool = False) -> int:
@ -24,5 +21,5 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int:
raise ValueError(
f"Expected a torch.device with a specified index or an integer, but got:{device}"
)
return torch.accelerator.current_device_idx()
return torch.accelerator.current_device_index()
return device_index