mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aotd] Support mutations of the same input in fw and bw (#155354)
Original issue: https://github.com/pytorch/pytorch/issues/154820 The issue happens when there is a mutation for the same input in forward AND in backward. AOTD emited copy_ after joint_function tracing. This made this fx-node to correspond to the side effects of both mutations (in forward and in backward). After that partitioner can put it either in forward or in backward. The fix: 1/ Introduce joint_function.handle that allows to set "post_forward" callback, to be able to check inputs state after forward We do not want to apply the mutation after joint, if we already applied it in forward. For that we need "mutation_counter" and memorize the version of mutation that we applied for forward mutation. 2/ Exposing mutation_counter to python We want to keep invariant that copy_ exist only in the end of joint graph. 3/ We memorize mutation_counter and state of the inputs after forward, using the handle post_forward. Emit post_forward mutations after joint graph fully traced. add for post_forward mutations "must_be_in_forward" tag (similar to existing "must_be_in_backward") to keep them in forward. 4/ Ban recompute of the source of mutation. Recompute can apply the same op (e.g. add) in forward and backward. For this set MUST_SAVE for the source of mutation in forward. proxy_tensor changes: By default proxy tensor updates tensor_tracker. In this case applied mutations will be chained. But we want that this copy_ will be independent and applied just to primals. For this introducing a contextmanager to be able to disable update of tensor_tracker for adding forward mutations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155354 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
197c1869f5
commit
2f94f69b7c
@ -11,6 +11,7 @@ import functools
|
||||
import inspect
|
||||
import logging
|
||||
import operator
|
||||
import threading
|
||||
import typing
|
||||
import typing_extensions
|
||||
import weakref
|
||||
@ -179,7 +180,7 @@ def is_sym_node(node: _HasMeta) -> bool:
|
||||
return "val" in node.meta and isinstance(node.meta["val"], py_sym_types)
|
||||
|
||||
|
||||
@overload
|
||||
@overload # type: ignore[no-overload-impl]
|
||||
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
|
||||
|
||||
|
||||
@ -195,7 +196,66 @@ def set_proxy_slot(
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def set_proxy_slot(
|
||||
class _DisableUpdateTensorTracker(threading.local):
|
||||
value: bool = False
|
||||
|
||||
|
||||
_disable_update_tensor_tracker_tls = _DisableUpdateTensorTracker()
|
||||
|
||||
|
||||
def _is_proxy_tensor_update_tensor_tracker_disabled() -> bool:
|
||||
"""
|
||||
Returns current state of disabling update tensor tracker.
|
||||
"""
|
||||
return _disable_update_tensor_tracker_tls.value
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _proxy_tensor_disable_update_tensor_tracker() -> Generator[None, None, None]:
|
||||
"""
|
||||
NOTE "Do not clobber inplace ops"
|
||||
By default tensor_tracker is updated every time.
|
||||
This leads to chaining every operation by the FakeTensor.
|
||||
For example for mutable ops if we have several consecutive mutable operations:
|
||||
|
||||
def f(x, y, z):
|
||||
x.copy_(y)
|
||||
x.copy_(z)
|
||||
return x
|
||||
|
||||
Default graph result:
|
||||
def f_graph(x, y, z)
|
||||
x_1 = x.copy_(y)
|
||||
x_2 = x_1.copy_(z)
|
||||
return x_2
|
||||
|
||||
This chaining simplifies the fx passes and helps to prevent the reordering.
|
||||
But in some cases, we want those nodes to be disconnected.
|
||||
E.g. in case of splitting joint graph into forward and backward.
|
||||
If first inplace op happened in forward, second in backward,
|
||||
we want them after split to be properly placed.
|
||||
|
||||
Enabling this context manager for copy_ will result in:
|
||||
def f_graph_2(x, y, z):
|
||||
x_1 = x.copy_(y)
|
||||
x_2 = x.copy_(z)
|
||||
return x
|
||||
|
||||
Results of copy_ x1 and x2 will have empty users in the graph.
|
||||
The reason why this behavior is not enabled for all inplace ops is that
|
||||
some fx passes (e.g. fx quantization) rely on chaining inplace ops like add_
|
||||
in their fusions passes.
|
||||
We could revisit enabling this logic for all inplace ops in future.
|
||||
"""
|
||||
orig_value = _disable_update_tensor_tracker_tls.value
|
||||
_disable_update_tensor_tracker_tls.value = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_disable_update_tensor_tracker_tls.value = orig_value
|
||||
|
||||
|
||||
def set_proxy_slot( # type: ignore[no-redef]
|
||||
obj: Union[PySymType, _AnyScriptObjectType, Tensor],
|
||||
tracer: _ProxyTracer,
|
||||
proxy: object,
|
||||
@ -205,7 +265,9 @@ def set_proxy_slot(
|
||||
# We DO want to clobber proxies whenever we run an inplace operation
|
||||
# on a tensor, and it affects the metadata on the proxy.
|
||||
assert isinstance(proxy, _ProxyTensor)
|
||||
tracer.tensor_tracker[obj] = proxy
|
||||
# see NOTE [Do not clobber inplace ops]
|
||||
if not _is_proxy_tensor_update_tensor_tracker_disabled():
|
||||
tracer.tensor_tracker[obj] = proxy
|
||||
elif isinstance(obj, (_AnyScriptObject)):
|
||||
# We DO want to clobber proxies, with a similar rationale as for tensors.
|
||||
assert isinstance(proxy, Proxy)
|
||||
|
Reference in New Issue
Block a user