mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add functions to setup PrivateUse1 as a python backend device. (#157859)
Fixes #156052 and #156444. This PR setup the privateuseone key in Python to be used as a python backend for pytorch. Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor. Changes done in this PR: 1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine. 2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops. 3. python function that invokes the other incantations to setup the privateusekey backend. This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859 Approved by: https://github.com/albanD
This commit is contained in:
@ -897,6 +897,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/Stream.cpp",
|
||||
"torch/csrc/Event.cpp",
|
||||
"torch/csrc/TypeInfo.cpp",
|
||||
"torch/csrc/acc/Module.cpp",
|
||||
"torch/csrc/api/src/python/init.cpp",
|
||||
"torch/csrc/autograd/functions/init.cpp",
|
||||
"torch/csrc/autograd/init.cpp",
|
||||
|
@ -9,16 +9,22 @@ std::array<
|
||||
static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
|
||||
device_guard_impl_registry;
|
||||
|
||||
DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(
|
||||
void registerDeviceGuard(
|
||||
DeviceType type,
|
||||
const DeviceGuardImplInterface* impl) {
|
||||
device_guard_impl_registry[static_cast<size_t>(type)].store(impl);
|
||||
}
|
||||
|
||||
DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(
|
||||
DeviceType type,
|
||||
const DeviceGuardImplInterface* impl) {
|
||||
registerDeviceGuard(type, impl);
|
||||
}
|
||||
|
||||
namespace {
|
||||
thread_local std::unique_ptr<DeviceGuardImplInterface> tls_fake_device_guard =
|
||||
nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ensureCUDADeviceGuardSet() {
|
||||
constexpr auto cuda_idx = static_cast<std::size_t>(DeviceType::CUDA);
|
||||
|
@ -368,6 +368,9 @@ inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
|
||||
return p;
|
||||
}
|
||||
|
||||
void C10_API
|
||||
registerDeviceGuard(DeviceType type, const DeviceGuardImplInterface* impl);
|
||||
|
||||
inline bool hasDeviceGuardImpl(DeviceType type) {
|
||||
return device_guard_impl_registry[static_cast<size_t>(type)].load();
|
||||
}
|
||||
|
@ -284,6 +284,7 @@ RUN_PARALLEL_BLOCKLIST = [
|
||||
# temporarily sets a global config
|
||||
"test_autograd_fallback",
|
||||
"inductor/test_compiler_bisector",
|
||||
"test_privateuseone_python_backend",
|
||||
] + FSDP_TEST
|
||||
|
||||
# Test files that should always be run serially with other test files,
|
||||
|
147
test/test_privateuseone_python_backend.py
Normal file
147
test/test_privateuseone_python_backend.py
Normal file
@ -0,0 +1,147 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils.backend_registration import _setup_privateuseone_for_python_backend
|
||||
|
||||
|
||||
_setup_privateuseone_for_python_backend("npy")
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
# NOTE: From https://github.com/albanD/subclass_zoo/blob/main/new_device.py
|
||||
# but using torch.library instead of `__torch_dispatch__`
|
||||
class MyDeviceTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, size, dtype, raw_data=None, requires_grad=False):
|
||||
# Use a meta Tensor here to be used as the wrapper
|
||||
res = torch._C._acc.create_empty_tensor(size, dtype)
|
||||
res.__class__ = MyDeviceTensor
|
||||
return res
|
||||
|
||||
def __init__(self, size, dtype, raw_data=None, requires_grad=False):
|
||||
# Store any provided user raw_data
|
||||
self.raw_data = raw_data
|
||||
|
||||
def __repr__(self):
|
||||
return "MyDeviceTensor" + str(self.raw_data)
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
def wrap(arr, shape, dtype):
|
||||
# hard code float32 for tests
|
||||
return MyDeviceTensor(shape, dtype, arr)
|
||||
|
||||
|
||||
def unwrap(arr):
|
||||
return arr.raw_data
|
||||
|
||||
|
||||
# Add some ops
|
||||
@torch.library.impl("aten::add.Tensor", "privateuseone")
|
||||
def add(t1, t2):
|
||||
out = unwrap(t1) + unwrap(t2)
|
||||
return wrap(out, out.shape, torch.float32)
|
||||
|
||||
|
||||
@torch.library.impl("aten::mul.Tensor", "privateuseone")
|
||||
def mul(t1, t2):
|
||||
# If unsure what should be the result's properties, you can
|
||||
# use the super_fn (can be useful for type promotion)
|
||||
out = unwrap(t1) * unwrap(t2)
|
||||
return wrap(out, out.shape, torch.float32)
|
||||
|
||||
|
||||
@torch.library.impl("aten::detach", "privateuseone")
|
||||
def detach(self):
|
||||
out = unwrap(self)
|
||||
return wrap(out, out.shape, torch.float32)
|
||||
|
||||
|
||||
@torch.library.impl("aten::empty_strided", "privateuseone")
|
||||
def empty_strided(
|
||||
size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
|
||||
):
|
||||
out = np.empty(size)
|
||||
return wrap(out, out.shape, torch.float32)
|
||||
|
||||
|
||||
@torch.library.impl("aten::_copy_from", "privateuseone")
|
||||
def _copy_from(a, b):
|
||||
if a.device.type == "npy":
|
||||
npy_data = unwrap(a)
|
||||
else:
|
||||
npy_data = a.numpy()
|
||||
b.raw_data = npy_data
|
||||
|
||||
|
||||
@torch.library.impl("aten::view", "privateuseone")
|
||||
def _view(a, b):
|
||||
ans = unwrap(a)
|
||||
return wrap(ans, a.shape, a.dtype)
|
||||
|
||||
|
||||
@torch.library.impl("aten::empty.memory_format", "privateuseone")
|
||||
def empty_memory_format(
|
||||
size, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
|
||||
):
|
||||
ans = np.empty(size)
|
||||
return wrap(ans, ans.shape, torch.float32)
|
||||
|
||||
|
||||
@torch.library.impl("aten::sum", "privateuseone")
|
||||
def sum_int_list(*args, **kwargs):
|
||||
ans = unwrap(args[0]).sum()
|
||||
return wrap(ans, ans.shape, torch.float32)
|
||||
|
||||
|
||||
@torch.library.impl("aten::ones_like", "privateuseone")
|
||||
def ones_like(
|
||||
self, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
|
||||
):
|
||||
ans = np.ones_like(unwrap(self))
|
||||
return wrap(ans, ans.shape, torch.float32)
|
||||
|
||||
|
||||
@torch.library.impl("aten::expand", "privateuseone")
|
||||
def expand(self, size, *, implicit=False):
|
||||
ans = np.broadcast_to(self.raw_data, size)
|
||||
return wrap(ans, ans.shape, torch.float32)
|
||||
|
||||
|
||||
@torch.library.impl("aten::as_strided", "privateuseone")
|
||||
def as_strided(self, size, stride, storage_offset=None):
|
||||
ans = np.lib.stride_tricks.as_strided(self.raw_data, size, stride)
|
||||
return wrap(ans, ans.shape, torch.float32)
|
||||
|
||||
|
||||
class PrivateUse1BackendTest(TestCase):
|
||||
@classmethod
|
||||
def setupClass(cls):
|
||||
pass
|
||||
|
||||
def test_accessing_is_pinned(self):
|
||||
a_cpu = torch.randn((2, 2))
|
||||
# Assert this don't throw:
|
||||
_ = a_cpu.is_pinned()
|
||||
|
||||
def test_backend_simple(self):
|
||||
a_cpu = torch.randn((2, 2))
|
||||
b_cpu = torch.randn((2, 2))
|
||||
# Assert this don't throw:
|
||||
a = a_cpu.to("privateuseone")
|
||||
b = b_cpu.to("privateuseone")
|
||||
|
||||
a.requires_grad = True
|
||||
b.requires_grad = True
|
||||
c = (a + b).sum()
|
||||
c.backward()
|
||||
self.assertTrue(np.allclose(a.grad.raw_data, np.ones((2, 2))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -26,6 +26,7 @@ import numpy
|
||||
import torch
|
||||
from torch import inf, SymInt, Tensor
|
||||
from torch._C import (
|
||||
_acc,
|
||||
_aoti,
|
||||
_cpu,
|
||||
_dynamo,
|
||||
|
15
torch/_C/_acc/__init__.pyi
Normal file
15
torch/_C/_acc/__init__.pyi
Normal file
@ -0,0 +1,15 @@
|
||||
from torch import Tensor
|
||||
from torch.types import _dtype, _int, Device
|
||||
|
||||
# Defined in torch/csrc/acc/Module.cpp
|
||||
class PrivateUse1Hooks:
|
||||
def has_primary_context(self, device_index: _int) -> bool: ...
|
||||
def is_built(self) -> bool: ...
|
||||
def is_avaible(self) -> bool: ...
|
||||
|
||||
class DeviceGuard:
|
||||
def type_(self) -> Device: ...
|
||||
|
||||
def register_python_privateuseone_device_guard(guard: DeviceGuard) -> bool: ...
|
||||
def register_python_privateuseone_hook(hook: PrivateUse1Hooks) -> bool: ...
|
||||
def create_empty_tensor(shape: tuple[_int, ...], dtype: _dtype) -> Tensor: ...
|
@ -887,13 +887,15 @@ class MetaConverter(Generic[_TensorT]):
|
||||
f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
|
||||
)
|
||||
|
||||
# This indicates you set no_dispatch() before calling into this
|
||||
# function. This is an error: we may be creating fake tensors and
|
||||
# will perform operations on them which need fake tensor mode to
|
||||
# be active. You will segfault if you are in a no_dispatch() block.
|
||||
msg = (
|
||||
" This indicates you set no_dispatch() before calling into this"
|
||||
" function. This is an error: we may be creating fake tensors and"
|
||||
" will perform operations on them which need fake tensor mode to"
|
||||
" be active. You will segfault if you are in a no_dispatch() block."
|
||||
)
|
||||
assert not torch._C._dispatch_tls_local_exclude_set().has(
|
||||
torch._C.DispatchKey.Python
|
||||
)
|
||||
), msg
|
||||
self.arg_cnt += 1
|
||||
|
||||
# When we make as_strided calls, we end up generating a guard
|
||||
|
@ -56,6 +56,7 @@
|
||||
#include <torch/csrc/Stream.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/TypeInfo.h>
|
||||
#include <torch/csrc/acc/Module.h>
|
||||
#include <torch/csrc/api/include/torch/python/init.h>
|
||||
#include <torch/csrc/autograd/generated/python_return_types.h>
|
||||
#include <torch/csrc/autograd/python_cpp_function.h>
|
||||
@ -2097,6 +2098,7 @@ PyObject* initModule() {
|
||||
torch::cpu::initModule(module);
|
||||
torch::accelerator::initModule(module);
|
||||
torch::instruction_counter::initModule(module);
|
||||
torch::acc::initModule(module);
|
||||
torch::initVerboseBindings(module);
|
||||
ASSERT_TRUE(THPStorage_init(module));
|
||||
torch::functionalization::initModule(module);
|
||||
|
196
torch/csrc/acc/Module.cpp
Normal file
196
torch/csrc/acc/Module.cpp
Normal file
@ -0,0 +1,196 @@
|
||||
#include <torch/csrc/acc/Module.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace torch::acc {
|
||||
|
||||
// python hook interface
|
||||
struct PythonHooks final : public at::PrivateUse1HooksInterface {
|
||||
using at::PrivateUse1HooksInterface::PrivateUse1HooksInterface;
|
||||
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
|
||||
PYBIND11_OVERRIDE_PURE_NAME(
|
||||
bool,
|
||||
at::PrivateUse1HooksInterface,
|
||||
"has_primary_context",
|
||||
hasPrimaryContext,
|
||||
device_index);
|
||||
}
|
||||
|
||||
bool isBuilt() const override {
|
||||
PYBIND11_OVERRIDE_PURE_NAME(
|
||||
bool, at::PrivateUse1HooksInterface, "is_built", isBuilt, );
|
||||
}
|
||||
|
||||
bool isAvailable() const override {
|
||||
PYBIND11_OVERRIDE_PURE_NAME(
|
||||
bool, at::PrivateUse1HooksInterface, "is_available", isBuilt, );
|
||||
}
|
||||
|
||||
// TODO(qihqi): these is not supported from python yet
|
||||
const at::Generator& getDefaultGenerator(
|
||||
c10::DeviceIndex device_index) const override {
|
||||
return at::PrivateUse1HooksInterface::getDefaultGenerator(device_index);
|
||||
}
|
||||
|
||||
at::Generator getNewGenerator(
|
||||
c10::DeviceIndex device_index = -1) const override {
|
||||
return at::PrivateUse1HooksInterface::getNewGenerator(device_index);
|
||||
}
|
||||
|
||||
at::Device getDeviceFromPtr(void* data) const override {
|
||||
return at::PrivateUse1HooksInterface::getDeviceFromPtr(data);
|
||||
}
|
||||
|
||||
bool isPinnedPtr(const void* data) const override {
|
||||
return at::PrivateUse1HooksInterface::isPinnedPtr(data);
|
||||
}
|
||||
|
||||
at::Allocator* getPinnedMemoryAllocator() const override {
|
||||
return at::PrivateUse1HooksInterface::getPinnedMemoryAllocator();
|
||||
}
|
||||
};
|
||||
|
||||
struct PythonDeviceGuard final : public c10::impl::DeviceGuardImplInterface {
|
||||
using c10::impl::DeviceGuardImplInterface::DeviceGuardImplInterface;
|
||||
|
||||
c10::DeviceType type() const override {
|
||||
PYBIND11_OVERRIDE_PURE_NAME(
|
||||
c10::DeviceType, c10::impl::DeviceGuardImplInterface, "type_", type, );
|
||||
}
|
||||
|
||||
// TODO(qihqi): figure out if those are even useful
|
||||
// to python or not
|
||||
c10::Device exchangeDevice(c10::Device device) const override {
|
||||
return getDevice();
|
||||
}
|
||||
c10::Device getDevice() const override {
|
||||
return c10::Device(type(), 0);
|
||||
}
|
||||
void setDevice(c10::Device device) const override {}
|
||||
void uncheckedSetDevice(c10::Device device) const noexcept override {}
|
||||
c10::Stream getStream(c10::Device) const noexcept override {
|
||||
// no-op
|
||||
return c10::Stream(c10::Stream::DEFAULT, getDevice());
|
||||
}
|
||||
|
||||
c10::Stream getNewStream(c10::Device, int priority = 0) const override {
|
||||
// no-op
|
||||
(void)priority;
|
||||
return c10::Stream(c10::Stream::DEFAULT, getDevice());
|
||||
}
|
||||
|
||||
c10::Stream exchangeStream(c10::Stream) const noexcept override {
|
||||
// no-op
|
||||
return c10::Stream(c10::Stream::DEFAULT, getDevice());
|
||||
}
|
||||
c10::DeviceIndex deviceCount() const noexcept override {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// TODO(qihqi): support Event-related functions
|
||||
void record(
|
||||
void** /*event*/,
|
||||
const c10::Stream& /*stream*/,
|
||||
const c10::DeviceIndex /*device_index*/,
|
||||
const c10::EventFlag /*flag*/) const override {}
|
||||
void block(void* /*event*/, const c10::Stream& /*stream*/) const override {}
|
||||
bool queryEvent(void* /*event*/) const override {
|
||||
return true;
|
||||
}
|
||||
void destroyEvent(void* /*event*/, const c10::DeviceIndex /*device_index*/)
|
||||
const noexcept override {}
|
||||
|
||||
// Stream-related functions
|
||||
bool queryStream(const c10::Stream& /*stream*/) const override {
|
||||
return true;
|
||||
}
|
||||
void synchronizeStream(const c10::Stream& /*stream*/) const override {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
bool registerPythonPrivateUse1Hook(const py::object& hook) {
|
||||
if (at::isPrivateUse1HooksRegistered()) {
|
||||
return false;
|
||||
}
|
||||
hook.inc_ref();
|
||||
at::RegisterPrivateUse1HooksInterface(
|
||||
hook.cast<PrivateUse1HooksInterface*>());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool registerPythonPrivateUse1DeviceGuard(const py::object& guard) {
|
||||
if (c10::impl::hasDeviceGuardImpl(c10::DeviceType::PrivateUse1)) {
|
||||
return false;
|
||||
}
|
||||
guard.inc_ref();
|
||||
c10::impl::registerDeviceGuard(
|
||||
c10::DeviceType::PrivateUse1,
|
||||
guard.cast<c10::impl::DeviceGuardImplInterface*>());
|
||||
return true;
|
||||
}
|
||||
|
||||
at::Tensor createEmptyTensor(
|
||||
const std::vector<int64_t>& shape,
|
||||
c10::ScalarType dtype) {
|
||||
c10::Storage storage{
|
||||
c10::Storage::use_byte_size_t{},
|
||||
0,
|
||||
c10::GetAllocator(c10::kMeta),
|
||||
true,
|
||||
};
|
||||
|
||||
c10::Device device(c10::DeviceType::PrivateUse1, 0);
|
||||
storage.set_data_ptr_noswap(at::DataPtr{nullptr, device});
|
||||
c10::DispatchKeySet key_set({c10::DispatchKey::PrivateUse1});
|
||||
at::Tensor tensor = at::detail::make_tensor<at::TensorImpl>(
|
||||
std::move(storage), key_set, c10::scalarTypeToTypeMeta(dtype));
|
||||
|
||||
std::vector<int64_t> strides(shape.size());
|
||||
int64_t size = 1;
|
||||
for (auto i = strides.size(); i > 0; --i) {
|
||||
strides[i - 1] = size;
|
||||
size *= shape[i - 1];
|
||||
}
|
||||
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_and_strides(shape, strides, 0);
|
||||
return tensor;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void initModule(PyObject* module) {
|
||||
auto py_module = py::reinterpret_borrow<py::module>(module);
|
||||
auto _acc =
|
||||
py_module.def_submodule("_acc", "classes related to custom accelerators");
|
||||
|
||||
py::class_<at::PrivateUse1HooksInterface, PythonHooks>(
|
||||
_acc.ptr(), "PrivateUse1Hooks")
|
||||
.def(py::init<>())
|
||||
.def(
|
||||
"has_primary_context",
|
||||
&at::PrivateUse1HooksInterface::hasPrimaryContext)
|
||||
.def("is_built", &at::PrivateUse1HooksInterface::isBuilt)
|
||||
.def("is_available", &at::PrivateUse1HooksInterface::isAvailable);
|
||||
|
||||
py::class_<c10::impl::DeviceGuardImplInterface, PythonDeviceGuard>(
|
||||
_acc.ptr(), "DeviceGuard")
|
||||
.def(py::init<>())
|
||||
.def("type_", &c10::impl::DeviceGuardImplInterface::type);
|
||||
|
||||
_acc.def(
|
||||
"register_python_privateuseone_hook", ®isterPythonPrivateUse1Hook);
|
||||
_acc.def(
|
||||
"register_python_privateuseone_device_guard",
|
||||
®isterPythonPrivateUse1DeviceGuard);
|
||||
_acc.def("create_empty_tensor", &createEmptyTensor);
|
||||
}
|
||||
|
||||
} // namespace torch::acc
|
8
torch/csrc/acc/Module.h
Normal file
8
torch/csrc/acc/Module.h
Normal file
@ -0,0 +1,8 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
namespace torch::acc {
|
||||
// PyMethodDef* python_functions();
|
||||
void initModule(PyObject* module);
|
||||
|
||||
} // namespace torch::acc
|
@ -6,7 +6,10 @@ from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend
|
||||
from torch.overrides import handle_torch_function, has_torch_function_unary
|
||||
|
||||
|
||||
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
|
||||
__all__ = [
|
||||
"rename_privateuse1_backend",
|
||||
"generate_methods_for_privateuse1_backend",
|
||||
]
|
||||
|
||||
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
|
||||
# renamed-backend name for `privateuse1`, but the func will cause an
|
||||
@ -438,3 +441,78 @@ def _get_custom_mod_func(func_name: str):
|
||||
message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
|
||||
raise RuntimeError(message)
|
||||
return function
|
||||
|
||||
|
||||
class _DummyBackendModule:
|
||||
def is_initialized(self):
|
||||
return True
|
||||
|
||||
def is_available(self):
|
||||
return True
|
||||
|
||||
def current_device(self):
|
||||
return 0
|
||||
|
||||
def _is_in_bad_fork(self):
|
||||
return False
|
||||
|
||||
def manual_seed_all(self, seed: int):
|
||||
pass
|
||||
|
||||
def device_count(self):
|
||||
return 1
|
||||
|
||||
|
||||
class _DummyPrivateUse1Hook(torch._C._acc.PrivateUse1Hooks):
|
||||
def is_available(self):
|
||||
return True
|
||||
|
||||
def has_primary_context(self, dev_id):
|
||||
return True
|
||||
|
||||
def is_built(self):
|
||||
return True
|
||||
|
||||
|
||||
class _DummyDeviceGuard(torch._C._acc.DeviceGuard):
|
||||
def type_(self):
|
||||
return torch._C._autograd.DeviceType.PrivateUse1
|
||||
|
||||
|
||||
def _setup_privateuseone_for_python_backend(
|
||||
rename=None, backend_module=None, hook=None, device_guard=None
|
||||
):
|
||||
"""This function will prepare the PrivateUse1 dispatch key to be used as a python backend.
|
||||
|
||||
WARNING: this API is experimental and might change without notice.
|
||||
|
||||
Formally, this registers things that Pytorch expects a registered backend
|
||||
in C++ to have: including device guards, hooks, and backend modules and what not.
|
||||
|
||||
after this call, one can use `torch.library` to write Ops for this dispatch key
|
||||
and expect it to behave like a backend registered in C++.
|
||||
|
||||
See the unit test at test/test_privateuseone_python_backend.py for more details.
|
||||
|
||||
Args:
|
||||
rename: str | None, if passed in, we will rename privateuseone backend to
|
||||
the name given.
|
||||
backend_module: object | None, if passed in None, we will use DummyBackendModule
|
||||
hook: object | None, if passed in None, we will use DummyPrivateUse1Hook
|
||||
device_guard: object | None, if passed in None, we will use DummyDeviceGuard
|
||||
"""
|
||||
# NOTE: the ordering of which these functions are called is important.
|
||||
if rename is not None:
|
||||
torch.utils.rename_privateuse1_backend(rename)
|
||||
else:
|
||||
rename = "privateuseone"
|
||||
torch.utils.generate_methods_for_privateuse1_backend()
|
||||
if backend_module is None:
|
||||
backend_module = _DummyBackendModule()
|
||||
if hook is None:
|
||||
hook = _DummyPrivateUse1Hook()
|
||||
if device_guard is None:
|
||||
device_guard = _DummyDeviceGuard()
|
||||
torch._register_device_module(rename, backend_module)
|
||||
torch._C._acc.register_python_privateuseone_hook(hook)
|
||||
torch._C._acc.register_python_privateuseone_device_guard(device_guard)
|
||||
|
Reference in New Issue
Block a user