Fix allow_mutation_on_saved_tensors for inplace foreach (#145520)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145520
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2025-01-24 13:36:22 -05:00
committed by PyTorch MergeBot
parent b4fe3c159d
commit 9e0ee152e5
2 changed files with 45 additions and 23 deletions

View File

@ -12489,6 +12489,18 @@ class TestAllowMutationOnSaved(TestCase):
with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
pass
def test_inplace_foreach(self):
with torch.autograd.graph.allow_mutation_on_saved_tensors():
a = [
torch.tensor(1.0, requires_grad=True),
torch.tensor(1.0, requires_grad=True),
]
b = torch._foreach_exp(a)
torch._foreach_add_(b, 1)
(b[0] + b[1]).backward()
self.assertEqual([a[0].grad, a[1].grad], torch._foreach_exp(a))
class TestAutogradInferenceMode(TestCase):
def _is_inference_tensor(self, tensor):

View File

@ -661,31 +661,41 @@ class _CloneArgBeforeMutateMode(TorchDispatchMode):
) -> Any:
kwargs = kwargs or {}
def maybe_clone(t: torch.Tensor) -> None:
tid = _get_tid(t)
sid = _get_sid(t)
ctx = self.ctx
if sid in ctx.sid_to_tid:
for tid in ctx.sid_to_tid[sid]:
if tid not in ctx.tid_to_weakhandle:
# We know that if tid is in sid_to_tid, then it must also be in
# tid_to_weakhandle. However, it is possible for the tensor to be
# saved at one point, but cleared by backward before it is modified
# in-place. Consider the following example:
#
# >>> a = torch.randn(2, 3, requires_grad=True).clone()
# >>> out = (a**2).sum()
# >>> out.backward()
# >>> a.sin_()
continue
handle = ctx.tid_to_weakhandle[tid]
if handle in ctx.cloned:
# The same exact tensor has been cloned already
continue
ctx.cloned[handle] = ctx.original[handle].clone()
del ctx.original[handle]
for idx, arg in enumerate(func._schema.arguments):
if arg.alias_info is not None and arg.alias_info.is_write:
t = kwargs["out"] if arg.is_out else args[idx]
tid = _get_tid(t)
sid = _get_sid(t)
ctx = self.ctx
if sid in ctx.sid_to_tid:
for tid in ctx.sid_to_tid[sid]:
if tid not in ctx.tid_to_weakhandle:
# We know that if tid is in sid_to_tid, then it must also be in
# tid_to_weakhandle. However, it is possible for the tensor to be
# saved at one point, but cleared by backward before it is modified
# in-place. Consider the following example:
#
# >>> a = torch.randn(2, 3, requires_grad=True).clone()
# >>> out = (a**2).sum()
# >>> out.backward()
# >>> a.sin_()
continue
handle = ctx.tid_to_weakhandle[tid]
if handle in ctx.cloned:
# The same exact tensor has been cloned already
continue
ctx.cloned[handle] = ctx.original[handle].clone()
del ctx.original[handle]
if arg.is_out:
maybe_clone(kwargs["out"])
elif isinstance(args[idx], list):
# Foreach case. (Possible optimization: if most of the
# tensors need to be cloned, use a for each clone?)
for t in args[idx]:
maybe_clone(t)
else:
maybe_clone(args[idx])
return func(*args, **kwargs)