Throw a nice error when SubTensor.__torch_dispatch__() returns the wrong type for detach()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77655

Approved by: https://github.com/albanD
This commit is contained in:
Joel Benjamin Schlosser
2022-05-18 15:54:39 -04:00
committed by PyTorch MergeBot
parent 8881d7ac6c
commit 0794d59d76
2 changed files with 42 additions and 1 deletions

View File

@ -11,6 +11,7 @@ from torch.testing._internal.common_utils import (
TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests)
from torch.testing._internal.common_subclass import subclass_db, DiagTensorBelow
from torch.testing._internal.logging_tensor import LoggingTensor
from torch.utils._pytree import tree_map
from unittest import expectedFailure
# The current test methodology in this file is to test a variety of real use cases
@ -204,6 +205,40 @@ class TestSubclass(TestCase):
self.assertFalse(m.has_uninitialized_params())
self.assertIsInstance(m.param, tensor_cls)
def test_non_rewrapping_torch_dispatch_subclass_as_parameter_throws_for_detach(self):
# Define a subclass that does not rewrap for any function in its __torch_dispatch__ impl.
class NonRewrappingTensor(torch.Tensor):
@staticmethod
def __new__(
cls, t: torch.Tensor
):
r = super(NonRewrappingTensor, cls)._make_wrapper_subclass(
cls, t.shape, dtype=t.dtype, requires_grad=t.requires_grad, device=t.device)
return r
def __init__(self, t) -> None:
self.tensor: torch.Tensor = t
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e) -> torch.Tensor:
if isinstance(e, NonRewrappingTensor):
t = e.tensor
return t
else:
return e
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
# Return an unwrapped tensor no longer of original subclass type.
return r
with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"):
param = nn.Parameter(NonRewrappingTensor(torch.randn(3)))
instantiate_parametrized_tests(TestSubclass)
if __name__ == '__main__':

View File

@ -31,13 +31,19 @@ class Parameter(torch.Tensor, metaclass=_ParameterMeta):
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.empty(0)
if type(data) is torch.Tensor:
if type(data) is torch.Tensor or type(data) is Parameter:
# For ease of BC maintenance, keep this path for standard Tensor.
# Eventually (tm), we should change the behavior for standard Tensor to match.
return torch.Tensor._make_subclass(cls, data, requires_grad)
# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
t = data.detach().requires_grad_(requires_grad)
if type(t) is not type(data):
raise RuntimeError(f"Creating a Parameter from an instance of type {type(data).__name__} "
"requires that detach() returns an instance of the same type, but return "
f"type {type(t).__name__} was found instead. To use the type as a "
"Parameter, please correct the detach() semantics defined by "
"its __torch_dispatch__() implementation.")
t._is_param = True
return t