[aotd] Support mutations of the same input in fw and bw (#155354)

Original issue: https://github.com/pytorch/pytorch/issues/154820

The issue happens when there is a mutation for the same input in forward AND in backward.

AOTD emited copy_ after joint_function tracing. This made this fx-node to correspond to the side effects of both mutations (in forward and in backward).
After that partitioner can put it either in forward or in backward.

The fix:

1/ Introduce joint_function.handle that allows to set "post_forward" callback, to be able to check inputs state after forward

We do not want to apply the mutation after joint, if we already applied it in forward. For that we need "mutation_counter" and memorize the version of mutation that we applied for  forward mutation.

2/ Exposing mutation_counter to python

We want to keep invariant that copy_ exist only in the end of joint graph.

3/ We memorize mutation_counter and state of the inputs after forward, using the handle post_forward.
Emit post_forward mutations after joint graph fully traced.

add for post_forward mutations "must_be_in_forward" tag (similar to existing "must_be_in_backward") to keep them in forward.

4/ Ban recompute of the source of mutation. Recompute can apply the same op (e.g. add) in forward and backward.
For this set MUST_SAVE for the source of mutation in forward.

proxy_tensor changes:

By default proxy tensor updates tensor_tracker. In this case applied mutations will be chained.
But we want that this copy_ will be independent and applied just to primals.
For this introducing a contextmanager to be able to disable update of tensor_tracker for adding forward mutations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155354
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev
2025-06-26 02:10:57 -07:00
committed by PyTorch MergeBot
parent 197c1869f5
commit 2f94f69b7c
10 changed files with 512 additions and 111 deletions

View File

@ -11,6 +11,7 @@ import functools
import inspect
import logging
import operator
import threading
import typing
import typing_extensions
import weakref
@ -179,7 +180,7 @@ def is_sym_node(node: _HasMeta) -> bool:
return "val" in node.meta and isinstance(node.meta["val"], py_sym_types)
@overload
@overload # type: ignore[no-overload-impl]
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
@ -195,7 +196,66 @@ def set_proxy_slot(
) -> None: ...
def set_proxy_slot(
class _DisableUpdateTensorTracker(threading.local):
value: bool = False
_disable_update_tensor_tracker_tls = _DisableUpdateTensorTracker()
def _is_proxy_tensor_update_tensor_tracker_disabled() -> bool:
"""
Returns current state of disabling update tensor tracker.
"""
return _disable_update_tensor_tracker_tls.value
@contextmanager
def _proxy_tensor_disable_update_tensor_tracker() -> Generator[None, None, None]:
"""
NOTE "Do not clobber inplace ops"
By default tensor_tracker is updated every time.
This leads to chaining every operation by the FakeTensor.
For example for mutable ops if we have several consecutive mutable operations:
def f(x, y, z):
x.copy_(y)
x.copy_(z)
return x
Default graph result:
def f_graph(x, y, z)
x_1 = x.copy_(y)
x_2 = x_1.copy_(z)
return x_2
This chaining simplifies the fx passes and helps to prevent the reordering.
But in some cases, we want those nodes to be disconnected.
E.g. in case of splitting joint graph into forward and backward.
If first inplace op happened in forward, second in backward,
we want them after split to be properly placed.
Enabling this context manager for copy_ will result in:
def f_graph_2(x, y, z):
x_1 = x.copy_(y)
x_2 = x.copy_(z)
return x
Results of copy_ x1 and x2 will have empty users in the graph.
The reason why this behavior is not enabled for all inplace ops is that
some fx passes (e.g. fx quantization) rely on chaining inplace ops like add_
in their fusions passes.
We could revisit enabling this logic for all inplace ops in future.
"""
orig_value = _disable_update_tensor_tracker_tls.value
_disable_update_tensor_tracker_tls.value = True
try:
yield
finally:
_disable_update_tensor_tracker_tls.value = orig_value
def set_proxy_slot( # type: ignore[no-redef]
obj: Union[PySymType, _AnyScriptObjectType, Tensor],
tracer: _ProxyTracer,
proxy: object,
@ -205,7 +265,9 @@ def set_proxy_slot(
# We DO want to clobber proxies whenever we run an inplace operation
# on a tensor, and it affects the metadata on the proxy.
assert isinstance(proxy, _ProxyTensor)
tracer.tensor_tracker[obj] = proxy
# see NOTE [Do not clobber inplace ops]
if not _is_proxy_tensor_update_tensor_tracker_disabled():
tracer.tensor_tracker[obj] = proxy
elif isinstance(obj, (_AnyScriptObject)):
# We DO want to clobber proxies, with a similar rationale as for tensors.
assert isinstance(proxy, Proxy)