Add functions to setup PrivateUse1 as a python backend device. (#157859)

Fixes #156052 and #156444.

This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.

Changes done in this PR:

1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.

This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
This commit is contained in:
Han Qi
2025-09-30 08:39:32 +00:00
committed by PyTorch MergeBot
parent 7f4c3e7d2f
commit 1310d6a1f9
13 changed files with 469 additions and 9 deletions

View File

@ -897,6 +897,7 @@ libtorch_python_core_sources = [
"torch/csrc/Stream.cpp",
"torch/csrc/Event.cpp",
"torch/csrc/TypeInfo.cpp",
"torch/csrc/acc/Module.cpp",
"torch/csrc/api/src/python/init.cpp",
"torch/csrc/autograd/functions/init.cpp",
"torch/csrc/autograd/init.cpp",

View File

@ -9,16 +9,22 @@ std::array<
static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
device_guard_impl_registry;
DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(
void registerDeviceGuard(
DeviceType type,
const DeviceGuardImplInterface* impl) {
device_guard_impl_registry[static_cast<size_t>(type)].store(impl);
}
DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(
DeviceType type,
const DeviceGuardImplInterface* impl) {
registerDeviceGuard(type, impl);
}
namespace {
thread_local std::unique_ptr<DeviceGuardImplInterface> tls_fake_device_guard =
nullptr;
}
} // namespace
void ensureCUDADeviceGuardSet() {
constexpr auto cuda_idx = static_cast<std::size_t>(DeviceType::CUDA);

View File

@ -368,6 +368,9 @@ inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
return p;
}
void C10_API
registerDeviceGuard(DeviceType type, const DeviceGuardImplInterface* impl);
inline bool hasDeviceGuardImpl(DeviceType type) {
return device_guard_impl_registry[static_cast<size_t>(type)].load();
}

View File

@ -284,6 +284,7 @@ RUN_PARALLEL_BLOCKLIST = [
# temporarily sets a global config
"test_autograd_fallback",
"inductor/test_compiler_bisector",
"test_privateuseone_python_backend",
] + FSDP_TEST
# Test files that should always be run serially with other test files,

View File

@ -0,0 +1,147 @@
# Owner(s): ["module: PrivateUse1"]
import numpy as np
import torch
import torch._C
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils.backend_registration import _setup_privateuseone_for_python_backend
_setup_privateuseone_for_python_backend("npy")
aten = torch.ops.aten
# NOTE: From https://github.com/albanD/subclass_zoo/blob/main/new_device.py
# but using torch.library instead of `__torch_dispatch__`
class MyDeviceTensor(torch.Tensor):
@staticmethod
def __new__(cls, size, dtype, raw_data=None, requires_grad=False):
# Use a meta Tensor here to be used as the wrapper
res = torch._C._acc.create_empty_tensor(size, dtype)
res.__class__ = MyDeviceTensor
return res
def __init__(self, size, dtype, raw_data=None, requires_grad=False):
# Store any provided user raw_data
self.raw_data = raw_data
def __repr__(self):
return "MyDeviceTensor" + str(self.raw_data)
__str__ = __repr__
def wrap(arr, shape, dtype):
# hard code float32 for tests
return MyDeviceTensor(shape, dtype, arr)
def unwrap(arr):
return arr.raw_data
# Add some ops
@torch.library.impl("aten::add.Tensor", "privateuseone")
def add(t1, t2):
out = unwrap(t1) + unwrap(t2)
return wrap(out, out.shape, torch.float32)
@torch.library.impl("aten::mul.Tensor", "privateuseone")
def mul(t1, t2):
# If unsure what should be the result's properties, you can
# use the super_fn (can be useful for type promotion)
out = unwrap(t1) * unwrap(t2)
return wrap(out, out.shape, torch.float32)
@torch.library.impl("aten::detach", "privateuseone")
def detach(self):
out = unwrap(self)
return wrap(out, out.shape, torch.float32)
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(
size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
):
out = np.empty(size)
return wrap(out, out.shape, torch.float32)
@torch.library.impl("aten::_copy_from", "privateuseone")
def _copy_from(a, b):
if a.device.type == "npy":
npy_data = unwrap(a)
else:
npy_data = a.numpy()
b.raw_data = npy_data
@torch.library.impl("aten::view", "privateuseone")
def _view(a, b):
ans = unwrap(a)
return wrap(ans, a.shape, a.dtype)
@torch.library.impl("aten::empty.memory_format", "privateuseone")
def empty_memory_format(
size, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
):
ans = np.empty(size)
return wrap(ans, ans.shape, torch.float32)
@torch.library.impl("aten::sum", "privateuseone")
def sum_int_list(*args, **kwargs):
ans = unwrap(args[0]).sum()
return wrap(ans, ans.shape, torch.float32)
@torch.library.impl("aten::ones_like", "privateuseone")
def ones_like(
self, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
):
ans = np.ones_like(unwrap(self))
return wrap(ans, ans.shape, torch.float32)
@torch.library.impl("aten::expand", "privateuseone")
def expand(self, size, *, implicit=False):
ans = np.broadcast_to(self.raw_data, size)
return wrap(ans, ans.shape, torch.float32)
@torch.library.impl("aten::as_strided", "privateuseone")
def as_strided(self, size, stride, storage_offset=None):
ans = np.lib.stride_tricks.as_strided(self.raw_data, size, stride)
return wrap(ans, ans.shape, torch.float32)
class PrivateUse1BackendTest(TestCase):
@classmethod
def setupClass(cls):
pass
def test_accessing_is_pinned(self):
a_cpu = torch.randn((2, 2))
# Assert this don't throw:
_ = a_cpu.is_pinned()
def test_backend_simple(self):
a_cpu = torch.randn((2, 2))
b_cpu = torch.randn((2, 2))
# Assert this don't throw:
a = a_cpu.to("privateuseone")
b = b_cpu.to("privateuseone")
a.requires_grad = True
b.requires_grad = True
c = (a + b).sum()
c.backward()
self.assertTrue(np.allclose(a.grad.raw_data, np.ones((2, 2))))
if __name__ == "__main__":
run_tests()

View File

@ -26,6 +26,7 @@ import numpy
import torch
from torch import inf, SymInt, Tensor
from torch._C import (
_acc,
_aoti,
_cpu,
_dynamo,

View File

@ -0,0 +1,15 @@
from torch import Tensor
from torch.types import _dtype, _int, Device
# Defined in torch/csrc/acc/Module.cpp
class PrivateUse1Hooks:
def has_primary_context(self, device_index: _int) -> bool: ...
def is_built(self) -> bool: ...
def is_avaible(self) -> bool: ...
class DeviceGuard:
def type_(self) -> Device: ...
def register_python_privateuseone_device_guard(guard: DeviceGuard) -> bool: ...
def register_python_privateuseone_hook(hook: PrivateUse1Hooks) -> bool: ...
def create_empty_tensor(shape: tuple[_int, ...], dtype: _dtype) -> Tensor: ...

View File

@ -887,13 +887,15 @@ class MetaConverter(Generic[_TensorT]):
f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
)
# This indicates you set no_dispatch() before calling into this
# function. This is an error: we may be creating fake tensors and
# will perform operations on them which need fake tensor mode to
# be active. You will segfault if you are in a no_dispatch() block.
msg = (
" This indicates you set no_dispatch() before calling into this"
" function. This is an error: we may be creating fake tensors and"
" will perform operations on them which need fake tensor mode to"
" be active. You will segfault if you are in a no_dispatch() block."
)
assert not torch._C._dispatch_tls_local_exclude_set().has(
torch._C.DispatchKey.Python
)
), msg
self.arg_cnt += 1
# When we make as_strided calls, we end up generating a guard

View File

@ -56,6 +56,7 @@
#include <torch/csrc/Stream.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/TypeInfo.h>
#include <torch/csrc/acc/Module.h>
#include <torch/csrc/api/include/torch/python/init.h>
#include <torch/csrc/autograd/generated/python_return_types.h>
#include <torch/csrc/autograd/python_cpp_function.h>
@ -2097,6 +2098,7 @@ PyObject* initModule() {
torch::cpu::initModule(module);
torch::accelerator::initModule(module);
torch::instruction_counter::initModule(module);
torch::acc::initModule(module);
torch::initVerboseBindings(module);
ASSERT_TRUE(THPStorage_init(module));
torch::functionalization::initModule(module);

196
torch/csrc/acc/Module.cpp Normal file
View File

@ -0,0 +1,196 @@
#include <torch/csrc/acc/Module.h>
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/utils/pybind.h>
namespace py = pybind11;
namespace torch::acc {
// python hook interface
struct PythonHooks final : public at::PrivateUse1HooksInterface {
using at::PrivateUse1HooksInterface::PrivateUse1HooksInterface;
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
PYBIND11_OVERRIDE_PURE_NAME(
bool,
at::PrivateUse1HooksInterface,
"has_primary_context",
hasPrimaryContext,
device_index);
}
bool isBuilt() const override {
PYBIND11_OVERRIDE_PURE_NAME(
bool, at::PrivateUse1HooksInterface, "is_built", isBuilt, );
}
bool isAvailable() const override {
PYBIND11_OVERRIDE_PURE_NAME(
bool, at::PrivateUse1HooksInterface, "is_available", isBuilt, );
}
// TODO(qihqi): these is not supported from python yet
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
return at::PrivateUse1HooksInterface::getDefaultGenerator(device_index);
}
at::Generator getNewGenerator(
c10::DeviceIndex device_index = -1) const override {
return at::PrivateUse1HooksInterface::getNewGenerator(device_index);
}
at::Device getDeviceFromPtr(void* data) const override {
return at::PrivateUse1HooksInterface::getDeviceFromPtr(data);
}
bool isPinnedPtr(const void* data) const override {
return at::PrivateUse1HooksInterface::isPinnedPtr(data);
}
at::Allocator* getPinnedMemoryAllocator() const override {
return at::PrivateUse1HooksInterface::getPinnedMemoryAllocator();
}
};
struct PythonDeviceGuard final : public c10::impl::DeviceGuardImplInterface {
using c10::impl::DeviceGuardImplInterface::DeviceGuardImplInterface;
c10::DeviceType type() const override {
PYBIND11_OVERRIDE_PURE_NAME(
c10::DeviceType, c10::impl::DeviceGuardImplInterface, "type_", type, );
}
// TODO(qihqi): figure out if those are even useful
// to python or not
c10::Device exchangeDevice(c10::Device device) const override {
return getDevice();
}
c10::Device getDevice() const override {
return c10::Device(type(), 0);
}
void setDevice(c10::Device device) const override {}
void uncheckedSetDevice(c10::Device device) const noexcept override {}
c10::Stream getStream(c10::Device) const noexcept override {
// no-op
return c10::Stream(c10::Stream::DEFAULT, getDevice());
}
c10::Stream getNewStream(c10::Device, int priority = 0) const override {
// no-op
(void)priority;
return c10::Stream(c10::Stream::DEFAULT, getDevice());
}
c10::Stream exchangeStream(c10::Stream) const noexcept override {
// no-op
return c10::Stream(c10::Stream::DEFAULT, getDevice());
}
c10::DeviceIndex deviceCount() const noexcept override {
return 1;
}
// TODO(qihqi): support Event-related functions
void record(
void** /*event*/,
const c10::Stream& /*stream*/,
const c10::DeviceIndex /*device_index*/,
const c10::EventFlag /*flag*/) const override {}
void block(void* /*event*/, const c10::Stream& /*stream*/) const override {}
bool queryEvent(void* /*event*/) const override {
return true;
}
void destroyEvent(void* /*event*/, const c10::DeviceIndex /*device_index*/)
const noexcept override {}
// Stream-related functions
bool queryStream(const c10::Stream& /*stream*/) const override {
return true;
}
void synchronizeStream(const c10::Stream& /*stream*/) const override {}
};
namespace {
bool registerPythonPrivateUse1Hook(const py::object& hook) {
if (at::isPrivateUse1HooksRegistered()) {
return false;
}
hook.inc_ref();
at::RegisterPrivateUse1HooksInterface(
hook.cast<PrivateUse1HooksInterface*>());
return true;
}
bool registerPythonPrivateUse1DeviceGuard(const py::object& guard) {
if (c10::impl::hasDeviceGuardImpl(c10::DeviceType::PrivateUse1)) {
return false;
}
guard.inc_ref();
c10::impl::registerDeviceGuard(
c10::DeviceType::PrivateUse1,
guard.cast<c10::impl::DeviceGuardImplInterface*>());
return true;
}
at::Tensor createEmptyTensor(
const std::vector<int64_t>& shape,
c10::ScalarType dtype) {
c10::Storage storage{
c10::Storage::use_byte_size_t{},
0,
c10::GetAllocator(c10::kMeta),
true,
};
c10::Device device(c10::DeviceType::PrivateUse1, 0);
storage.set_data_ptr_noswap(at::DataPtr{nullptr, device});
c10::DispatchKeySet key_set({c10::DispatchKey::PrivateUse1});
at::Tensor tensor = at::detail::make_tensor<at::TensorImpl>(
std::move(storage), key_set, c10::scalarTypeToTypeMeta(dtype));
std::vector<int64_t> strides(shape.size());
int64_t size = 1;
for (auto i = strides.size(); i > 0; --i) {
strides[i - 1] = size;
size *= shape[i - 1];
}
tensor.unsafeGetTensorImpl()->set_sizes_and_strides(shape, strides, 0);
return tensor;
}
} // namespace
void initModule(PyObject* module) {
auto py_module = py::reinterpret_borrow<py::module>(module);
auto _acc =
py_module.def_submodule("_acc", "classes related to custom accelerators");
py::class_<at::PrivateUse1HooksInterface, PythonHooks>(
_acc.ptr(), "PrivateUse1Hooks")
.def(py::init<>())
.def(
"has_primary_context",
&at::PrivateUse1HooksInterface::hasPrimaryContext)
.def("is_built", &at::PrivateUse1HooksInterface::isBuilt)
.def("is_available", &at::PrivateUse1HooksInterface::isAvailable);
py::class_<c10::impl::DeviceGuardImplInterface, PythonDeviceGuard>(
_acc.ptr(), "DeviceGuard")
.def(py::init<>())
.def("type_", &c10::impl::DeviceGuardImplInterface::type);
_acc.def(
"register_python_privateuseone_hook", &registerPythonPrivateUse1Hook);
_acc.def(
"register_python_privateuseone_device_guard",
&registerPythonPrivateUse1DeviceGuard);
_acc.def("create_empty_tensor", &createEmptyTensor);
}
} // namespace torch::acc

