mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Preserve Python dispatch keys upon copy_tensor_metadata_except_version_counter
Whether or not this is a reasonable operation to do in the presence of subclasses is a good question in and of itself, but this fixes an obvious invariant violation, which is that if a Tensor reports that it is a tensor subclass, it had better have the Python dispatch key. Previously, the dispatch key would have gotten unconditionally cleared; now we preserve what ever the original bit was. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/75644 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
b09769992f
commit
2772870860
@ -706,6 +706,32 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
x.neg()
|
||||
self.assertEqual(called, [torch.ops.aten.neg.default])
|
||||
|
||||
def test_set_data(self):
|
||||
called = 0
|
||||
|
||||
class SubTensor(torch.Tensor):
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
nonlocal called
|
||||
called += 1
|
||||
return super().__torch_dispatch__(func, types, args, kwargs)
|
||||
|
||||
x = SubTensor(torch.empty(2))
|
||||
x.data
|
||||
self.assertEqual(called, 1)
|
||||
x.data = torch.empty(2)
|
||||
self.assertEqual(called, 1)
|
||||
x.data
|
||||
self.assertEqual(called, 2)
|
||||
self.assertIs(type(x), SubTensor)
|
||||
x.set_(torch.empty(2))
|
||||
self.assertEqual(called, 3)
|
||||
x.data
|
||||
self.assertEqual(called, 4)
|
||||
self.assertIs(type(x), SubTensor)
|
||||
|
||||
def test_construct_int_tensor(self):
|
||||
class SubTensor(torch.Tensor):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user