#include #include namespace torch::accelerator { void initModule(PyObject* module) { auto m = py::handle(module).cast(); m.def("_accelerator_getAccelerator", []() { // If no accelerator is currently available, raise an exception. return c10::Device(at::getAccelerator(true).value()); }); m.def("_accelerator_deviceCount", []() { auto device_type = at::accelerator::getAccelerator(false); torch::utils::maybe_initialize_device(device_type); return at::accelerator::deviceCount(); }); m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) { // If device index is negative, no-op if (device_index < 0) { return; } const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); at::accelerator::setDeviceIndex(device_index); }); m.def("_accelerator_getDeviceIndex", []() { const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); return at::accelerator::getDeviceIndex(); }); m.def("_accelerator_setStream", [](c10::Stream stream) { const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); // Set the current device to the device of stream if (at::accelerator::getDeviceIndex() != stream.device_index()) { at::accelerator::setDeviceIndex(stream.device_index()); } at::accelerator::setCurrentStream(stream); }); m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) { const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); return at::accelerator::getCurrentStream(device_index); }); m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) { const auto device_type = at::accelerator::getAccelerator(true).value(); if (torch::utils::is_device_lazy_init_supported(device_type) && !torch::utils::is_device_initialized(device_type)) { return; } torch::utils::maybe_initialize_device(device_type); { py::gil_scoped_release no_gil; at::accelerator::synchronizeDevice(device_index); } }); } } // namespace torch::accelerator