Make sure requires_grad is propagated for all backend

The if statement is not strictly necessary but that avoid having to call this function if we don't need it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76256
Approved by: https://github.com/ezyang, https://github.com/soulitzer
This commit is contained in:
Alban Desmaison
2022-04-25 19:31:24 +00:00
committed by PyTorch MergeBot
parent 920876d693
commit e4d5801e36
2 changed files with 15 additions and 3 deletions

View File

@ -20,8 +20,8 @@ from torch._utils_internal import get_file_path_2
from torch._utils import _rebuild_tensor from torch._utils import _rebuild_tensor
from torch.serialization import check_module_version_greater_or_equal from torch.serialization import check_module_version_greater_or_equal
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, \ from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, TEST_DILL, \
TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName run_tests, download_file, BytesIOContext, TemporaryFileName, parametrize, instantiate_parametrized_tests
from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_dtype import all_types_and_complex_and from torch.testing._internal.common_dtype import all_types_and_complex_and
@ -948,8 +948,18 @@ class TestSubclassSerialization(TestCase):
self.assertEqual(new_tensor.elem, my_tensor.elem) self.assertEqual(new_tensor.elem, my_tensor.elem)
self.assertEqual(new_tensor.foo, foo_val) self.assertEqual(new_tensor.foo, foo_val)
@parametrize('requires_grad', (True, False))
def test_cloned_deepcopy(self, requires_grad):
my_tensor = torch.rand(2, requires_grad=requires_grad, device='meta')
new_tensor = deepcopy(my_tensor)
self.assertEqual(new_tensor.requires_grad, my_tensor.requires_grad)
instantiate_device_type_tests(TestBothSerialization, globals()) instantiate_device_type_tests(TestBothSerialization, globals())
instantiate_parametrized_tests(TestSubclassSerialization)
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()

View File

@ -95,6 +95,7 @@ class Tensor(torch._C._TensorBase):
# does accurate alias tracking; however, the code below # does accurate alias tracking; however, the code below
# doesn't work because of # doesn't work because of
# https://github.com/pytorch/pytorch/issues/47442 # https://github.com/pytorch/pytorch/issues/47442
# Update the test in test_serialization if you remove 'meta' from here
if self.is_sparse or self.device.type in ['lazy', 'xla', 'mlc', 'ort', 'meta', 'hpu'] or \ if self.is_sparse or self.device.type in ['lazy', 'xla', 'mlc', 'ort', 'meta', 'hpu'] or \
(type(self) is not Tensor and self.data_ptr() == 0): (type(self) is not Tensor and self.data_ptr() == 0):
new_tensor = self.clone() new_tensor = self.clone()
@ -150,7 +151,8 @@ class Tensor(torch._C._TensorBase):
new_tensor = new_tensor.conj_physical() new_tensor = new_tensor.conj_physical()
if self.is_neg(): if self.is_neg():
new_tensor = new_tensor.neg() new_tensor = new_tensor.neg()
new_tensor.requires_grad = self.requires_grad if self.requires_grad:
new_tensor.requires_grad_()
if self.grad is not None: if self.grad is not None:
new_tensor.grad = self.grad.__deepcopy__(memo) new_tensor.grad = self.grad.__deepcopy__(memo)