mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Fixes #156052 and #156444. This PR setup the privateuseone key in Python to be used as a python backend for pytorch. Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor. Changes done in this PR: 1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine. 2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops. 3. python function that invokes the other incantations to setup the privateusekey backend. This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859 Approved by: https://github.com/albanD
		
			
				
	
	
		
			148 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			148 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: PrivateUse1"]
 | 
						|
import numpy as np
 | 
						|
 | 
						|
import torch
 | 
						|
import torch._C
 | 
						|
from torch.testing._internal.common_utils import run_tests, TestCase
 | 
						|
from torch.utils.backend_registration import _setup_privateuseone_for_python_backend
 | 
						|
 | 
						|
 | 
						|
_setup_privateuseone_for_python_backend("npy")
 | 
						|
 | 
						|
aten = torch.ops.aten
 | 
						|
 | 
						|
 | 
						|
# NOTE: From https://github.com/albanD/subclass_zoo/blob/main/new_device.py
 | 
						|
# but using torch.library instead of `__torch_dispatch__`
 | 
						|
class MyDeviceTensor(torch.Tensor):
 | 
						|
    @staticmethod
 | 
						|
    def __new__(cls, size, dtype, raw_data=None, requires_grad=False):
 | 
						|
        # Use a meta Tensor here to be used as the wrapper
 | 
						|
        res = torch._C._acc.create_empty_tensor(size, dtype)
 | 
						|
        res.__class__ = MyDeviceTensor
 | 
						|
        return res
 | 
						|
 | 
						|
    def __init__(self, size, dtype, raw_data=None, requires_grad=False):
 | 
						|
        # Store any provided user raw_data
 | 
						|
        self.raw_data = raw_data
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return "MyDeviceTensor" + str(self.raw_data)
 | 
						|
 | 
						|
    __str__ = __repr__
 | 
						|
 | 
						|
 | 
						|
def wrap(arr, shape, dtype):
 | 
						|
    # hard code float32 for tests
 | 
						|
    return MyDeviceTensor(shape, dtype, arr)
 | 
						|
 | 
						|
 | 
						|
def unwrap(arr):
 | 
						|
    return arr.raw_data
 | 
						|
 | 
						|
 | 
						|
# Add some ops
 | 
						|
@torch.library.impl("aten::add.Tensor", "privateuseone")
 | 
						|
def add(t1, t2):
 | 
						|
    out = unwrap(t1) + unwrap(t2)
 | 
						|
    return wrap(out, out.shape, torch.float32)
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::mul.Tensor", "privateuseone")
 | 
						|
def mul(t1, t2):
 | 
						|
    # If unsure what should be the result's properties, you can
 | 
						|
    # use the super_fn (can be useful for type promotion)
 | 
						|
    out = unwrap(t1) * unwrap(t2)
 | 
						|
    return wrap(out, out.shape, torch.float32)
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::detach", "privateuseone")
 | 
						|
def detach(self):
 | 
						|
    out = unwrap(self)
 | 
						|
    return wrap(out, out.shape, torch.float32)
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::empty_strided", "privateuseone")
 | 
						|
def empty_strided(
 | 
						|
    size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
 | 
						|
):
 | 
						|
    out = np.empty(size)
 | 
						|
    return wrap(out, out.shape, torch.float32)
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::_copy_from", "privateuseone")
 | 
						|
def _copy_from(a, b):
 | 
						|
    if a.device.type == "npy":
 | 
						|
        npy_data = unwrap(a)
 | 
						|
    else:
 | 
						|
        npy_data = a.numpy()
 | 
						|
    b.raw_data = npy_data
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::view", "privateuseone")
 | 
						|
def _view(a, b):
 | 
						|
    ans = unwrap(a)
 | 
						|
    return wrap(ans, a.shape, a.dtype)
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::empty.memory_format", "privateuseone")
 | 
						|
def empty_memory_format(
 | 
						|
    size, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
 | 
						|
):
 | 
						|
    ans = np.empty(size)
 | 
						|
    return wrap(ans, ans.shape, torch.float32)
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::sum", "privateuseone")
 | 
						|
def sum_int_list(*args, **kwargs):
 | 
						|
    ans = unwrap(args[0]).sum()
 | 
						|
    return wrap(ans, ans.shape, torch.float32)
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::ones_like", "privateuseone")
 | 
						|
def ones_like(
 | 
						|
    self, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
 | 
						|
):
 | 
						|
    ans = np.ones_like(unwrap(self))
 | 
						|
    return wrap(ans, ans.shape, torch.float32)
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::expand", "privateuseone")
 | 
						|
def expand(self, size, *, implicit=False):
 | 
						|
    ans = np.broadcast_to(self.raw_data, size)
 | 
						|
    return wrap(ans, ans.shape, torch.float32)
 | 
						|
 | 
						|
 | 
						|
@torch.library.impl("aten::as_strided", "privateuseone")
 | 
						|
def as_strided(self, size, stride, storage_offset=None):
 | 
						|
    ans = np.lib.stride_tricks.as_strided(self.raw_data, size, stride)
 | 
						|
    return wrap(ans, ans.shape, torch.float32)
 | 
						|
 | 
						|
 | 
						|
class PrivateUse1BackendTest(TestCase):
 | 
						|
    @classmethod
 | 
						|
    def setupClass(cls):
 | 
						|
        pass
 | 
						|
 | 
						|
    def test_accessing_is_pinned(self):
 | 
						|
        a_cpu = torch.randn((2, 2))
 | 
						|
        # Assert this don't throw:
 | 
						|
        _ = a_cpu.is_pinned()
 | 
						|
 | 
						|
    def test_backend_simple(self):
 | 
						|
        a_cpu = torch.randn((2, 2))
 | 
						|
        b_cpu = torch.randn((2, 2))
 | 
						|
        # Assert this don't throw:
 | 
						|
        a = a_cpu.to("privateuseone")
 | 
						|
        b = b_cpu.to("privateuseone")
 | 
						|
 | 
						|
        a.requires_grad = True
 | 
						|
        b.requires_grad = True
 | 
						|
        c = (a + b).sum()
 | 
						|
        c.backward()
 | 
						|
        self.assertTrue(np.allclose(a.grad.raw_data, np.ones((2, 2))))
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    run_tests()
 |