mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.accelerator.device_index as accelerator's device switch context (#148864)
# Motivation We propose adding support for the Python with statement on `torch.accelerator.device_index` to enable device switching functionality. This enhancement would simplify writing device-agnostic code and provide benefits across all accelerators. Its device-specific counterparts include [`torch.cuda.device`](00199acdb8/torch/cuda/__init__.py (L482)
) and [`torch.cuda._DeviceGuard`](00199acdb8/torch/cuda/__init__.py (L469)
). **Design Philosophy** It accepts either an `Int` or `None` as input. When `None` is passed, no device switch is performed. Supporting `None` is important for compatibility, as it's possible to encounter `None` values from `torch.device.index`. Therefore, with this PR, we can do like this ```python src = 0 dst = 1 # Set src to current device torch.accelerator.set_device_index(src) with torch.accelerator.device_index(dst): # Inside with statement, we set dst to current device assert torch.accelerator.get_device_index() == dst # Here the current device should be src assert torch.accelerator.get_device_index() == src ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148864 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
f38dae76ee
commit
33c75cae0a
@ -76,7 +76,7 @@ c10::DeviceIndex deviceCount() {
|
|||||||
return static_cast<c10::DeviceIndex>(0);
|
return static_cast<c10::DeviceIndex>(0);
|
||||||
}
|
}
|
||||||
c10::impl::VirtualGuardImpl impl(device_type.value());
|
c10::impl::VirtualGuardImpl impl(device_type.value());
|
||||||
return static_cast<c10::DeviceIndex>(impl.deviceCount());
|
return impl.deviceCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
void setDeviceIndex(c10::DeviceIndex device_index) {
|
void setDeviceIndex(c10::DeviceIndex device_index) {
|
||||||
@ -88,7 +88,7 @@ void setDeviceIndex(c10::DeviceIndex device_index) {
|
|||||||
c10::DeviceIndex getDeviceIndex() {
|
c10::DeviceIndex getDeviceIndex() {
|
||||||
const auto device_type = getAccelerator(true).value();
|
const auto device_type = getAccelerator(true).value();
|
||||||
c10::impl::VirtualGuardImpl impl(device_type);
|
c10::impl::VirtualGuardImpl impl(device_type);
|
||||||
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
|
return impl.getDevice().index();
|
||||||
}
|
}
|
||||||
|
|
||||||
void setCurrentStream(c10::Stream stream) {
|
void setCurrentStream(c10::Stream stream) {
|
||||||
@ -115,6 +115,21 @@ void synchronizeDevice(c10::DeviceIndex device_index) {
|
|||||||
// impl.synchronizeDevice should can be safely called from any device
|
// impl.synchronizeDevice should can be safely called from any device
|
||||||
impl.synchronizeDevice(device_index);
|
impl.synchronizeDevice(device_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index) {
|
||||||
|
const auto device_type = getAccelerator(true).value();
|
||||||
|
c10::impl::VirtualGuardImpl impl(device_type);
|
||||||
|
return impl.exchangeDevice({device_type, device_index}).index();
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index) {
|
||||||
|
const auto device_type = getAccelerator(true).value();
|
||||||
|
c10::impl::VirtualGuardImpl impl(device_type);
|
||||||
|
// Avoid creating a new context if the context for the given device_index
|
||||||
|
// is not initialized.
|
||||||
|
impl.uncheckedSetDevice({device_type, device_index});
|
||||||
|
return impl.getDevice().index();
|
||||||
|
}
|
||||||
// NOLINTEND(bugprone-unchecked-optional-access)
|
// NOLINTEND(bugprone-unchecked-optional-access)
|
||||||
|
|
||||||
} // namespace at::accelerator
|
} // namespace at::accelerator
|
||||||
|
@ -63,6 +63,15 @@ TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index);
|
|||||||
// on the given device index has been completed.
|
// on the given device index has been completed.
|
||||||
TORCH_API void synchronizeDevice(c10::DeviceIndex device_index);
|
TORCH_API void synchronizeDevice(c10::DeviceIndex device_index);
|
||||||
|
|
||||||
|
// Set the current device index to the given device_index and return the
|
||||||
|
// original device index that was active before the change.
|
||||||
|
TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index);
|
||||||
|
|
||||||
|
// Set the current device index to the given device_index. Avoid creating a new
|
||||||
|
// context if the context for device_index is not initialized. Return the
|
||||||
|
// original device index that was active before the change.
|
||||||
|
TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index);
|
||||||
|
|
||||||
} // namespace at::accelerator
|
} // namespace at::accelerator
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
@ -64,6 +64,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda_exchange_device_test.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_generator_test.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda_generator_test.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
|
||||||
|
31
aten/src/ATen/test/cuda_exchange_device_test.cpp
Normal file
31
aten/src/ATen/test/cuda_exchange_device_test.cpp
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <ATen/DeviceAccelerator.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
|
||||||
|
TEST(CudaExchangeDeviceTest, checkPrimaryContext) {
|
||||||
|
if (!at::cuda::is_available()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||||
|
at::cuda::MaybeExchangeDevice(0);
|
||||||
|
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||||
|
at::accelerator::maybeExchangeDevice(0);
|
||||||
|
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||||
|
|
||||||
|
if (at::cuda::device_count() > 1) {
|
||||||
|
ASSERT_FALSE(at::cuda::hasPrimaryContext(1));
|
||||||
|
at::cuda::ExchangeDevice(1);
|
||||||
|
ASSERT_TRUE(at::cuda::hasPrimaryContext(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||||
|
at::cuda::MaybeExchangeDevice(0);
|
||||||
|
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||||
|
at::accelerator::maybeExchangeDevice(0);
|
||||||
|
ASSERT_FALSE(at::cuda::hasPrimaryContext(0));
|
||||||
|
at::accelerator::exchangeDevice(0);
|
||||||
|
ASSERT_TRUE(at::cuda::hasPrimaryContext(0));
|
||||||
|
}
|
@ -17,3 +17,4 @@ torch.accelerator
|
|||||||
set_stream
|
set_stream
|
||||||
current_stream
|
current_stream
|
||||||
synchronize
|
synchronize
|
||||||
|
device_index
|
||||||
|
@ -81,6 +81,24 @@ class TestAccelerator(TestCase):
|
|||||||
):
|
):
|
||||||
torch.accelerator.current_stream(other_device)
|
torch.accelerator.current_stream(other_device)
|
||||||
|
|
||||||
|
def test_device_context_manager(self):
|
||||||
|
prev_device = torch.accelerator.current_device_index()
|
||||||
|
with torch.accelerator.device_index(None):
|
||||||
|
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
|
||||||
|
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
|
||||||
|
with torch.accelerator.device_index(0):
|
||||||
|
self.assertEqual(torch.accelerator.current_device_index(), 0)
|
||||||
|
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
|
||||||
|
def test_multi_device_context_manager(self):
|
||||||
|
src_device = 0
|
||||||
|
dst_device = 1
|
||||||
|
torch.accelerator.set_device_index(src_device)
|
||||||
|
with torch.accelerator.device_index(dst_device):
|
||||||
|
self.assertEqual(torch.accelerator.current_device_index(), dst_device)
|
||||||
|
self.assertEqual(torch.accelerator.current_device_index(), src_device)
|
||||||
|
|
||||||
def test_stream_context_manager(self):
|
def test_stream_context_manager(self):
|
||||||
prev_stream = torch.accelerator.current_stream()
|
prev_stream = torch.accelerator.current_stream()
|
||||||
with torch.Stream() as s:
|
with torch.Stream() as s:
|
||||||
|
@ -1094,6 +1094,24 @@ class TestCuda(TestCase):
|
|||||||
|
|
||||||
self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
|
self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
|
||||||
|
|
||||||
|
def test_device_context_manager(self):
|
||||||
|
prev_device = torch.cuda.current_device()
|
||||||
|
with torch.accelerator.device_index(None):
|
||||||
|
self.assertEqual(torch.cuda.current_device(), prev_device)
|
||||||
|
self.assertEqual(torch.cuda.current_device(), prev_device)
|
||||||
|
with torch.accelerator.device_index(0):
|
||||||
|
self.assertEqual(torch.cuda.current_device(), 0)
|
||||||
|
self.assertEqual(torch.cuda.current_device(), prev_device)
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||||
|
def test_multi_device_context_manager(self):
|
||||||
|
src_device = 0
|
||||||
|
dst_device = 1
|
||||||
|
torch.cuda.set_device(src_device)
|
||||||
|
with torch.accelerator.device_index(dst_device):
|
||||||
|
self.assertEqual(torch.cuda.current_device(), 1)
|
||||||
|
self.assertEqual(torch.cuda.set_device(), src_device)
|
||||||
|
|
||||||
def test_stream_context_manager(self):
|
def test_stream_context_manager(self):
|
||||||
prev_stream = torch.cuda.current_stream()
|
prev_stream = torch.cuda.current_stream()
|
||||||
with torch.cuda.Stream() as stream:
|
with torch.cuda.Stream() as stream:
|
||||||
|
@ -315,6 +315,24 @@ if __name__ == "__main__":
|
|||||||
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
||||||
torch.accelerator.current_stream(torch.accelerator.device_count())
|
torch.accelerator.current_stream(torch.accelerator.device_count())
|
||||||
|
|
||||||
|
def test_device_context_manager(self):
|
||||||
|
prev_device = torch.xpu.current_device()
|
||||||
|
with torch.accelerator.device_index(None):
|
||||||
|
self.assertEqual(torch.xpu.current_device(), prev_device)
|
||||||
|
self.assertEqual(torch.xpu.current_device(), prev_device)
|
||||||
|
with torch.accelerator.device_index(0):
|
||||||
|
self.assertEqual(torch.xpu.current_device(), 0)
|
||||||
|
self.assertEqual(torch.xpu.current_device(), prev_device)
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
|
||||||
|
def test_device_context_manager_with_set_device(self):
|
||||||
|
src_device = 0
|
||||||
|
dst_device = 1
|
||||||
|
torch.xpu.set_device(src_device)
|
||||||
|
with torch.accelerator.device_index(dst_device):
|
||||||
|
self.assertEqual(torch.xpu.current_device(), 1)
|
||||||
|
self.assertEqual(torch.xpu.set_device(), src_device)
|
||||||
|
|
||||||
def test_stream_context_manager(self):
|
def test_stream_context_manager(self):
|
||||||
prev_stream = torch.xpu.current_stream()
|
prev_stream = torch.xpu.current_stream()
|
||||||
with torch.xpu.Stream() as stream:
|
with torch.xpu.Stream() as stream:
|
||||||
|
@ -2289,6 +2289,8 @@ def _accelerator_getDeviceIndex() -> _int: ...
|
|||||||
def _accelerator_setStream(Stream) -> None: ...
|
def _accelerator_setStream(Stream) -> None: ...
|
||||||
def _accelerator_getStream(device_index: _int) -> Stream: ...
|
def _accelerator_getStream(device_index: _int) -> Stream: ...
|
||||||
def _accelerator_synchronizeDevice(device_index: _int) -> None: ...
|
def _accelerator_synchronizeDevice(device_index: _int) -> None: ...
|
||||||
|
def _accelerator_exchangeDevice(device_index: _int) -> _int: ...
|
||||||
|
def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
||||||
class TracingState:
|
class TracingState:
|
||||||
|
@ -2,7 +2,7 @@ r"""
|
|||||||
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
|
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -16,6 +16,7 @@ __all__ = [
|
|||||||
"current_device_index",
|
"current_device_index",
|
||||||
"current_stream",
|
"current_stream",
|
||||||
"device_count",
|
"device_count",
|
||||||
|
"device_index",
|
||||||
"is_available",
|
"is_available",
|
||||||
"set_device_idx", # deprecated
|
"set_device_idx", # deprecated
|
||||||
"set_device_index",
|
"set_device_index",
|
||||||
@ -189,3 +190,41 @@ def synchronize(device: _device_t = None, /) -> None:
|
|||||||
"""
|
"""
|
||||||
device_index = _get_device_index(device, True)
|
device_index = _get_device_index(device, True)
|
||||||
torch._C._accelerator_synchronizeDevice(device_index)
|
torch._C._accelerator_synchronizeDevice(device_index)
|
||||||
|
|
||||||
|
|
||||||
|
class device_index:
|
||||||
|
r"""Context manager to set the current device index for the current :ref:`accelerator<accelerators>`.
|
||||||
|
Temporarily changes the current device index to the specified value for the duration
|
||||||
|
of the context, and automatically restores the previous device index when exiting
|
||||||
|
the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (Optional[int]): a given device index to temporarily set. If None,
|
||||||
|
no device index switching occurs.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||||
|
>>> # Set device 0 as the current device temporarily
|
||||||
|
>>> with torch.accelerator.device_index(0):
|
||||||
|
... # Code here runs with device 0 as the current device
|
||||||
|
... pass
|
||||||
|
>>> # Original device is now restored
|
||||||
|
>>> # No-op when None is passed
|
||||||
|
>>> with torch.accelerator.device_index(None):
|
||||||
|
... # No device switching occurs
|
||||||
|
... pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Optional[int], /) -> None:
|
||||||
|
self.idx = device
|
||||||
|
self.prev_idx = -1
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
if self.idx is not None:
|
||||||
|
self.prev_idx = torch._C._accelerator_exchangeDevice(self.idx)
|
||||||
|
|
||||||
|
def __exit__(self, *args: object) -> Literal[False]:
|
||||||
|
if self.idx is not None:
|
||||||
|
torch._C._accelerator_maybeExchangeDevice(self.prev_idx)
|
||||||
|
return False
|
||||||
|
@ -60,6 +60,18 @@ void initModule(PyObject* module) {
|
|||||||
at::accelerator::synchronizeDevice(device_index);
|
at::accelerator::synchronizeDevice(device_index);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("_accelerator_exchangeDevice", [](c10::DeviceIndex device_index) {
|
||||||
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
||||||
|
torch::utils::maybe_initialize_device(device_type);
|
||||||
|
return at::accelerator::exchangeDevice(device_index);
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("_accelerator_maybeExchangeDevice", [](c10::DeviceIndex device_index) {
|
||||||
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
||||||
|
torch::utils::maybe_initialize_device(device_type);
|
||||||
|
return at::accelerator::maybeExchangeDevice(device_index);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace torch::accelerator
|
} // namespace torch::accelerator
|
||||||
|
Reference in New Issue
Block a user