mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add functions to setup PrivateUse1 as a python backend device. (#157859)
Fixes #156052 and #156444. This PR setup the privateuseone key in Python to be used as a python backend for pytorch. Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor. Changes done in this PR: 1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine. 2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops. 3. python function that invokes the other incantations to setup the privateusekey backend. This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859 Approved by: https://github.com/albanD
This commit is contained in:
@ -6,7 +6,10 @@ from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend
|
||||
from torch.overrides import handle_torch_function, has_torch_function_unary
|
||||
|
||||
|
||||
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
|
||||
__all__ = [
|
||||
"rename_privateuse1_backend",
|
||||
"generate_methods_for_privateuse1_backend",
|
||||
]
|
||||
|
||||
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
|
||||
# renamed-backend name for `privateuse1`, but the func will cause an
|
||||
@ -438,3 +441,78 @@ def _get_custom_mod_func(func_name: str):
|
||||
message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
|
||||
raise RuntimeError(message)
|
||||
return function
|
||||
|
||||
|
||||
class _DummyBackendModule:
|
||||
def is_initialized(self):
|
||||
return True
|
||||
|
||||
def is_available(self):
|
||||
return True
|
||||
|
||||
def current_device(self):
|
||||
return 0
|
||||
|
||||
def _is_in_bad_fork(self):
|
||||
return False
|
||||
|
||||
def manual_seed_all(self, seed: int):
|
||||
pass
|
||||
|
||||
def device_count(self):
|
||||
return 1
|
||||
|
||||
|
||||
class _DummyPrivateUse1Hook(torch._C._acc.PrivateUse1Hooks):
|
||||
def is_available(self):
|
||||
return True
|
||||
|
||||
def has_primary_context(self, dev_id):
|
||||
return True
|
||||
|
||||
def is_built(self):
|
||||
return True
|
||||
|
||||
|
||||
class _DummyDeviceGuard(torch._C._acc.DeviceGuard):
|
||||
def type_(self):
|
||||
return torch._C._autograd.DeviceType.PrivateUse1
|
||||
|
||||
|
||||
def _setup_privateuseone_for_python_backend(
|
||||
rename=None, backend_module=None, hook=None, device_guard=None
|
||||
):
|
||||
"""This function will prepare the PrivateUse1 dispatch key to be used as a python backend.
|
||||
|
||||
WARNING: this API is experimental and might change without notice.
|
||||
|
||||
Formally, this registers things that Pytorch expects a registered backend
|
||||
in C++ to have: including device guards, hooks, and backend modules and what not.
|
||||
|
||||
after this call, one can use `torch.library` to write Ops for this dispatch key
|
||||
and expect it to behave like a backend registered in C++.
|
||||
|
||||
See the unit test at test/test_privateuseone_python_backend.py for more details.
|
||||
|
||||
Args:
|
||||
rename: str | None, if passed in, we will rename privateuseone backend to
|
||||
the name given.
|
||||
backend_module: object | None, if passed in None, we will use DummyBackendModule
|
||||
hook: object | None, if passed in None, we will use DummyPrivateUse1Hook
|
||||
device_guard: object | None, if passed in None, we will use DummyDeviceGuard
|
||||
"""
|
||||
# NOTE: the ordering of which these functions are called is important.
|
||||
if rename is not None:
|
||||
torch.utils.rename_privateuse1_backend(rename)
|
||||
else:
|
||||
rename = "privateuseone"
|
||||
torch.utils.generate_methods_for_privateuse1_backend()
|
||||
if backend_module is None:
|
||||
backend_module = _DummyBackendModule()
|
||||
if hook is None:
|
||||
hook = _DummyPrivateUse1Hook()
|
||||
if device_guard is None:
|
||||
device_guard = _DummyDeviceGuard()
|
||||
torch._register_device_module(rename, backend_module)
|
||||
torch._C._acc.register_python_privateuseone_hook(hook)
|
||||
torch._C._acc.register_python_privateuseone_device_guard(device_guard)
|
||||
|
Reference in New Issue
Block a user