Files
pytorch/torch/csrc/DeviceAccelerator.cpp
Yu, Guangye 33c75cae0a Add torch.accelerator.device_index as accelerator's device switch context (#148864)
# Motivation
We propose adding support for the Python with statement on `torch.accelerator.device_index` to enable device switching functionality. This enhancement would simplify writing device-agnostic code and provide benefits across all accelerators. Its device-specific counterparts include [`torch.cuda.device`](00199acdb8/torch/cuda/__init__.py (L482)) and  [`torch.cuda._DeviceGuard`](00199acdb8/torch/cuda/__init__.py (L469)).

**Design Philosophy**
It accepts either an `Int` or `None` as input. When `None` is passed, no device switch is performed. Supporting `None` is important for compatibility, as it's possible to encounter `None` values from `torch.device.index`.

Therefore, with this PR, we can do like this

```python
src = 0
dst = 1
# Set src to current device
torch.accelerator.set_device_index(src)
with torch.accelerator.device_index(dst):
    # Inside with statement, we set dst to current device
    assert torch.accelerator.get_device_index() == dst
# Here the current device should be src
assert torch.accelerator.get_device_index() == src
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148864
Approved by: https://github.com/albanD
2025-04-25 09:45:25 +00:00

78 lines
2.8 KiB
C++

#include <torch/csrc/DeviceAccelerator.h>
#include <torch/csrc/utils/device_lazy_init.h>
namespace torch::accelerator {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_accelerator_getAccelerator", []() -> std::optional<c10::Device> {
// If no accelerator was available at compile time, return None.
auto acc = at::getAccelerator(false);
if (acc.has_value()) {
return acc.value();
} else {
return std::nullopt;
}
});
m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) {
// If device index is negative, no-op
if (device_index < 0) {
return;
}
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
at::accelerator::setDeviceIndex(device_index);
});
m.def("_accelerator_getDeviceIndex", []() {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::getDeviceIndex();
});
m.def("_accelerator_setStream", [](c10::Stream stream) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
// Set the current device to the device of stream
if (at::accelerator::getDeviceIndex() != stream.device_index()) {
at::accelerator::setDeviceIndex(stream.device_index());
}
at::accelerator::setCurrentStream(stream);
});
m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::getCurrentStream(device_index);
});
m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
if (torch::utils::is_device_lazy_init_supported(device_type) &&
!torch::utils::is_device_initialized(device_type)) {
return;
}
torch::utils::maybe_initialize_device(device_type);
{
py::gil_scoped_release no_gil;
at::accelerator::synchronizeDevice(device_index);
}
});
m.def("_accelerator_exchangeDevice", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::exchangeDevice(device_index);
});
m.def("_accelerator_maybeExchangeDevice", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::maybeExchangeDevice(device_index);
});
}
} // namespace torch::accelerator