mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix allow_mutation_on_saved_tensors for inplace foreach (#145520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145520 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
b4fe3c159d
commit
9e0ee152e5
@ -12489,6 +12489,18 @@ class TestAllowMutationOnSaved(TestCase):
|
||||
with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
|
||||
pass
|
||||
|
||||
def test_inplace_foreach(self):
|
||||
with torch.autograd.graph.allow_mutation_on_saved_tensors():
|
||||
a = [
|
||||
torch.tensor(1.0, requires_grad=True),
|
||||
torch.tensor(1.0, requires_grad=True),
|
||||
]
|
||||
b = torch._foreach_exp(a)
|
||||
torch._foreach_add_(b, 1)
|
||||
(b[0] + b[1]).backward()
|
||||
|
||||
self.assertEqual([a[0].grad, a[1].grad], torch._foreach_exp(a))
|
||||
|
||||
|
||||
class TestAutogradInferenceMode(TestCase):
|
||||
def _is_inference_tensor(self, tensor):
|
||||
|
@ -661,31 +661,41 @@ class _CloneArgBeforeMutateMode(TorchDispatchMode):
|
||||
) -> Any:
|
||||
kwargs = kwargs or {}
|
||||
|
||||
def maybe_clone(t: torch.Tensor) -> None:
|
||||
tid = _get_tid(t)
|
||||
sid = _get_sid(t)
|
||||
ctx = self.ctx
|
||||
if sid in ctx.sid_to_tid:
|
||||
for tid in ctx.sid_to_tid[sid]:
|
||||
if tid not in ctx.tid_to_weakhandle:
|
||||
# We know that if tid is in sid_to_tid, then it must also be in
|
||||
# tid_to_weakhandle. However, it is possible for the tensor to be
|
||||
# saved at one point, but cleared by backward before it is modified
|
||||
# in-place. Consider the following example:
|
||||
#
|
||||
# >>> a = torch.randn(2, 3, requires_grad=True).clone()
|
||||
# >>> out = (a**2).sum()
|
||||
# >>> out.backward()
|
||||
# >>> a.sin_()
|
||||
continue
|
||||
handle = ctx.tid_to_weakhandle[tid]
|
||||
if handle in ctx.cloned:
|
||||
# The same exact tensor has been cloned already
|
||||
continue
|
||||
ctx.cloned[handle] = ctx.original[handle].clone()
|
||||
del ctx.original[handle]
|
||||
|
||||
for idx, arg in enumerate(func._schema.arguments):
|
||||
if arg.alias_info is not None and arg.alias_info.is_write:
|
||||
t = kwargs["out"] if arg.is_out else args[idx]
|
||||
tid = _get_tid(t)
|
||||
sid = _get_sid(t)
|
||||
ctx = self.ctx
|
||||
if sid in ctx.sid_to_tid:
|
||||
for tid in ctx.sid_to_tid[sid]:
|
||||
if tid not in ctx.tid_to_weakhandle:
|
||||
# We know that if tid is in sid_to_tid, then it must also be in
|
||||
# tid_to_weakhandle. However, it is possible for the tensor to be
|
||||
# saved at one point, but cleared by backward before it is modified
|
||||
# in-place. Consider the following example:
|
||||
#
|
||||
# >>> a = torch.randn(2, 3, requires_grad=True).clone()
|
||||
# >>> out = (a**2).sum()
|
||||
# >>> out.backward()
|
||||
# >>> a.sin_()
|
||||
continue
|
||||
handle = ctx.tid_to_weakhandle[tid]
|
||||
if handle in ctx.cloned:
|
||||
# The same exact tensor has been cloned already
|
||||
continue
|
||||
ctx.cloned[handle] = ctx.original[handle].clone()
|
||||
del ctx.original[handle]
|
||||
if arg.is_out:
|
||||
maybe_clone(kwargs["out"])
|
||||
elif isinstance(args[idx], list):
|
||||
# Foreach case. (Possible optimization: if most of the
|
||||
# tensors need to be cloned, use a for each clone?)
|
||||
for t in args[idx]:
|
||||
maybe_clone(t)
|
||||
else:
|
||||
maybe_clone(args[idx])
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user