mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Throw a nice error when SubTensor.__torch_dispatch__() returns the wrong type for detach()
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77655 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
8881d7ac6c
commit
0794d59d76
@ -11,6 +11,7 @@ from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests)
|
||||
from torch.testing._internal.common_subclass import subclass_db, DiagTensorBelow
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor
|
||||
from torch.utils._pytree import tree_map
|
||||
from unittest import expectedFailure
|
||||
|
||||
# The current test methodology in this file is to test a variety of real use cases
|
||||
@ -204,6 +205,40 @@ class TestSubclass(TestCase):
|
||||
self.assertFalse(m.has_uninitialized_params())
|
||||
self.assertIsInstance(m.param, tensor_cls)
|
||||
|
||||
def test_non_rewrapping_torch_dispatch_subclass_as_parameter_throws_for_detach(self):
|
||||
|
||||
# Define a subclass that does not rewrap for any function in its __torch_dispatch__ impl.
|
||||
class NonRewrappingTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(
|
||||
cls, t: torch.Tensor
|
||||
):
|
||||
r = super(NonRewrappingTensor, cls)._make_wrapper_subclass(
|
||||
cls, t.shape, dtype=t.dtype, requires_grad=t.requires_grad, device=t.device)
|
||||
return r
|
||||
|
||||
def __init__(self, t) -> None:
|
||||
self.tensor: torch.Tensor = t
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
def unwrap(e) -> torch.Tensor:
|
||||
if isinstance(e, NonRewrappingTensor):
|
||||
t = e.tensor
|
||||
return t
|
||||
else:
|
||||
return e
|
||||
|
||||
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
# Return an unwrapped tensor no longer of original subclass type.
|
||||
return r
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"):
|
||||
param = nn.Parameter(NonRewrappingTensor(torch.randn(3)))
|
||||
|
||||
instantiate_parametrized_tests(TestSubclass)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -31,13 +31,19 @@ class Parameter(torch.Tensor, metaclass=_ParameterMeta):
|
||||
def __new__(cls, data=None, requires_grad=True):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
if type(data) is torch.Tensor:
|
||||
if type(data) is torch.Tensor or type(data) is Parameter:
|
||||
# For ease of BC maintenance, keep this path for standard Tensor.
|
||||
# Eventually (tm), we should change the behavior for standard Tensor to match.
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
|
||||
t = data.detach().requires_grad_(requires_grad)
|
||||
if type(t) is not type(data):
|
||||
raise RuntimeError(f"Creating a Parameter from an instance of type {type(data).__name__} "
|
||||
"requires that detach() returns an instance of the same type, but return "
|
||||
f"type {type(t).__name__} was found instead. To use the type as a "
|
||||
"Parameter, please correct the detach() semantics defined by "
|
||||
"its __torch_dispatch__() implementation.")
|
||||
t._is_param = True
|
||||
return t
|
||||
|
||||
|
Reference in New Issue
Block a user