mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 16:14:55 +08:00 
			
		
		
		
	Fixes #156052 and #156444. This PR setup the privateuseone key in Python to be used as a python backend for pytorch. Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor. Changes done in this PR: 1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine. 2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops. 3. python function that invokes the other incantations to setup the privateusekey backend. This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859 Approved by: https://github.com/albanD
		
			
				
	
	
		
			148 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			148 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: PrivateUse1"]
 | |
| import numpy as np
 | |
| 
 | |
| import torch
 | |
| import torch._C
 | |
| from torch.testing._internal.common_utils import run_tests, TestCase
 | |
| from torch.utils.backend_registration import _setup_privateuseone_for_python_backend
 | |
| 
 | |
| 
 | |
| _setup_privateuseone_for_python_backend("npy")
 | |
| 
 | |
| aten = torch.ops.aten
 | |
| 
 | |
| 
 | |
| # NOTE: From https://github.com/albanD/subclass_zoo/blob/main/new_device.py
 | |
| # but using torch.library instead of `__torch_dispatch__`
 | |
| class MyDeviceTensor(torch.Tensor):
 | |
|     @staticmethod
 | |
|     def __new__(cls, size, dtype, raw_data=None, requires_grad=False):
 | |
|         # Use a meta Tensor here to be used as the wrapper
 | |
|         res = torch._C._acc.create_empty_tensor(size, dtype)
 | |
|         res.__class__ = MyDeviceTensor
 | |
|         return res
 | |
| 
 | |
|     def __init__(self, size, dtype, raw_data=None, requires_grad=False):
 | |
|         # Store any provided user raw_data
 | |
|         self.raw_data = raw_data
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return "MyDeviceTensor" + str(self.raw_data)
 | |
| 
 | |
|     __str__ = __repr__
 | |
| 
 | |
| 
 | |
| def wrap(arr, shape, dtype):
 | |
|     # hard code float32 for tests
 | |
|     return MyDeviceTensor(shape, dtype, arr)
 | |
| 
 | |
| 
 | |
| def unwrap(arr):
 | |
|     return arr.raw_data
 | |
| 
 | |
| 
 | |
| # Add some ops
 | |
| @torch.library.impl("aten::add.Tensor", "privateuseone")
 | |
| def add(t1, t2):
 | |
|     out = unwrap(t1) + unwrap(t2)
 | |
|     return wrap(out, out.shape, torch.float32)
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::mul.Tensor", "privateuseone")
 | |
| def mul(t1, t2):
 | |
|     # If unsure what should be the result's properties, you can
 | |
|     # use the super_fn (can be useful for type promotion)
 | |
|     out = unwrap(t1) * unwrap(t2)
 | |
|     return wrap(out, out.shape, torch.float32)
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::detach", "privateuseone")
 | |
| def detach(self):
 | |
|     out = unwrap(self)
 | |
|     return wrap(out, out.shape, torch.float32)
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::empty_strided", "privateuseone")
 | |
| def empty_strided(
 | |
|     size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
 | |
| ):
 | |
|     out = np.empty(size)
 | |
|     return wrap(out, out.shape, torch.float32)
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::_copy_from", "privateuseone")
 | |
| def _copy_from(a, b):
 | |
|     if a.device.type == "npy":
 | |
|         npy_data = unwrap(a)
 | |
|     else:
 | |
|         npy_data = a.numpy()
 | |
|     b.raw_data = npy_data
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::view", "privateuseone")
 | |
| def _view(a, b):
 | |
|     ans = unwrap(a)
 | |
|     return wrap(ans, a.shape, a.dtype)
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::empty.memory_format", "privateuseone")
 | |
| def empty_memory_format(
 | |
|     size, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
 | |
| ):
 | |
|     ans = np.empty(size)
 | |
|     return wrap(ans, ans.shape, torch.float32)
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::sum", "privateuseone")
 | |
| def sum_int_list(*args, **kwargs):
 | |
|     ans = unwrap(args[0]).sum()
 | |
|     return wrap(ans, ans.shape, torch.float32)
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::ones_like", "privateuseone")
 | |
| def ones_like(
 | |
|     self, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
 | |
| ):
 | |
|     ans = np.ones_like(unwrap(self))
 | |
|     return wrap(ans, ans.shape, torch.float32)
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::expand", "privateuseone")
 | |
| def expand(self, size, *, implicit=False):
 | |
|     ans = np.broadcast_to(self.raw_data, size)
 | |
|     return wrap(ans, ans.shape, torch.float32)
 | |
| 
 | |
| 
 | |
| @torch.library.impl("aten::as_strided", "privateuseone")
 | |
| def as_strided(self, size, stride, storage_offset=None):
 | |
|     ans = np.lib.stride_tricks.as_strided(self.raw_data, size, stride)
 | |
|     return wrap(ans, ans.shape, torch.float32)
 | |
| 
 | |
| 
 | |
| class PrivateUse1BackendTest(TestCase):
 | |
|     @classmethod
 | |
|     def setupClass(cls):
 | |
|         pass
 | |
| 
 | |
|     def test_accessing_is_pinned(self):
 | |
|         a_cpu = torch.randn((2, 2))
 | |
|         # Assert this don't throw:
 | |
|         _ = a_cpu.is_pinned()
 | |
| 
 | |
|     def test_backend_simple(self):
 | |
|         a_cpu = torch.randn((2, 2))
 | |
|         b_cpu = torch.randn((2, 2))
 | |
|         # Assert this don't throw:
 | |
|         a = a_cpu.to("privateuseone")
 | |
|         b = b_cpu.to("privateuseone")
 | |
| 
 | |
|         a.requires_grad = True
 | |
|         b.requires_grad = True
 | |
|         c = (a + b).sum()
 | |
|         c.backward()
 | |
|         self.assertTrue(np.allclose(a.grad.raw_data, np.ones((2, 2))))
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_tests()
 |