mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make sure requires_grad is propagated for all backend
The if statement is not strictly necessary but that avoid having to call this function if we don't need it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76256 Approved by: https://github.com/ezyang, https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
920876d693
commit
e4d5801e36
@ -20,8 +20,8 @@ from torch._utils_internal import get_file_path_2
|
|||||||
from torch._utils import _rebuild_tensor
|
from torch._utils import _rebuild_tensor
|
||||||
from torch.serialization import check_module_version_greater_or_equal
|
from torch.serialization import check_module_version_greater_or_equal
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, \
|
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, TEST_DILL, \
|
||||||
TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName
|
run_tests, download_file, BytesIOContext, TemporaryFileName, parametrize, instantiate_parametrized_tests
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||||
|
|
||||||
@ -948,8 +948,18 @@ class TestSubclassSerialization(TestCase):
|
|||||||
self.assertEqual(new_tensor.elem, my_tensor.elem)
|
self.assertEqual(new_tensor.elem, my_tensor.elem)
|
||||||
self.assertEqual(new_tensor.foo, foo_val)
|
self.assertEqual(new_tensor.foo, foo_val)
|
||||||
|
|
||||||
|
@parametrize('requires_grad', (True, False))
|
||||||
|
def test_cloned_deepcopy(self, requires_grad):
|
||||||
|
my_tensor = torch.rand(2, requires_grad=requires_grad, device='meta')
|
||||||
|
|
||||||
|
new_tensor = deepcopy(my_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(new_tensor.requires_grad, my_tensor.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestBothSerialization, globals())
|
instantiate_device_type_tests(TestBothSerialization, globals())
|
||||||
|
instantiate_parametrized_tests(TestSubclassSerialization)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -95,6 +95,7 @@ class Tensor(torch._C._TensorBase):
|
|||||||
# does accurate alias tracking; however, the code below
|
# does accurate alias tracking; however, the code below
|
||||||
# doesn't work because of
|
# doesn't work because of
|
||||||
# https://github.com/pytorch/pytorch/issues/47442
|
# https://github.com/pytorch/pytorch/issues/47442
|
||||||
|
# Update the test in test_serialization if you remove 'meta' from here
|
||||||
if self.is_sparse or self.device.type in ['lazy', 'xla', 'mlc', 'ort', 'meta', 'hpu'] or \
|
if self.is_sparse or self.device.type in ['lazy', 'xla', 'mlc', 'ort', 'meta', 'hpu'] or \
|
||||||
(type(self) is not Tensor and self.data_ptr() == 0):
|
(type(self) is not Tensor and self.data_ptr() == 0):
|
||||||
new_tensor = self.clone()
|
new_tensor = self.clone()
|
||||||
@ -150,7 +151,8 @@ class Tensor(torch._C._TensorBase):
|
|||||||
new_tensor = new_tensor.conj_physical()
|
new_tensor = new_tensor.conj_physical()
|
||||||
if self.is_neg():
|
if self.is_neg():
|
||||||
new_tensor = new_tensor.neg()
|
new_tensor = new_tensor.neg()
|
||||||
new_tensor.requires_grad = self.requires_grad
|
if self.requires_grad:
|
||||||
|
new_tensor.requires_grad_()
|
||||||
if self.grad is not None:
|
if self.grad is not None:
|
||||||
new_tensor.grad = self.grad.__deepcopy__(memo)
|
new_tensor.grad = self.grad.__deepcopy__(memo)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user