Add torch.accelerator.device_index as accelerator's device switch context (#148864)

# Motivation
We propose adding support for the Python with statement on `torch.accelerator.device_index` to enable device switching functionality. This enhancement would simplify writing device-agnostic code and provide benefits across all accelerators. Its device-specific counterparts include [`torch.cuda.device`](00199acdb8/torch/cuda/__init__.py (L482)) and  [`torch.cuda._DeviceGuard`](00199acdb8/torch/cuda/__init__.py (L469)).

**Design Philosophy**
It accepts either an `Int` or `None` as input. When `None` is passed, no device switch is performed. Supporting `None` is important for compatibility, as it's possible to encounter `None` values from `torch.device.index`.

Therefore, with this PR, we can do like this

```python
src = 0
dst = 1
# Set src to current device
torch.accelerator.set_device_index(src)
with torch.accelerator.device_index(dst):
    # Inside with statement, we set dst to current device
    assert torch.accelerator.get_device_index() == dst
# Here the current device should be src
assert torch.accelerator.get_device_index() == src
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148864
Approved by: https://github.com/albanD
This commit is contained in:
Yu, Guangye
2025-04-25 02:12:24 +00:00
committed by PyTorch MergeBot
parent f38dae76ee
commit 33c75cae0a
11 changed files with 167 additions and 3 deletions

View File

@ -76,7 +76,7 @@ c10::DeviceIndex deviceCount() {
return static_cast<c10::DeviceIndex>(0);
}
c10::impl::VirtualGuardImpl impl(device_type.value());
return static_cast<c10::DeviceIndex>(impl.deviceCount());
return impl.deviceCount();
}
void setDeviceIndex(c10::DeviceIndex device_index) {
@ -88,7 +88,7 @@ void setDeviceIndex(c10::DeviceIndex device_index) {
c10::DeviceIndex getDeviceIndex() {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
return impl.getDevice().index();
}
void setCurrentStream(c10::Stream stream) {
@ -115,6 +115,21 @@ void synchronizeDevice(c10::DeviceIndex device_index) {
// impl.synchronizeDevice should can be safely called from any device
impl.synchronizeDevice(device_index);
}
c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
return impl.exchangeDevice({device_type, device_index}).index();
}
c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
// Avoid creating a new context if the context for the given device_index
// is not initialized.
impl.uncheckedSetDevice({device_type, device_index});
return impl.getDevice().index();
}
// NOLINTEND(bugprone-unchecked-optional-access)
} // namespace at::accelerator

View File

@ -63,6 +63,15 @@ TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index);
// on the given device index has been completed.
TORCH_API void synchronizeDevice(c10::DeviceIndex device_index);
// Set the current device index to the given device_index and return the
// original device index that was active before the change.
TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index);
// Set the current device index to the given device_index. Avoid creating a new
// context if the context for device_index is not initialized. Return the
// original device index that was active before the change.
TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index);
} // namespace at::accelerator
namespace at {

View File

@ -64,6 +64,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_exchange_device_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_generator_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu

View File

@ -0,0 +1,31 @@
#include <gtest/gtest.h>
#include <ATen/DeviceAccelerator.h>
#include <ATen/cuda/CUDAContext.h>
TEST(CudaExchangeDeviceTest, checkPrimaryContext) {
if (!at::cuda::is_available()) {
return;
}
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
at::cuda::MaybeExchangeDevice(0);
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
at::accelerator::maybeExchangeDevice(0);
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
if (at::cuda::device_count() > 1) {
ASSERT_FALSE(at::cuda::hasPrimaryContext(1));
at::cuda::ExchangeDevice(1);
ASSERT_TRUE(at::cuda::hasPrimaryContext(1));
}
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
at::cuda::MaybeExchangeDevice(0);
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
at::accelerator::maybeExchangeDevice(0);
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
at::accelerator::exchangeDevice(0);
ASSERT_TRUE(at::cuda::hasPrimaryContext(0));
}

View File

@ -17,3 +17,4 @@ torch.accelerator
set_stream
current_stream
synchronize
device_index

View File

@ -81,6 +81,24 @@ class TestAccelerator(TestCase):
):
torch.accelerator.current_stream(other_device)
def test_device_context_manager(self):
prev_device = torch.accelerator.current_device_index()
with torch.accelerator.device_index(None):
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
with torch.accelerator.device_index(0):
self.assertEqual(torch.accelerator.current_device_index(), 0)
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
def test_multi_device_context_manager(self):
src_device = 0
dst_device = 1
torch.accelerator.set_device_index(src_device)
with torch.accelerator.device_index(dst_device):
self.assertEqual(torch.accelerator.current_device_index(), dst_device)
self.assertEqual(torch.accelerator.current_device_index(), src_device)
def test_stream_context_manager(self):
prev_stream = torch.accelerator.current_stream()
with torch.Stream() as s:

View File

@ -1094,6 +1094,24 @@ class TestCuda(TestCase):
self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
def test_device_context_manager(self):
prev_device = torch.cuda.current_device()
with torch.accelerator.device_index(None):
self.assertEqual(torch.cuda.current_device(), prev_device)
self.assertEqual(torch.cuda.current_device(), prev_device)
with torch.accelerator.device_index(0):
self.assertEqual(torch.cuda.current_device(), 0)
self.assertEqual(torch.cuda.current_device(), prev_device)
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_multi_device_context_manager(self):
src_device = 0
dst_device = 1
torch.cuda.set_device(src_device)
with torch.accelerator.device_index(dst_device):
self.assertEqual(torch.cuda.current_device(), 1)
self.assertEqual(torch.cuda.set_device(), src_device)
def test_stream_context_manager(self):
prev_stream = torch.cuda.current_stream()
with torch.cuda.Stream() as stream:

View File

@ -315,6 +315,24 @@ if __name__ == "__main__":
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
torch.accelerator.current_stream(torch.accelerator.device_count())
def test_device_context_manager(self):
prev_device = torch.xpu.current_device()
with torch.accelerator.device_index(None):
self.assertEqual(torch.xpu.current_device(), prev_device)
self.assertEqual(torch.xpu.current_device(), prev_device)
with torch.accelerator.device_index(0):
self.assertEqual(torch.xpu.current_device(), 0)
self.assertEqual(torch.xpu.current_device(), prev_device)
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
def test_device_context_manager_with_set_device(self):
src_device = 0
dst_device = 1
torch.xpu.set_device(src_device)
with torch.accelerator.device_index(dst_device):
self.assertEqual(torch.xpu.current_device(), 1)
self.assertEqual(torch.xpu.set_device(), src_device)
def test_stream_context_manager(self):
prev_stream = torch.xpu.current_stream()
with torch.xpu.Stream() as stream:

View File

@ -2289,6 +2289,8 @@ def _accelerator_getDeviceIndex() -> _int: ...
def _accelerator_setStream(Stream) -> None: ...
def _accelerator_getStream(device_index: _int) -> Stream: ...
def _accelerator_synchronizeDevice(device_index: _int) -> None: ...
def _accelerator_exchangeDevice(device_index: _int) -> _int: ...
def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ...
# Defined in torch/csrc/jit/python/python_tracer.cpp
class TracingState:

View File

@ -2,7 +2,7 @@ r"""
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
"""
from typing import Optional
from typing import Literal, Optional
from typing_extensions import deprecated
import torch
@ -16,6 +16,7 @@ __all__ = [
"current_device_index",
"current_stream",
"device_count",
"device_index",
"is_available",
"set_device_idx", # deprecated
"set_device_index",
@ -189,3 +190,41 @@ def synchronize(device: _device_t = None, /) -> None:
"""
device_index = _get_device_index(device, True)
torch._C._accelerator_synchronizeDevice(device_index)
class device_index:
r"""Context manager to set the current device index for the current :ref:`accelerator<accelerators>`.
Temporarily changes the current device index to the specified value for the duration
of the context, and automatically restores the previous device index when exiting
the context.
Args:
device (Optional[int]): a given device index to temporarily set. If None,
no device index switching occurs.
Examples:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # Set device 0 as the current device temporarily
>>> with torch.accelerator.device_index(0):
... # Code here runs with device 0 as the current device
... pass
>>> # Original device is now restored
>>> # No-op when None is passed
>>> with torch.accelerator.device_index(None):
... # No device switching occurs
... pass
"""
def __init__(self, device: Optional[int], /) -> None:
self.idx = device
self.prev_idx = -1
def __enter__(self) -> None:
if self.idx is not None:
self.prev_idx = torch._C._accelerator_exchangeDevice(self.idx)
def __exit__(self, *args: object) -> Literal[False]:
if self.idx is not None:
torch._C._accelerator_maybeExchangeDevice(self.prev_idx)
return False

View File

@ -60,6 +60,18 @@ void initModule(PyObject* module) {
at::accelerator::synchronizeDevice(device_index);
}
});
m.def("_accelerator_exchangeDevice", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::exchangeDevice(device_index);
});
m.def("_accelerator_maybeExchangeDevice", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::maybeExchangeDevice(device_index);
});
}
} // namespace torch::accelerator