Files
pytorch/torch/csrc/acc/Module.cpp
Han Qi b5c4f46bb9 Add functions to setup PrivateUse1 as a python backend device. (#157859)
Fixes #156052 and #156444.

This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.

Changes done in this PR:

1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.

This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
2025-10-01 21:32:59 +00:00

197 lines
6.0 KiB
C++

#include <torch/csrc/acc/Module.h>
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/utils/pybind.h>
namespace py = pybind11;
namespace torch::acc {
// python hook interface
struct PythonHooks final : public at::PrivateUse1HooksInterface {
using at::PrivateUse1HooksInterface::PrivateUse1HooksInterface;
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
PYBIND11_OVERRIDE_PURE_NAME(
bool,
at::PrivateUse1HooksInterface,
"has_primary_context",
hasPrimaryContext,
device_index);
}
bool isBuilt() const override {
PYBIND11_OVERRIDE_PURE_NAME(
bool, at::PrivateUse1HooksInterface, "is_built", isBuilt, );
}
bool isAvailable() const override {
PYBIND11_OVERRIDE_PURE_NAME(
bool, at::PrivateUse1HooksInterface, "is_available", isBuilt, );
}
// TODO(qihqi): these is not supported from python yet
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
return at::PrivateUse1HooksInterface::getDefaultGenerator(device_index);
}
at::Generator getNewGenerator(
c10::DeviceIndex device_index = -1) const override {
return at::PrivateUse1HooksInterface::getNewGenerator(device_index);
}
at::Device getDeviceFromPtr(void* data) const override {
return at::PrivateUse1HooksInterface::getDeviceFromPtr(data);
}
bool isPinnedPtr(const void* data) const override {
return at::PrivateUse1HooksInterface::isPinnedPtr(data);
}
at::Allocator* getPinnedMemoryAllocator() const override {
return at::PrivateUse1HooksInterface::getPinnedMemoryAllocator();
}
};
struct PythonDeviceGuard final : public c10::impl::DeviceGuardImplInterface {
using c10::impl::DeviceGuardImplInterface::DeviceGuardImplInterface;
c10::DeviceType type() const override {
PYBIND11_OVERRIDE_PURE_NAME(
c10::DeviceType, c10::impl::DeviceGuardImplInterface, "type_", type, );
}
// TODO(qihqi): figure out if those are even useful
// to python or not
c10::Device exchangeDevice(c10::Device device) const override {
return getDevice();
}
c10::Device getDevice() const override {
return c10::Device(type(), 0);
}
void setDevice(c10::Device device) const override {}
void uncheckedSetDevice(c10::Device device) const noexcept override {}
c10::Stream getStream(c10::Device) const noexcept override {
// no-op
return c10::Stream(c10::Stream::DEFAULT, getDevice());
}
c10::Stream getNewStream(c10::Device, int priority = 0) const override {
// no-op
(void)priority;
return c10::Stream(c10::Stream::DEFAULT, getDevice());
}
c10::Stream exchangeStream(c10::Stream) const noexcept override {
// no-op
return c10::Stream(c10::Stream::DEFAULT, getDevice());
}
c10::DeviceIndex deviceCount() const noexcept override {
return 1;
}
// TODO(qihqi): support Event-related functions
void record(
void** /*event*/,
const c10::Stream& /*stream*/,
const c10::DeviceIndex /*device_index*/,
const c10::EventFlag /*flag*/) const override {}
void block(void* /*event*/, const c10::Stream& /*stream*/) const override {}
bool queryEvent(void* /*event*/) const override {
return true;
}
void destroyEvent(void* /*event*/, const c10::DeviceIndex /*device_index*/)
const noexcept override {}
// Stream-related functions
bool queryStream(const c10::Stream& /*stream*/) const override {
return true;
}
void synchronizeStream(const c10::Stream& /*stream*/) const override {}
};
namespace {
bool registerPythonPrivateUse1Hook(const py::object& hook) {
if (at::isPrivateUse1HooksRegistered()) {
return false;
}
hook.inc_ref();
at::RegisterPrivateUse1HooksInterface(
hook.cast<PrivateUse1HooksInterface*>());
return true;
}
bool registerPythonPrivateUse1DeviceGuard(const py::object& guard) {
if (c10::impl::hasDeviceGuardImpl(c10::DeviceType::PrivateUse1)) {
return false;
}
guard.inc_ref();
c10::impl::registerDeviceGuard(
c10::DeviceType::PrivateUse1,
guard.cast<c10::impl::DeviceGuardImplInterface*>());
return true;
}
at::Tensor createEmptyTensor(
const std::vector<int64_t>& shape,
c10::ScalarType dtype) {
c10::Storage storage{
c10::Storage::use_byte_size_t{},
0,
c10::GetAllocator(c10::kMeta),
true,
};
c10::Device device(c10::DeviceType::PrivateUse1, 0);
storage.set_data_ptr_noswap(at::DataPtr{nullptr, device});
c10::DispatchKeySet key_set({c10::DispatchKey::PrivateUse1});
at::Tensor tensor = at::detail::make_tensor<at::TensorImpl>(
std::move(storage), key_set, c10::scalarTypeToTypeMeta(dtype));
std::vector<int64_t> strides(shape.size());
int64_t size = 1;
for (auto i = strides.size(); i > 0; --i) {
strides[i - 1] = size;
size *= shape[i - 1];
}
tensor.unsafeGetTensorImpl()->set_sizes_and_strides(shape, strides, 0);
return tensor;
}
} // namespace
void initModule(PyObject* module) {
auto py_module = py::reinterpret_borrow<py::module>(module);
auto _acc =
py_module.def_submodule("_acc", "classes related to custom accelerators");
py::class_<at::PrivateUse1HooksInterface, PythonHooks>(
_acc.ptr(), "PrivateUse1Hooks")
.def(py::init<>())
.def(
"has_primary_context",
&at::PrivateUse1HooksInterface::hasPrimaryContext)
.def("is_built", &at::PrivateUse1HooksInterface::isBuilt)
.def("is_available", &at::PrivateUse1HooksInterface::isAvailable);
py::class_<c10::impl::DeviceGuardImplInterface, PythonDeviceGuard>(
_acc.ptr(), "DeviceGuard")
.def(py::init<>())
.def("type_", &c10::impl::DeviceGuardImplInterface::type);
_acc.def(
"register_python_privateuseone_hook", &registerPythonPrivateUse1Hook);
_acc.def(
"register_python_privateuseone_device_guard",
&registerPythonPrivateUse1DeviceGuard);
_acc.def("create_empty_tensor", &createEmptyTensor);
}
} // namespace torch::acc