mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add torch.accelerator.device_index as accelerator's device switch context (#148864)
# Motivation We propose adding support for the Python with statement on `torch.accelerator.device_index` to enable device switching functionality. This enhancement would simplify writing device-agnostic code and provide benefits across all accelerators. Its device-specific counterparts include [`torch.cuda.device`](00199acdb8/torch/cuda/__init__.py (L482)
) and [`torch.cuda._DeviceGuard`](00199acdb8/torch/cuda/__init__.py (L469)
). **Design Philosophy** It accepts either an `Int` or `None` as input. When `None` is passed, no device switch is performed. Supporting `None` is important for compatibility, as it's possible to encounter `None` values from `torch.device.index`. Therefore, with this PR, we can do like this ```python src = 0 dst = 1 # Set src to current device torch.accelerator.set_device_index(src) with torch.accelerator.device_index(dst): # Inside with statement, we set dst to current device assert torch.accelerator.get_device_index() == dst # Here the current device should be src assert torch.accelerator.get_device_index() == src ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148864 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
f38dae76ee
commit
33c75cae0a
@ -76,7 +76,7 @@ c10::DeviceIndex deviceCount() {
|
||||
return static_cast<c10::DeviceIndex>(0);
|
||||
}
|
||||
c10::impl::VirtualGuardImpl impl(device_type.value());
|
||||
return static_cast<c10::DeviceIndex>(impl.deviceCount());
|
||||
return impl.deviceCount();
|
||||
}
|
||||
|
||||
void setDeviceIndex(c10::DeviceIndex device_index) {
|
||||
@ -88,7 +88,7 @@ void setDeviceIndex(c10::DeviceIndex device_index) {
|
||||
c10::DeviceIndex getDeviceIndex() {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
|
||||
return impl.getDevice().index();
|
||||
}
|
||||
|
||||
void setCurrentStream(c10::Stream stream) {
|
||||
@ -115,6 +115,21 @@ void synchronizeDevice(c10::DeviceIndex device_index) {
|
||||
// impl.synchronizeDevice should can be safely called from any device
|
||||
impl.synchronizeDevice(device_index);
|
||||
}
|
||||
|
||||
c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
return impl.exchangeDevice({device_type, device_index}).index();
|
||||
}
|
||||
|
||||
c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
// Avoid creating a new context if the context for the given device_index
|
||||
// is not initialized.
|
||||
impl.uncheckedSetDevice({device_type, device_index});
|
||||
return impl.getDevice().index();
|
||||
}
|
||||
// NOLINTEND(bugprone-unchecked-optional-access)
|
||||
|
||||
} // namespace at::accelerator
|
||||
|
@ -63,6 +63,15 @@ TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index);
|
||||
// on the given device index has been completed.
|
||||
TORCH_API void synchronizeDevice(c10::DeviceIndex device_index);
|
||||
|
||||
// Set the current device index to the given device_index and return the
|
||||
// original device index that was active before the change.
|
||||
TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index);
|
||||
|
||||
// Set the current device index to the given device_index. Avoid creating a new
|
||||
// context if the context for device_index is not initialized. Return the
|
||||
// original device index that was active before the change.
|
||||
TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index);
|
||||
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
@ -64,6 +64,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_exchange_device_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_generator_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
|
||||
|
31
aten/src/ATen/test/cuda_exchange_device_test.cpp
Normal file
31
aten/src/ATen/test/cuda_exchange_device_test.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/DeviceAccelerator.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
|
||||
TEST(CudaExchangeDeviceTest, checkPrimaryContext) {
|
||||
if (!at::cuda::is_available()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||
at::cuda::MaybeExchangeDevice(0);
|
||||
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||
at::accelerator::maybeExchangeDevice(0);
|
||||
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||
|
||||
if (at::cuda::device_count() > 1) {
|
||||
ASSERT_FALSE(at::cuda::hasPrimaryContext(1));
|
||||
at::cuda::ExchangeDevice(1);
|
||||
ASSERT_TRUE(at::cuda::hasPrimaryContext(1));
|
||||
}
|
||||
|
||||
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||
at::cuda::MaybeExchangeDevice(0);
|
||||
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||
at::accelerator::maybeExchangeDevice(0);
|
||||
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||
at::accelerator::exchangeDevice(0);
|
||||
ASSERT_TRUE(at::cuda::hasPrimaryContext(0));
|
||||
}
|
@ -17,3 +17,4 @@ torch.accelerator
|
||||
set_stream
|
||||
current_stream
|
||||
synchronize
|
||||
device_index
|
||||
|
@ -81,6 +81,24 @@ class TestAccelerator(TestCase):
|
||||
):
|
||||
torch.accelerator.current_stream(other_device)
|
||||
|
||||
def test_device_context_manager(self):
|
||||
prev_device = torch.accelerator.current_device_index()
|
||||
with torch.accelerator.device_index(None):
|
||||
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
|
||||
with torch.accelerator.device_index(0):
|
||||
self.assertEqual(torch.accelerator.current_device_index(), 0)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
|
||||
def test_multi_device_context_manager(self):
|
||||
src_device = 0
|
||||
dst_device = 1
|
||||
torch.accelerator.set_device_index(src_device)
|
||||
with torch.accelerator.device_index(dst_device):
|
||||
self.assertEqual(torch.accelerator.current_device_index(), dst_device)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), src_device)
|
||||
|
||||
def test_stream_context_manager(self):
|
||||
prev_stream = torch.accelerator.current_stream()
|
||||
with torch.Stream() as s:
|
||||
|
@ -1094,6 +1094,24 @@ class TestCuda(TestCase):
|
||||
|
||||
self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
|
||||
|
||||
def test_device_context_manager(self):
|
||||
prev_device = torch.cuda.current_device()
|
||||
with torch.accelerator.device_index(None):
|
||||
self.assertEqual(torch.cuda.current_device(), prev_device)
|
||||
self.assertEqual(torch.cuda.current_device(), prev_device)
|
||||
with torch.accelerator.device_index(0):
|
||||
self.assertEqual(torch.cuda.current_device(), 0)
|
||||
self.assertEqual(torch.cuda.current_device(), prev_device)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_multi_device_context_manager(self):
|
||||
src_device = 0
|
||||
dst_device = 1
|
||||
torch.cuda.set_device(src_device)
|
||||
with torch.accelerator.device_index(dst_device):
|
||||
self.assertEqual(torch.cuda.current_device(), 1)
|
||||
self.assertEqual(torch.cuda.set_device(), src_device)
|
||||
|
||||
def test_stream_context_manager(self):
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.Stream() as stream:
|
||||
|
@ -315,6 +315,24 @@ if __name__ == "__main__":
|
||||
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
||||
torch.accelerator.current_stream(torch.accelerator.device_count())
|
||||
|
||||
def test_device_context_manager(self):
|
||||
prev_device = torch.xpu.current_device()
|
||||
with torch.accelerator.device_index(None):
|
||||
self.assertEqual(torch.xpu.current_device(), prev_device)
|
||||
self.assertEqual(torch.xpu.current_device(), prev_device)
|
||||
with torch.accelerator.device_index(0):
|
||||
self.assertEqual(torch.xpu.current_device(), 0)
|
||||
self.assertEqual(torch.xpu.current_device(), prev_device)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
|
||||
def test_device_context_manager_with_set_device(self):
|
||||
src_device = 0
|
||||
dst_device = 1
|
||||
torch.xpu.set_device(src_device)
|
||||
with torch.accelerator.device_index(dst_device):
|
||||
self.assertEqual(torch.xpu.current_device(), 1)
|
||||
self.assertEqual(torch.xpu.set_device(), src_device)
|
||||
|
||||
def test_stream_context_manager(self):
|
||||
prev_stream = torch.xpu.current_stream()
|
||||
with torch.xpu.Stream() as stream:
|
||||
|
@ -2289,6 +2289,8 @@ def _accelerator_getDeviceIndex() -> _int: ...
|
||||
def _accelerator_setStream(Stream) -> None: ...
|
||||
def _accelerator_getStream(device_index: _int) -> Stream: ...
|
||||
def _accelerator_synchronizeDevice(device_index: _int) -> None: ...
|
||||
def _accelerator_exchangeDevice(device_index: _int) -> _int: ...
|
||||
def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
||||
class TracingState:
|
||||
|
@ -2,7 +2,7 @@ r"""
|
||||
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -16,6 +16,7 @@ __all__ = [
|
||||
"current_device_index",
|
||||
"current_stream",
|
||||
"device_count",
|
||||
"device_index",
|
||||
"is_available",
|
||||
"set_device_idx", # deprecated
|
||||
"set_device_index",
|
||||
@ -189,3 +190,41 @@ def synchronize(device: _device_t = None, /) -> None:
|
||||
"""
|
||||
device_index = _get_device_index(device, True)
|
||||
torch._C._accelerator_synchronizeDevice(device_index)
|
||||
|
||||
|
||||
class device_index:
|
||||
r"""Context manager to set the current device index for the current :ref:`accelerator<accelerators>`.
|
||||
Temporarily changes the current device index to the specified value for the duration
|
||||
of the context, and automatically restores the previous device index when exiting
|
||||
the context.
|
||||
|
||||
Args:
|
||||
device (Optional[int]): a given device index to temporarily set. If None,
|
||||
no device index switching occurs.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> # Set device 0 as the current device temporarily
|
||||
>>> with torch.accelerator.device_index(0):
|
||||
... # Code here runs with device 0 as the current device
|
||||
... pass
|
||||
>>> # Original device is now restored
|
||||
>>> # No-op when None is passed
|
||||
>>> with torch.accelerator.device_index(None):
|
||||
... # No device switching occurs
|
||||
... pass
|
||||
"""
|
||||
|
||||
def __init__(self, device: Optional[int], /) -> None:
|
||||
self.idx = device
|
||||
self.prev_idx = -1
|
||||
|
||||
def __enter__(self) -> None:
|
||||
if self.idx is not None:
|
||||
self.prev_idx = torch._C._accelerator_exchangeDevice(self.idx)
|
||||
|
||||
def __exit__(self, *args: object) -> Literal[False]:
|
||||
if self.idx is not None:
|
||||
torch._C._accelerator_maybeExchangeDevice(self.prev_idx)
|
||||
return False
|
||||
|
@ -60,6 +60,18 @@ void initModule(PyObject* module) {
|
||||
at::accelerator::synchronizeDevice(device_index);
|
||||
}
|
||||
});
|
||||
|
||||
m.def("_accelerator_exchangeDevice", [](c10::DeviceIndex device_index) {
|
||||
const auto device_type = at::accelerator::getAccelerator(true).value();
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
return at::accelerator::exchangeDevice(device_index);
|
||||
});
|
||||
|
||||
m.def("_accelerator_maybeExchangeDevice", [](c10::DeviceIndex device_index) {
|
||||
const auto device_type = at::accelerator::getAccelerator(true).value();
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
return at::accelerator::maybeExchangeDevice(device_index);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace torch::accelerator
|
||||
|
Reference in New Issue
Block a user