diff --git a/test/test_autograd.py b/test/test_autograd.py index 2a00b620d78e..6fd359fbc943 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -12489,6 +12489,18 @@ class TestAllowMutationOnSaved(TestCase): with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: pass + def test_inplace_foreach(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors(): + a = [ + torch.tensor(1.0, requires_grad=True), + torch.tensor(1.0, requires_grad=True), + ] + b = torch._foreach_exp(a) + torch._foreach_add_(b, 1) + (b[0] + b[1]).backward() + + self.assertEqual([a[0].grad, a[1].grad], torch._foreach_exp(a)) + class TestAutogradInferenceMode(TestCase): def _is_inference_tensor(self, tensor): diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index d2312331663f..549f31d349e0 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -661,31 +661,41 @@ class _CloneArgBeforeMutateMode(TorchDispatchMode): ) -> Any: kwargs = kwargs or {} + def maybe_clone(t: torch.Tensor) -> None: + tid = _get_tid(t) + sid = _get_sid(t) + ctx = self.ctx + if sid in ctx.sid_to_tid: + for tid in ctx.sid_to_tid[sid]: + if tid not in ctx.tid_to_weakhandle: + # We know that if tid is in sid_to_tid, then it must also be in + # tid_to_weakhandle. However, it is possible for the tensor to be + # saved at one point, but cleared by backward before it is modified + # in-place. Consider the following example: + # + # >>> a = torch.randn(2, 3, requires_grad=True).clone() + # >>> out = (a**2).sum() + # >>> out.backward() + # >>> a.sin_() + continue + handle = ctx.tid_to_weakhandle[tid] + if handle in ctx.cloned: + # The same exact tensor has been cloned already + continue + ctx.cloned[handle] = ctx.original[handle].clone() + del ctx.original[handle] + for idx, arg in enumerate(func._schema.arguments): if arg.alias_info is not None and arg.alias_info.is_write: - t = kwargs["out"] if arg.is_out else args[idx] - tid = _get_tid(t) - sid = _get_sid(t) - ctx = self.ctx - if sid in ctx.sid_to_tid: - for tid in ctx.sid_to_tid[sid]: - if tid not in ctx.tid_to_weakhandle: - # We know that if tid is in sid_to_tid, then it must also be in - # tid_to_weakhandle. However, it is possible for the tensor to be - # saved at one point, but cleared by backward before it is modified - # in-place. Consider the following example: - # - # >>> a = torch.randn(2, 3, requires_grad=True).clone() - # >>> out = (a**2).sum() - # >>> out.backward() - # >>> a.sin_() - continue - handle = ctx.tid_to_weakhandle[tid] - if handle in ctx.cloned: - # The same exact tensor has been cloned already - continue - ctx.cloned[handle] = ctx.original[handle].clone() - del ctx.original[handle] + if arg.is_out: + maybe_clone(kwargs["out"]) + elif isinstance(args[idx], list): + # Foreach case. (Possible optimization: if most of the + # tensors need to be cloned, use a for each clone?) + for t in args[idx]: + maybe_clone(t) + else: + maybe_clone(args[idx]) return func(*args, **kwargs)