mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Motivation
This PR intends to add C++ accelerator device-agnostic APIs.
# Additional Context
This PR is relanded. It is reverted because `torch.Event` doesn't support mps backend. We have fixed it in https://github.com/pytorch/pytorch/pull/142468. The previous commit is f84e533a2c
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138677
Approved by: https://github.com/albanD, https://github.com/EikanWang
ghstack dependencies: #143171, #133572
67 lines
2.3 KiB
C++
67 lines
2.3 KiB
C++
#include <torch/csrc/DeviceAccelerator.h>
|
|
#include <torch/csrc/utils/device_lazy_init.h>
|
|
|
|
namespace torch::accelerator {
|
|
|
|
void initModule(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
m.def("_accelerator_getAccelerator", []() {
|
|
// If no accelerator is currently available, raise an exception.
|
|
return c10::Device(at::getAccelerator(true).value());
|
|
});
|
|
|
|
m.def("_accelerator_deviceCount", []() {
|
|
auto device_type = at::accelerator::getAccelerator(false);
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::deviceCount();
|
|
});
|
|
|
|
m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) {
|
|
// If device index is negative, no-op
|
|
if (device_index < 0) {
|
|
return;
|
|
}
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
at::accelerator::setDeviceIndex(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_getDeviceIndex", []() {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::getDeviceIndex();
|
|
});
|
|
|
|
m.def("_accelerator_setStream", [](c10::Stream stream) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
// Set the current device to the device of stream
|
|
if (at::accelerator::getDeviceIndex() != stream.device_index()) {
|
|
at::accelerator::setDeviceIndex(stream.device_index());
|
|
}
|
|
at::accelerator::setCurrentStream(stream);
|
|
});
|
|
|
|
m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::getCurrentStream(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
if (torch::utils::is_device_lazy_init_supported(device_type) &&
|
|
!torch::utils::is_device_initialized(device_type)) {
|
|
return;
|
|
}
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
{
|
|
py::gil_scoped_release no_gil;
|
|
at::accelerator::synchronizeDevice(device_index);
|
|
}
|
|
});
|
|
}
|
|
|
|
} // namespace torch::accelerator
|