8
torch/csrc/acc/Module.h Normal file
View File

@ -0,0 +1,8 @@
#pragma once
#include <torch/csrc/python_headers.h>
namespace torch::acc {
// PyMethodDef* python_functions();
void initModule(PyObject* module);
} // namespace torch::acc

View File

@ -6,7 +6,10 @@ from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend
from torch.overrides import handle_torch_function, has_torch_function_unary
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
__all__ = [
"rename_privateuse1_backend",
"generate_methods_for_privateuse1_backend",
]
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
# renamed-backend name for `privateuse1`, but the func will cause an
@ -438,3 +441,78 @@ def _get_custom_mod_func(func_name: str):
message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
raise RuntimeError(message)
return function
class _DummyBackendModule:
def is_initialized(self):
return True
def is_available(self):
return True
def current_device(self):
return 0
def _is_in_bad_fork(self):
return False
def manual_seed_all(self, seed: int):
pass
def device_count(self):
return 1
class _DummyPrivateUse1Hook(torch._C._acc.PrivateUse1Hooks):
def is_available(self):
return True
def has_primary_context(self, dev_id):
return True
def is_built(self):
return True
class _DummyDeviceGuard(torch._C._acc.DeviceGuard):
def type_(self):
return torch._C._autograd.DeviceType.PrivateUse1
def _setup_privateuseone_for_python_backend(
rename=None, backend_module=None, hook=None, device_guard=None
):
"""This function will prepare the PrivateUse1 dispatch key to be used as a python backend.
WARNING: this API is experimental and might change without notice.
Formally, this registers things that Pytorch expects a registered backend
in C++ to have: including device guards, hooks, and backend modules and what not.
after this call, one can use `torch.library` to write Ops for this dispatch key
and expect it to behave like a backend registered in C++.
See the unit test at test/test_privateuseone_python_backend.py for more details.
Args:
rename: str | None, if passed in, we will rename privateuseone backend to
the name given.
backend_module: object | None, if passed in None, we will use DummyBackendModule
hook: object | None, if passed in None, we will use DummyPrivateUse1Hook
device_guard: object | None, if passed in None, we will use DummyDeviceGuard
"""
# NOTE: the ordering of which these functions are called is important.
if rename is not None:
torch.utils.rename_privateuse1_backend(rename)
else:
rename = "privateuseone"
torch.utils.generate_methods_for_privateuse1_backend()
if backend_module is None:
backend_module = _DummyBackendModule()
if hook is None:
hook = _DummyPrivateUse1Hook()
if device_guard is None:
device_guard = _DummyDeviceGuard()
torch._register_device_module(rename, backend_module)
torch._C._acc.register_python_privateuseone_hook(hook)
torch._C._acc.register_python_privateuseone_device_guard(device_guard)