mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Add a test to ensure grads are never inplaced into accidentally (#143612)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143612 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
2daa666591
commit
4e29e4aa63
@ -1331,6 +1331,44 @@ class TestOptimRenewed(TestCase):
|
||||
optimizer.step(closure)
|
||||
self.assertEqual(old_param, params[0])
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_grads_are_never_inplaced_into(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
|
||||
device, dtype, optim_info
|
||||
)
|
||||
param = torch.randn((5, 1), device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
def closure():
|
||||
return torch.tensor([1], device=device, dtype=dtype)
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
|
||||
if kwargs.get("differentiable", False):
|
||||
params = [param.clone()]
|
||||
else:
|
||||
params = [param]
|
||||
|
||||
optimizer = optim_cls(params, **kwargs)
|
||||
if optim_info.only_supports_sparse_grads:
|
||||
# Intentionally construct a multidimensional empty v for the sparse grad
|
||||
# Single dim v passes the test while multidim correctly repros the issue
|
||||
# https://github.com/pytorch/pytorch/issues/82486
|
||||
i = torch.empty((1, 0), device=device, dtype=dtype)
|
||||
v = torch.empty((0, 1), device=device, dtype=dtype)
|
||||
params[0].grad = torch.sparse_coo_tensor(
|
||||
i, v, (5, 1), device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
params[0].grad = torch.rand_like(params[0])
|
||||
|
||||
old_version = params[0].grad._version
|
||||
|
||||
for _ in range(5):
|
||||
optimizer.step(closure)
|
||||
self.assertEqual(params[0].grad._version, old_version)
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_optimizer_can_be_printed(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
Reference in New Issue
Block a user