mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Extend __new__ on subclasses to set custom_device and custom_strides
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77970 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
678213ead2
commit
98e0816986
@ -7,7 +7,6 @@ from torch.testing._internal.jit_utils import RUN_CUDA
|
||||
import unittest
|
||||
from torch._subclasses import FakeTensor
|
||||
|
||||
|
||||
class FakeTensorTest(TestCase):
|
||||
def test_basic(self):
|
||||
x = FakeTensor.from_tensor(torch.empty(2, 2, device="cpu"))
|
||||
@ -39,6 +38,10 @@ class FakeTensorTest(TestCase):
|
||||
z = FakeTensor.from_tensor(torch.rand([4, 4], device="cpu"))
|
||||
self.assertRaises(Exception, lambda: torch.lerp(x, y, z))
|
||||
|
||||
def test_dispatch_device(self):
|
||||
x = FakeTensor.from_tensor(torch.rand([4, 4]))
|
||||
self.assertEqual(x.device.type, "cpu")
|
||||
|
||||
|
||||
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
|
||||
return maybe_contained_type.isSubtypeOf(type) or any(
|
||||
|
@ -614,7 +614,8 @@ def gen_pyi(
|
||||
],
|
||||
"as_subclass": ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
|
||||
"_make_subclass": [
|
||||
"def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."
|
||||
"def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False,"
|
||||
" dispatch_device: _bool=False) -> Tensor: ..."
|
||||
],
|
||||
"__getitem__": ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
|
||||
"__setitem__": [
|
||||
|
@ -1,10 +1,8 @@
|
||||
import torch
|
||||
|
||||
from torch._subclasses.base_tensor import BaseTensor
|
||||
from torch._subclasses.fake_tensor import FakeTensor, _device_not_kwarg_ops
|
||||
|
||||
__all__ = [
|
||||
"BaseTensor",
|
||||
"FakeTensor",
|
||||
"_device_not_kwarg_ops",
|
||||
]
|
||||
|
@ -1,19 +0,0 @@
|
||||
import torch
|
||||
|
||||
# Ideally, tensor subclasses would would inherit directly from Tensor.
|
||||
# This is just our staging ground for applying behavior that hasn't yet made it
|
||||
# into the core Tensor class but that we would like to apply by default.
|
||||
class BaseTensor(torch.Tensor):
|
||||
# See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
|
||||
# to ensure that super().__new__ can cooperate with each other
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *, requires_grad=None):
|
||||
if requires_grad is None:
|
||||
return super().__new__(cls, elem) # type: ignore[call-arg]
|
||||
else:
|
||||
return cls._make_subclass(cls, elem, requires_grad)
|
||||
|
||||
# If __torch_dispatch__ is defined (which it will be for all our examples)
|
||||
# the default torch function implementation (which preserves subclasses)
|
||||
# typically must be disabled
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
|
||||
from torch._subclasses import BaseTensor
|
||||
from torch.utils._pytree import tree_map
|
||||
from functools import partial
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
@ -27,12 +26,12 @@ _device_not_kwarg_ops = (
|
||||
# which tracks devices that would have been used.
|
||||
|
||||
|
||||
class FakeTensor(BaseTensor):
|
||||
class FakeTensor(torch.Tensor):
|
||||
fake_device: torch.device
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, device):
|
||||
return super().__new__(cls, elem)
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad, dispatch_device=True)
|
||||
|
||||
def __init__(self, elem, device: Union[torch.device, str]):
|
||||
# elem does not need to be recorded, because FakeTensor *is a* elem
|
||||
@ -46,16 +45,22 @@ class FakeTensor(BaseTensor):
|
||||
existing_device = t.device
|
||||
return FakeTensor(t.to(device="meta"), existing_device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.fake_device
|
||||
# TODO: resolve error in default __repr__
|
||||
def __repr__(self):
|
||||
return f"FakeTensor({self.fake_device})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
# This classes virtualizes .device() calls, need to short-circuit
|
||||
# it insteead of calling device again or we would keep on recurring
|
||||
if func == torch.ops.prim.device.default:
|
||||
assert len(args) == 1 and isinstance(args[0], FakeTensor)
|
||||
return args[0].fake_device
|
||||
|
||||
# Run the original computation
|
||||
|
||||
r = super().__torch_dispatch__(func, types, args, kwargs)
|
||||
|
||||
def wrap(e, device):
|
||||
@ -140,3 +145,5 @@ class FakeTensor(BaseTensor):
|
||||
assert common_device is not None, f"Could not find common device for {func}"
|
||||
|
||||
return common_device
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
Reference in New Issue
Block a user