Add functions to setup PrivateUse1 as a python backend device. (#157859)

Fixes #156052 and #156444.

This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.

Changes done in this PR:

1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.

This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
This commit is contained in:
Han Qi
2025-10-01 21:32:55 +00:00
committed by PyTorch MergeBot
parent 773c6762b8
commit b5c4f46bb9
13 changed files with 468 additions and 8 deletions

View File

@ -6,7 +6,10 @@ from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend
from torch.overrides import handle_torch_function, has_torch_function_unary
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
__all__ = [
"rename_privateuse1_backend",
"generate_methods_for_privateuse1_backend",
]
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
# renamed-backend name for `privateuse1`, but the func will cause an
@ -438,3 +441,78 @@ def _get_custom_mod_func(func_name: str):
message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
raise RuntimeError(message)
return function
class _DummyBackendModule:
def is_initialized(self):
return True
def is_available(self):
return True
def current_device(self):
return 0
def _is_in_bad_fork(self):
return False
def manual_seed_all(self, seed: int):
pass
def device_count(self):
return 1
class _DummyPrivateUse1Hook(torch._C._acc.PrivateUse1Hooks):
def is_available(self):
return True
def has_primary_context(self, dev_id):
return True
def is_built(self):
return True
class _DummyDeviceGuard(torch._C._acc.DeviceGuard):
def type_(self):
return torch._C._autograd.DeviceType.PrivateUse1
def _setup_privateuseone_for_python_backend(
rename=None, backend_module=None, hook=None, device_guard=None
):
"""This function will prepare the PrivateUse1 dispatch key to be used as a python backend.
WARNING: this API is experimental and might change without notice.
Formally, this registers things that Pytorch expects a registered backend
in C++ to have: including device guards, hooks, and backend modules and what not.
after this call, one can use `torch.library` to write Ops for this dispatch key
and expect it to behave like a backend registered in C++.
See the unit test at test/test_privateuseone_python_backend.py for more details.
Args:
rename: str | None, if passed in, we will rename privateuseone backend to
the name given.
backend_module: object | None, if passed in None, we will use DummyBackendModule
hook: object | None, if passed in None, we will use DummyPrivateUse1Hook
device_guard: object | None, if passed in None, we will use DummyDeviceGuard
"""
# NOTE: the ordering of which these functions are called is important.
if rename is not None:
torch.utils.rename_privateuse1_backend(rename)
else:
rename = "privateuseone"
torch.utils.generate_methods_for_privateuse1_backend()
if backend_module is None:
backend_module = _DummyBackendModule()
if hook is None:
hook = _DummyPrivateUse1Hook()
if device_guard is None:
device_guard = _DummyDeviceGuard()
torch._register_device_module(rename, backend_module)
torch._C._acc.register_python_privateuseone_hook(hook)
torch._C._acc.register_python_privateuseone_device_guard(device_guard)