mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make sure requires_grad is propagated for all backend
The if statement is not strictly necessary but that avoid having to call this function if we don't need it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76256 Approved by: https://github.com/ezyang, https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
920876d693
commit
e4d5801e36
@ -20,8 +20,8 @@ from torch._utils_internal import get_file_path_2
|
||||
from torch._utils import _rebuild_tensor
|
||||
from torch.serialization import check_module_version_greater_or_equal
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, \
|
||||
TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName
|
||||
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, TEST_DILL, \
|
||||
run_tests, download_file, BytesIOContext, TemporaryFileName, parametrize, instantiate_parametrized_tests
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
|
||||
@ -948,8 +948,18 @@ class TestSubclassSerialization(TestCase):
|
||||
self.assertEqual(new_tensor.elem, my_tensor.elem)
|
||||
self.assertEqual(new_tensor.foo, foo_val)
|
||||
|
||||
@parametrize('requires_grad', (True, False))
|
||||
def test_cloned_deepcopy(self, requires_grad):
|
||||
my_tensor = torch.rand(2, requires_grad=requires_grad, device='meta')
|
||||
|
||||
new_tensor = deepcopy(my_tensor)
|
||||
|
||||
self.assertEqual(new_tensor.requires_grad, my_tensor.requires_grad)
|
||||
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestBothSerialization, globals())
|
||||
instantiate_parametrized_tests(TestSubclassSerialization)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -95,6 +95,7 @@ class Tensor(torch._C._TensorBase):
|
||||
# does accurate alias tracking; however, the code below
|
||||
# doesn't work because of
|
||||
# https://github.com/pytorch/pytorch/issues/47442
|
||||
# Update the test in test_serialization if you remove 'meta' from here
|
||||
if self.is_sparse or self.device.type in ['lazy', 'xla', 'mlc', 'ort', 'meta', 'hpu'] or \
|
||||
(type(self) is not Tensor and self.data_ptr() == 0):
|
||||
new_tensor = self.clone()
|
||||
@ -150,7 +151,8 @@ class Tensor(torch._C._TensorBase):
|
||||
new_tensor = new_tensor.conj_physical()
|
||||
if self.is_neg():
|
||||
new_tensor = new_tensor.neg()
|
||||
new_tensor.requires_grad = self.requires_grad
|
||||
if self.requires_grad:
|
||||
new_tensor.requires_grad_()
|
||||
if self.grad is not None:
|
||||
new_tensor.grad = self.grad.__deepcopy__(memo)
|
||||
|
||||
|
Reference in New Issue
Block a user