mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This reverts commit 2ad5c25cfc603c3656e6699d6137419dbb009495. Reverted https://github.com/pytorch/pytorch/pull/152932 on behalf of https://github.com/ZainRizvi due to Very sorry but this is still breaking internally. @albanD would you be able to help get this past the finish line? D78496124 has more details on the failure and the workaround might be to do something like what's in D78684669. To validate the fixes internally, you can follow the instructions here to ghimport the changes: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/138222#issuecomment-3100195370))
78 lines
2.8 KiB
C++
78 lines
2.8 KiB
C++
#include <torch/csrc/DeviceAccelerator.h>
|
|
#include <torch/csrc/utils/device_lazy_init.h>
|
|
|
|
namespace torch::accelerator {
|
|
|
|
void initModule(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
m.def("_accelerator_getAccelerator", []() -> std::optional<c10::Device> {
|
|
// If no accelerator was available at compile time, return None.
|
|
auto acc = at::getAccelerator(false);
|
|
if (acc.has_value()) {
|
|
return acc.value();
|
|
} else {
|
|
return std::nullopt;
|
|
}
|
|
});
|
|
|
|
m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) {
|
|
// If device index is negative, no-op
|
|
if (device_index < 0) {
|
|
return;
|
|
}
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
at::accelerator::setDeviceIndex(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_getDeviceIndex", []() {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::getDeviceIndex();
|
|
});
|
|
|
|
m.def("_accelerator_setStream", [](c10::Stream stream) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
// Set the current device to the device of stream
|
|
if (at::accelerator::getDeviceIndex() != stream.device_index()) {
|
|
at::accelerator::setDeviceIndex(stream.device_index());
|
|
}
|
|
at::accelerator::setCurrentStream(stream);
|
|
});
|
|
|
|
m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::getCurrentStream(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
if (torch::utils::is_device_lazy_init_supported(device_type) &&
|
|
!torch::utils::is_device_initialized(device_type)) {
|
|
return;
|
|
}
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
{
|
|
py::gil_scoped_release no_gil;
|
|
at::accelerator::synchronizeDevice(device_index);
|
|
}
|
|
});
|
|
|
|
m.def("_accelerator_exchangeDevice", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::exchangeDevice(device_index);
|
|
});
|
|
|
|
m.def("_accelerator_maybeExchangeDevice", [](c10::DeviceIndex device_index) {
|
|
const auto device_type = at::accelerator::getAccelerator(true).value();
|
|
torch::utils::maybe_initialize_device(device_type);
|
|
return at::accelerator::maybeExchangeDevice(device_index);
|
|
});
|
|
}
|
|
|
|
} // namespace torch::accelerator
|