Files
pytorch/test/test_privateuseone_python_backend.py
Han Qi 1310d6a1f9 Add functions to setup PrivateUse1 as a python backend device. (#157859)
Fixes #156052 and #156444.

This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.

Changes done in this PR:

1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.

This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
2025-09-30 08:39:36 +00:00

148 lines
4.0 KiB
Python

# Owner(s): ["module: PrivateUse1"]
import numpy as np
import torch
import torch._C
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils.backend_registration import _setup_privateuseone_for_python_backend
_setup_privateuseone_for_python_backend("npy")
aten = torch.ops.aten
# NOTE: From https://github.com/albanD/subclass_zoo/blob/main/new_device.py
# but using torch.library instead of `__torch_dispatch__`
class MyDeviceTensor(torch.Tensor):
@staticmethod
def __new__(cls, size, dtype, raw_data=None, requires_grad=False):
# Use a meta Tensor here to be used as the wrapper
res = torch._C._acc.create_empty_tensor(size, dtype)
res.__class__ = MyDeviceTensor
return res
def __init__(self, size, dtype, raw_data=None, requires_grad=False):
# Store any provided user raw_data
self.raw_data = raw_data
def __repr__(self):
return "MyDeviceTensor" + str(self.raw_data)
__str__ = __repr__
def wrap(arr, shape, dtype):
# hard code float32 for tests
return MyDeviceTensor(shape, dtype, arr)
def unwrap(arr):
return arr.raw_data
# Add some ops
@torch.library.impl("aten::add.Tensor", "privateuseone")
def add(t1, t2):
out = unwrap(t1) + unwrap(t2)
return wrap(out, out.shape, torch.float32)
@torch.library.impl("aten::mul.Tensor", "privateuseone")
def mul(t1, t2):
# If unsure what should be the result's properties, you can
# use the super_fn (can be useful for type promotion)
out = unwrap(t1) * unwrap(t2)
return wrap(out, out.shape, torch.float32)
@torch.library.impl("aten::detach", "privateuseone")
def detach(self):
out = unwrap(self)
return wrap(out, out.shape, torch.float32)
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(
size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
):
out = np.empty(size)
return wrap(out, out.shape, torch.float32)
@torch.library.impl("aten::_copy_from", "privateuseone")
def _copy_from(a, b):
if a.device.type == "npy":
npy_data = unwrap(a)
else:
npy_data = a.numpy()
b.raw_data = npy_data
@torch.library.impl("aten::view", "privateuseone")
def _view(a, b):
ans = unwrap(a)
return wrap(ans, a.shape, a.dtype)
@torch.library.impl("aten::empty.memory_format", "privateuseone")
def empty_memory_format(
size, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
):
ans = np.empty(size)
return wrap(ans, ans.shape, torch.float32)
@torch.library.impl("aten::sum", "privateuseone")
def sum_int_list(*args, **kwargs):
ans = unwrap(args[0]).sum()
return wrap(ans, ans.shape, torch.float32)
@torch.library.impl("aten::ones_like", "privateuseone")
def ones_like(
self, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
):
ans = np.ones_like(unwrap(self))
return wrap(ans, ans.shape, torch.float32)
@torch.library.impl("aten::expand", "privateuseone")
def expand(self, size, *, implicit=False):
ans = np.broadcast_to(self.raw_data, size)
return wrap(ans, ans.shape, torch.float32)
@torch.library.impl("aten::as_strided", "privateuseone")
def as_strided(self, size, stride, storage_offset=None):
ans = np.lib.stride_tricks.as_strided(self.raw_data, size, stride)
return wrap(ans, ans.shape, torch.float32)
class PrivateUse1BackendTest(TestCase):
@classmethod
def setupClass(cls):
pass
def test_accessing_is_pinned(self):
a_cpu = torch.randn((2, 2))
# Assert this don't throw:
_ = a_cpu.is_pinned()
def test_backend_simple(self):
a_cpu = torch.randn((2, 2))
b_cpu = torch.randn((2, 2))
# Assert this don't throw:
a = a_cpu.to("privateuseone")
b = b_cpu.to("privateuseone")
a.requires_grad = True
b.requires_grad = True
c = (a + b).sum()
c.backward()
self.assertTrue(np.allclose(a.grad.raw_data, np.ones((2, 2))))
if __name__ == "__main__":
run_tests()