From b5c4f46bb9ede8dc6adf11975c93b9f285d9ed67 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 1 Oct 2025 21:32:55 +0000 Subject: [PATCH] Add functions to setup PrivateUse1 as a python backend device. (#157859) Fixes #156052 and #156444. This PR setup the privateuseone key in Python to be used as a python backend for pytorch. Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor. Changes done in this PR: 1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine. 2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops. 3. python function that invokes the other incantations to setup the privateusekey backend. This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859 Approved by: https://github.com/albanD --- build_variables.bzl | 1 + c10/core/impl/DeviceGuardImplInterface.cpp | 10 +- c10/core/impl/DeviceGuardImplInterface.h | 3 + ...PrivateUse1BackendTest.test_backend_simple | 0 test/run_test.py | 1 + test/test_privateuseone_python_backend.py | 147 +++++++++++++ torch/_C/__init__.pyi.in | 1 + torch/_C/_acc/__init__.pyi | 15 ++ torch/_subclasses/meta_utils.py | 12 +- torch/csrc/Module.cpp | 2 + torch/csrc/acc/Module.cpp | 196 ++++++++++++++++++ torch/csrc/acc/Module.h | 8 + torch/utils/backend_registration.py | 80 ++++++- 13 files changed, 468 insertions(+), 8 deletions(-) create mode 100644 test/dynamo_skips/PrivateUse1BackendTest.test_backend_simple create mode 100644 test/test_privateuseone_python_backend.py create mode 100644 torch/_C/_acc/__init__.pyi create mode 100644 torch/csrc/acc/Module.cpp create mode 100644 torch/csrc/acc/Module.h diff --git a/build_variables.bzl b/build_variables.bzl index ecd1e8b79f65..e4dd849be4fe 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -897,6 +897,7 @@ libtorch_python_core_sources = [ "torch/csrc/Stream.cpp", "torch/csrc/Event.cpp", "torch/csrc/TypeInfo.cpp", + "torch/csrc/acc/Module.cpp", "torch/csrc/api/src/python/init.cpp", "torch/csrc/autograd/functions/init.cpp", "torch/csrc/autograd/init.cpp", diff --git a/c10/core/impl/DeviceGuardImplInterface.cpp b/c10/core/impl/DeviceGuardImplInterface.cpp index 52f6f5e8c13a..428ea63c0415 100644 --- a/c10/core/impl/DeviceGuardImplInterface.cpp +++ b/c10/core/impl/DeviceGuardImplInterface.cpp @@ -9,16 +9,22 @@ std::array< static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)> device_guard_impl_registry; -DeviceGuardImplRegistrar::DeviceGuardImplRegistrar( +void registerDeviceGuard( DeviceType type, const DeviceGuardImplInterface* impl) { device_guard_impl_registry[static_cast(type)].store(impl); } +DeviceGuardImplRegistrar::DeviceGuardImplRegistrar( + DeviceType type, + const DeviceGuardImplInterface* impl) { + registerDeviceGuard(type, impl); +} + namespace { thread_local std::unique_ptr tls_fake_device_guard = nullptr; -} +} // namespace void ensureCUDADeviceGuardSet() { constexpr auto cuda_idx = static_cast(DeviceType::CUDA); diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index fc8c367f75e8..e1efa53035b1 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -368,6 +368,9 @@ inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) { return p; } +void C10_API +registerDeviceGuard(DeviceType type, const DeviceGuardImplInterface* impl); + inline bool hasDeviceGuardImpl(DeviceType type) { return device_guard_impl_registry[static_cast(type)].load(); } diff --git a/test/dynamo_skips/PrivateUse1BackendTest.test_backend_simple b/test/dynamo_skips/PrivateUse1BackendTest.test_backend_simple new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/run_test.py b/test/run_test.py index d8bd65c4e95c..745e62331f30 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -284,6 +284,7 @@ RUN_PARALLEL_BLOCKLIST = [ # temporarily sets a global config "test_autograd_fallback", "inductor/test_compiler_bisector", + "test_privateuseone_python_backend", ] + FSDP_TEST # Test files that should always be run serially with other test files, diff --git a/test/test_privateuseone_python_backend.py b/test/test_privateuseone_python_backend.py new file mode 100644 index 000000000000..b767933f0c54 --- /dev/null +++ b/test/test_privateuseone_python_backend.py @@ -0,0 +1,147 @@ +# Owner(s): ["module: PrivateUse1"] +import numpy as np + +import torch +import torch._C +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.utils.backend_registration import _setup_privateuseone_for_python_backend + + +_setup_privateuseone_for_python_backend("npy") + +aten = torch.ops.aten + + +# NOTE: From https://github.com/albanD/subclass_zoo/blob/main/new_device.py +# but using torch.library instead of `__torch_dispatch__` +class MyDeviceTensor(torch.Tensor): + @staticmethod + def __new__(cls, size, dtype, raw_data=None, requires_grad=False): + # Use a meta Tensor here to be used as the wrapper + res = torch._C._acc.create_empty_tensor(size, dtype) + res.__class__ = MyDeviceTensor + return res + + def __init__(self, size, dtype, raw_data=None, requires_grad=False): + # Store any provided user raw_data + self.raw_data = raw_data + + def __repr__(self): + return "MyDeviceTensor" + str(self.raw_data) + + __str__ = __repr__ + + +def wrap(arr, shape, dtype): + # hard code float32 for tests + return MyDeviceTensor(shape, dtype, arr) + + +def unwrap(arr): + return arr.raw_data + + +# Add some ops +@torch.library.impl("aten::add.Tensor", "privateuseone") +def add(t1, t2): + out = unwrap(t1) + unwrap(t2) + return wrap(out, out.shape, torch.float32) + + +@torch.library.impl("aten::mul.Tensor", "privateuseone") +def mul(t1, t2): + # If unsure what should be the result's properties, you can + # use the super_fn (can be useful for type promotion) + out = unwrap(t1) * unwrap(t2) + return wrap(out, out.shape, torch.float32) + + +@torch.library.impl("aten::detach", "privateuseone") +def detach(self): + out = unwrap(self) + return wrap(out, out.shape, torch.float32) + + +@torch.library.impl("aten::empty_strided", "privateuseone") +def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + out = np.empty(size) + return wrap(out, out.shape, torch.float32) + + +@torch.library.impl("aten::_copy_from", "privateuseone") +def _copy_from(a, b): + if a.device.type == "npy": + npy_data = unwrap(a) + else: + npy_data = a.numpy() + b.raw_data = npy_data + + +@torch.library.impl("aten::view", "privateuseone") +def _view(a, b): + ans = unwrap(a) + return wrap(ans, a.shape, a.dtype) + + +@torch.library.impl("aten::empty.memory_format", "privateuseone") +def empty_memory_format( + size, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +): + ans = np.empty(size) + return wrap(ans, ans.shape, torch.float32) + + +@torch.library.impl("aten::sum", "privateuseone") +def sum_int_list(*args, **kwargs): + ans = unwrap(args[0]).sum() + return wrap(ans, ans.shape, torch.float32) + + +@torch.library.impl("aten::ones_like", "privateuseone") +def ones_like( + self, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +): + ans = np.ones_like(unwrap(self)) + return wrap(ans, ans.shape, torch.float32) + + +@torch.library.impl("aten::expand", "privateuseone") +def expand(self, size, *, implicit=False): + ans = np.broadcast_to(self.raw_data, size) + return wrap(ans, ans.shape, torch.float32) + + +@torch.library.impl("aten::as_strided", "privateuseone") +def as_strided(self, size, stride, storage_offset=None): + ans = np.lib.stride_tricks.as_strided(self.raw_data, size, stride) + return wrap(ans, ans.shape, torch.float32) + + +class PrivateUse1BackendTest(TestCase): + @classmethod + def setupClass(cls): + pass + + def test_accessing_is_pinned(self): + a_cpu = torch.randn((2, 2)) + # Assert this don't throw: + _ = a_cpu.is_pinned() + + def test_backend_simple(self): + a_cpu = torch.randn((2, 2)) + b_cpu = torch.randn((2, 2)) + # Assert this don't throw: + a = a_cpu.to("privateuseone") + b = b_cpu.to("privateuseone") + + a.requires_grad = True + b.requires_grad = True + c = (a + b).sum() + c.backward() + self.assertTrue(np.allclose(a.grad.raw_data, np.ones((2, 2)))) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 147dc9a86524..7e29cf9fa218 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -26,6 +26,7 @@ import numpy import torch from torch import inf, SymInt, Tensor from torch._C import ( + _acc, _aoti, _cpu, _dynamo, diff --git a/torch/_C/_acc/__init__.pyi b/torch/_C/_acc/__init__.pyi new file mode 100644 index 000000000000..aa17e5cb2190 --- /dev/null +++ b/torch/_C/_acc/__init__.pyi @@ -0,0 +1,15 @@ +from torch import Tensor +from torch.types import _dtype, _int, Device + +# Defined in torch/csrc/acc/Module.cpp +class PrivateUse1Hooks: + def has_primary_context(self, device_index: _int) -> bool: ... + def is_built(self) -> bool: ... + def is_avaible(self) -> bool: ... + +class DeviceGuard: + def type_(self) -> Device: ... + +def register_python_privateuseone_device_guard(guard: DeviceGuard) -> bool: ... +def register_python_privateuseone_hook(hook: PrivateUse1Hooks) -> bool: ... +def create_empty_tensor(shape: tuple[_int, ...], dtype: _dtype) -> Tensor: ... diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index b73ee9abfc33..c447ffb5d736 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -887,13 +887,15 @@ class MetaConverter(Generic[_TensorT]): f"__meta_utils_unknown_tensor{len(self.tensor_memo)}" ) - # This indicates you set no_dispatch() before calling into this - # function. This is an error: we may be creating fake tensors and - # will perform operations on them which need fake tensor mode to - # be active. You will segfault if you are in a no_dispatch() block. + msg = ( + " This indicates you set no_dispatch() before calling into this" + " function. This is an error: we may be creating fake tensors and" + " will perform operations on them which need fake tensor mode to" + " be active. You will segfault if you are in a no_dispatch() block." + ) assert not torch._C._dispatch_tls_local_exclude_set().has( torch._C.DispatchKey.Python - ) + ), msg self.arg_cnt += 1 # When we make as_strided calls, we end up generating a guard diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 316c35ecff75..b7ec94f4357b 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -56,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -2097,6 +2098,7 @@ PyObject* initModule() { torch::cpu::initModule(module); torch::accelerator::initModule(module); torch::instruction_counter::initModule(module); + torch::acc::initModule(module); torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module)); torch::functionalization::initModule(module); diff --git a/torch/csrc/acc/Module.cpp b/torch/csrc/acc/Module.cpp new file mode 100644 index 000000000000..6360d0430bf8 --- /dev/null +++ b/torch/csrc/acc/Module.cpp @@ -0,0 +1,196 @@ +#include + +#include +#include + +#include +#include +#include +#include + +namespace py = pybind11; + +namespace torch::acc { + +// python hook interface +struct PythonHooks final : public at::PrivateUse1HooksInterface { + using at::PrivateUse1HooksInterface::PrivateUse1HooksInterface; + bool hasPrimaryContext(c10::DeviceIndex device_index) const override { + PYBIND11_OVERRIDE_PURE_NAME( + bool, + at::PrivateUse1HooksInterface, + "has_primary_context", + hasPrimaryContext, + device_index); + } + + bool isBuilt() const override { + PYBIND11_OVERRIDE_PURE_NAME( + bool, at::PrivateUse1HooksInterface, "is_built", isBuilt, ); + } + + bool isAvailable() const override { + PYBIND11_OVERRIDE_PURE_NAME( + bool, at::PrivateUse1HooksInterface, "is_available", isBuilt, ); + } + + // TODO(qihqi): these is not supported from python yet + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const override { + return at::PrivateUse1HooksInterface::getDefaultGenerator(device_index); + } + + at::Generator getNewGenerator( + c10::DeviceIndex device_index = -1) const override { + return at::PrivateUse1HooksInterface::getNewGenerator(device_index); + } + + at::Device getDeviceFromPtr(void* data) const override { + return at::PrivateUse1HooksInterface::getDeviceFromPtr(data); + } + + bool isPinnedPtr(const void* data) const override { + return at::PrivateUse1HooksInterface::isPinnedPtr(data); + } + + at::Allocator* getPinnedMemoryAllocator() const override { + return at::PrivateUse1HooksInterface::getPinnedMemoryAllocator(); + } +}; + +struct PythonDeviceGuard final : public c10::impl::DeviceGuardImplInterface { + using c10::impl::DeviceGuardImplInterface::DeviceGuardImplInterface; + + c10::DeviceType type() const override { + PYBIND11_OVERRIDE_PURE_NAME( + c10::DeviceType, c10::impl::DeviceGuardImplInterface, "type_", type, ); + } + + // TODO(qihqi): figure out if those are even useful + // to python or not + c10::Device exchangeDevice(c10::Device device) const override { + return getDevice(); + } + c10::Device getDevice() const override { + return c10::Device(type(), 0); + } + void setDevice(c10::Device device) const override {} + void uncheckedSetDevice(c10::Device device) const noexcept override {} + c10::Stream getStream(c10::Device) const noexcept override { + // no-op + return c10::Stream(c10::Stream::DEFAULT, getDevice()); + } + + c10::Stream getNewStream(c10::Device, int priority = 0) const override { + // no-op + (void)priority; + return c10::Stream(c10::Stream::DEFAULT, getDevice()); + } + + c10::Stream exchangeStream(c10::Stream) const noexcept override { + // no-op + return c10::Stream(c10::Stream::DEFAULT, getDevice()); + } + c10::DeviceIndex deviceCount() const noexcept override { + return 1; + } + + // TODO(qihqi): support Event-related functions + void record( + void** /*event*/, + const c10::Stream& /*stream*/, + const c10::DeviceIndex /*device_index*/, + const c10::EventFlag /*flag*/) const override {} + void block(void* /*event*/, const c10::Stream& /*stream*/) const override {} + bool queryEvent(void* /*event*/) const override { + return true; + } + void destroyEvent(void* /*event*/, const c10::DeviceIndex /*device_index*/) + const noexcept override {} + + // Stream-related functions + bool queryStream(const c10::Stream& /*stream*/) const override { + return true; + } + void synchronizeStream(const c10::Stream& /*stream*/) const override {} +}; + +namespace { + +bool registerPythonPrivateUse1Hook(const py::object& hook) { + if (at::isPrivateUse1HooksRegistered()) { + return false; + } + hook.inc_ref(); + at::RegisterPrivateUse1HooksInterface( + hook.cast()); + return true; +} + +bool registerPythonPrivateUse1DeviceGuard(const py::object& guard) { + if (c10::impl::hasDeviceGuardImpl(c10::DeviceType::PrivateUse1)) { + return false; + } + guard.inc_ref(); + c10::impl::registerDeviceGuard( + c10::DeviceType::PrivateUse1, + guard.cast()); + return true; +} + +at::Tensor createEmptyTensor( + const std::vector& shape, + c10::ScalarType dtype) { + c10::Storage storage{ + c10::Storage::use_byte_size_t{}, + 0, + c10::GetAllocator(c10::kMeta), + true, + }; + + c10::Device device(c10::DeviceType::PrivateUse1, 0); + storage.set_data_ptr_noswap(at::DataPtr{nullptr, device}); + c10::DispatchKeySet key_set({c10::DispatchKey::PrivateUse1}); + at::Tensor tensor = at::detail::make_tensor( + std::move(storage), key_set, c10::scalarTypeToTypeMeta(dtype)); + + std::vector strides(shape.size()); + int64_t size = 1; + for (auto i = strides.size(); i > 0; --i) { + strides[i - 1] = size; + size *= shape[i - 1]; + } + + tensor.unsafeGetTensorImpl()->set_sizes_and_strides(shape, strides, 0); + return tensor; +} +} // namespace + +void initModule(PyObject* module) { + auto py_module = py::reinterpret_borrow(module); + auto _acc = + py_module.def_submodule("_acc", "classes related to custom accelerators"); + + py::class_( + _acc.ptr(), "PrivateUse1Hooks") + .def(py::init<>()) + .def( + "has_primary_context", + &at::PrivateUse1HooksInterface::hasPrimaryContext) + .def("is_built", &at::PrivateUse1HooksInterface::isBuilt) + .def("is_available", &at::PrivateUse1HooksInterface::isAvailable); + + py::class_( + _acc.ptr(), "DeviceGuard") + .def(py::init<>()) + .def("type_", &c10::impl::DeviceGuardImplInterface::type); + + _acc.def( + "register_python_privateuseone_hook", ®isterPythonPrivateUse1Hook); + _acc.def( + "register_python_privateuseone_device_guard", + ®isterPythonPrivateUse1DeviceGuard); + _acc.def("create_empty_tensor", &createEmptyTensor); +} + +} // namespace torch::acc diff --git a/torch/csrc/acc/Module.h b/torch/csrc/acc/Module.h new file mode 100644 index 000000000000..7fe776e2f783 --- /dev/null +++ b/torch/csrc/acc/Module.h @@ -0,0 +1,8 @@ +#pragma once +#include + +namespace torch::acc { +// PyMethodDef* python_functions(); +void initModule(PyObject* module); + +} // namespace torch::acc diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index a4fcd949ee90..b54bd25f1016 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -6,7 +6,10 @@ from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend from torch.overrides import handle_torch_function, has_torch_function_unary -__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"] +__all__ = [ + "rename_privateuse1_backend", + "generate_methods_for_privateuse1_backend", +] # TODO: Should use `torch._C._get_privateuse1_backend_name()` to get # renamed-backend name for `privateuse1`, but the func will cause an @@ -438,3 +441,78 @@ def _get_custom_mod_func(func_name: str): message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n" raise RuntimeError(message) return function + + +class _DummyBackendModule: + def is_initialized(self): + return True + + def is_available(self): + return True + + def current_device(self): + return 0 + + def _is_in_bad_fork(self): + return False + + def manual_seed_all(self, seed: int): + pass + + def device_count(self): + return 1 + + +class _DummyPrivateUse1Hook(torch._C._acc.PrivateUse1Hooks): + def is_available(self): + return True + + def has_primary_context(self, dev_id): + return True + + def is_built(self): + return True + + +class _DummyDeviceGuard(torch._C._acc.DeviceGuard): + def type_(self): + return torch._C._autograd.DeviceType.PrivateUse1 + + +def _setup_privateuseone_for_python_backend( + rename=None, backend_module=None, hook=None, device_guard=None +): + """This function will prepare the PrivateUse1 dispatch key to be used as a python backend. + + WARNING: this API is experimental and might change without notice. + + Formally, this registers things that Pytorch expects a registered backend + in C++ to have: including device guards, hooks, and backend modules and what not. + + after this call, one can use `torch.library` to write Ops for this dispatch key + and expect it to behave like a backend registered in C++. + + See the unit test at test/test_privateuseone_python_backend.py for more details. + + Args: + rename: str | None, if passed in, we will rename privateuseone backend to + the name given. + backend_module: object | None, if passed in None, we will use DummyBackendModule + hook: object | None, if passed in None, we will use DummyPrivateUse1Hook + device_guard: object | None, if passed in None, we will use DummyDeviceGuard + """ + # NOTE: the ordering of which these functions are called is important. + if rename is not None: + torch.utils.rename_privateuse1_backend(rename) + else: + rename = "privateuseone" + torch.utils.generate_methods_for_privateuse1_backend() + if backend_module is None: + backend_module = _DummyBackendModule() + if hook is None: + hook = _DummyPrivateUse1Hook() + if device_guard is None: + device_guard = _DummyDeviceGuard() + torch._register_device_module(rename, backend_module) + torch._C._acc.register_python_privateuseone_hook(hook) + torch._C._acc.register_python_privateuseone_device_guard(device_guard)