Extend __new__ on subclasses to set custom_device and custom_strides

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77970

Approved by: https://github.com/Chillee
This commit is contained in:
Elias Ellison
2022-05-31 07:02:18 -07:00
committed by PyTorch MergeBot
parent 678213ead2
commit 98e0816986
5 changed files with 19 additions and 29 deletions

View File

@ -7,7 +7,6 @@ from torch.testing._internal.jit_utils import RUN_CUDA
import unittest
from torch._subclasses import FakeTensor
class FakeTensorTest(TestCase):
def test_basic(self):
x = FakeTensor.from_tensor(torch.empty(2, 2, device="cpu"))
@ -39,6 +38,10 @@ class FakeTensorTest(TestCase):
z = FakeTensor.from_tensor(torch.rand([4, 4], device="cpu"))
self.assertRaises(Exception, lambda: torch.lerp(x, y, z))
def test_dispatch_device(self):
x = FakeTensor.from_tensor(torch.rand([4, 4]))
self.assertEqual(x.device.type, "cpu")
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(

View File

@ -614,7 +614,8 @@ def gen_pyi(
],
"as_subclass": ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
"_make_subclass": [
"def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."
"def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False,"
" dispatch_device: _bool=False) -> Tensor: ..."
],
"__getitem__": ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
"__setitem__": [

View File

@ -1,10 +1,8 @@
import torch
from torch._subclasses.base_tensor import BaseTensor
from torch._subclasses.fake_tensor import FakeTensor, _device_not_kwarg_ops
__all__ = [
"BaseTensor",
"FakeTensor",
"_device_not_kwarg_ops",
]

View File

@ -1,19 +0,0 @@
import torch
# Ideally, tensor subclasses would would inherit directly from Tensor.
# This is just our staging ground for applying behavior that hasn't yet made it
# into the core Tensor class but that we would like to apply by default.
class BaseTensor(torch.Tensor):
# See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
# to ensure that super().__new__ can cooperate with each other
@staticmethod
def __new__(cls, elem, *, requires_grad=None):
if requires_grad is None:
return super().__new__(cls, elem) # type: ignore[call-arg]
else:
return cls._make_subclass(cls, elem, requires_grad)
# If __torch_dispatch__ is defined (which it will be for all our examples)
# the default torch function implementation (which preserves subclasses)
# typically must be disabled
__torch_function__ = torch._C._disabled_torch_function_impl

View File

@ -1,6 +1,5 @@
import torch
from torch._subclasses import BaseTensor
from torch.utils._pytree import tree_map
from functools import partial
from torch.fx.operator_schemas import normalize_function
@ -27,12 +26,12 @@ _device_not_kwarg_ops = (
# which tracks devices that would have been used.
class FakeTensor(BaseTensor):
class FakeTensor(torch.Tensor):
fake_device: torch.device
@staticmethod
def __new__(cls, elem, device):
return super().__new__(cls, elem)
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad, dispatch_device=True)
def __init__(self, elem, device: Union[torch.device, str]):
# elem does not need to be recorded, because FakeTensor *is a* elem
@ -46,16 +45,22 @@ class FakeTensor(BaseTensor):
existing_device = t.device
return FakeTensor(t.to(device="meta"), existing_device)
@property
def device(self):
return self.fake_device
# TODO: resolve error in default __repr__
def __repr__(self):
return f"FakeTensor({self.fake_device})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
# This classes virtualizes .device() calls, need to short-circuit
# it insteead of calling device again or we would keep on recurring
if func == torch.ops.prim.device.default:
assert len(args) == 1 and isinstance(args[0], FakeTensor)
return args[0].fake_device
# Run the original computation
r = super().__torch_dispatch__(func, types, args, kwargs)
def wrap(e, device):
@ -140,3 +145,5 @@ class FakeTensor(BaseTensor):
assert common_device is not None, f"Could not find common device for {func}"
return common_device
__torch_function__ = torch._C._disabled_torch_function_impl