[Compiled Autograd] Introduce BackwardState capture (#120382)

This adds support for backwards hooks that are *both*:
1) Interior to the graph; and
2) Dynamically generated (e.g. lambdas)

We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382
Approved by: https://github.com/xmfan
This commit is contained in:
Jason Ansel
2024-02-27 15:49:16 -08:00
committed by PyTorch MergeBot
parent c016ffed5b
commit 01ec8df6d8
24 changed files with 499 additions and 151 deletions

View File

@ -59,19 +59,15 @@ class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase):
out = make_fx(_multiply_invoke)(x)
self.assertEqual(out(x), torch.tensor([0.25, 0.25]))
actual = normalize_gm(out.print_readable(False))
expected = """\
self.assertExpectedInline(
actual,
"""\
class _multiply_invoke(torch.nn.Module):
def forward(self, grad_1: "f32[2]"):
trace_wrapped: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
assert_1: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(trace_wrapped, (2,), (1,), torch.float32); trace_wrapped = None
detach: "f32[2]" = torch.ops.aten.detach.default(assert_1); assert_1 = None
detach_1: "f32[2]" = torch.ops.aten.detach.default(detach); detach = None
detach_2: "f32[2]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_3: "f32[2]" = torch.ops.aten.detach.default(detach_2); detach_2 = None
return detach_3
"""
self.assertExpectedInline(actual, expected)
return trace_wrapped
""",
)
def test_invoke_make_bw(self):
x = torch.tensor([0.5, 0.5], requires_grad=True)
@ -86,14 +82,15 @@ class _multiply_invoke(torch.nn.Module):
self.assertEqual(out(x.grad), torch.tensor([4.0, 4.0]))
actual = normalize_gm(out.print_readable(False))
expected = """\
self.assertExpectedInline(
actual,
"""\
class _multiply_invoke(torch.nn.Module):
def forward(self, grad_1: "f32[2]"):
trace_wrapped: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
assert_1: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(trace_wrapped, (2,), (1,), torch.float32); trace_wrapped = None
return assert_1
"""
self.assertExpectedInline(actual, expected)
return trace_wrapped
""",
)
def test_invoke_in_pt2_compiled_autograd(self):
graph = None

View File

@ -33,6 +33,11 @@ def global_hook_2(grad):
h0 = None
class ClassWithVal:
def __init__(self, val):
self.val = val
class HooksTests(torch._dynamo.test_case.TestCase):
def test_tensor_only_register_hook_in_graph_lambda(self):
def fn(x):
@ -517,10 +522,6 @@ class HooksTests(torch._dynamo.test_case.TestCase):
def test_register_hook_partial_guarding(
self,
):
class SomePyClass:
def __init__(self, val):
self.val = val
def some_hook(grad, *, obj):
return grad + obj.val
@ -533,8 +534,9 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return (z,)
mod = MyMod()
obj1 = SomePyClass(88)
obj2 = SomePyClass(99)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
obj3 = ClassWithVal(11)
cnt = torch._dynamo.testing.CompileCounter()
x0 = torch.ones(4, requires_grad=True)
@ -543,34 +545,23 @@ class HooksTests(torch._dynamo.test_case.TestCase):
with compiled_autograd.enable(compiler_fn):
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj1)
torch.compile(mod, backend=cnt, fullgraph=True)(x1, obj1)
self.assertEqual(cnt.frame_count, 1)
# New obj forces recompile (for now)
# TODO(jansel): this behavor is bad, we should fix it so it doesn't happen
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj2)
self.assertEqual(cnt.frame_count, 2)
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj3)
self.assertEqual(cnt.frame_count, 1)
def test_hook_with_closure(self):
class SomePyClass:
def __init__(self, val):
self.val = val
cnt = torch._dynamo.testing.CompileCounter()
def fn(x, obj):
y = x.mul(2)
def hook1(grad):
return grad + obj.val
x.register_hook(hook1)
z = y.mul(3)
y = x.sin()
x.register_hook(lambda grad: grad + obj.val)
z = y.sin()
return z
opt = torch.compile(fn, backend=cnt, fullgraph=True)
obj1 = SomePyClass(88)
obj2 = SomePyClass(99)
cnt_fw = torch._dynamo.testing.CompileCounter()
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
@ -579,11 +570,72 @@ class HooksTests(torch._dynamo.test_case.TestCase):
fn(x1, obj2).sum().backward()
with compiled_autograd.enable(
functools.partial(torch.compile, backend="eager")
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt_fw.frame_count, 1)
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_intermediate_hook_with_closure_eager(self):
def fn(x, obj):
y = x.sin()
y.register_hook(lambda grad: grad + obj.val)
z = y.sin()
return z
cnt_fw = torch._dynamo.testing.CompileCounter()
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd.enable(
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(cnt_fw.frame_count, 1)
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_intermediate_hook_with_closure_aot(self):
def fn(x, obj):
y = x.sin()
y.register_hook(lambda grad: grad + obj.val)
z = y.sin()
return z
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd.enable(
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
@ -624,7 +676,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
comp_out = comp_mod(x1)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.frame_count, 1)
comp_out[0].backward(torch.ones(4))
self.assertEqual(x0.grad, x1.grad)

View File

@ -133,6 +133,13 @@ uniform_qconfig_8bit = QConfig(
qconfig_dict = {"object_type": [(torch.nn.Linear, uniform_qconfig_8bit)]}
def closure_adder(val):
def inner(x):
return torch.sin(x + val)
return inner
class MiscTests(torch._dynamo.test_case.TestCase):
def test_get_cache_entry(self):
def f(x):
@ -467,6 +474,25 @@ class MiscTests(torch._dynamo.test_case.TestCase):
cleanup_op("mylib::foo")
del lib
def test_closure_recompiles(self):
cnt = CompileCounter()
def fn(x, other_fn):
return other_fn(x + 1) - 1
opt = torch.compile(fn, backend=cnt, fullgraph=True)
x = torch.randn(8)
for f in (
closure_adder(5),
closure_adder(5),
closure_adder(torch.randn(8)),
closure_adder(torch.randn(8)),
):
self.assertEqual(opt(x, f), fn(x, f))
self.assertEqual(cnt.frame_count, 2)
def test_generate_trivial_abstract_impl(self):
try:
lib = torch.library.Library("mylib", "FRAGMENT")

View File

@ -1,11 +1,45 @@
# Owner(s): ["oncall: pt2"]
import dataclasses
import functools
import torch
from torch._dynamo import compiled_autograd
from torch._dynamo.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.inductor_utils import HAS_CPU
class DistributedPatternTests(TestCase):
def test_intermediate_hook_with_closure(self):
@dataclasses.dataclass
class CustomObj:
val: torch.Tensor
def fn(x, obj):
y = x.sin()
closure_var = y + 1
y.register_hook(lambda grad: grad + obj.val + closure_var)
z = y.sin()
return z
opt = torch.compile(fn, fullgraph=True)
obj1 = CustomObj(torch.tensor(88))
obj2 = CustomObj(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd.enable(functools.partial(torch.compile, fullgraph=True)):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
@torch.no_grad()
def test_storage_resize_zero(self):
@torch.compile(fullgraph=True)

View File

@ -1,11 +1,14 @@
import torch
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._subclasses import FakeTensorMode
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils._python_dispatch import _get_current_dispatch_mode
from torch.utils._pytree import tree_map_only
__all__ = ["trace_wrapped"]
@ -41,10 +44,12 @@ __all__ = ["trace_wrapped"]
# compiled autograd do we inline into the function.
def trace_wrapped(*args, fn):
return _trace_wrapped_op(*args, fn=fn)
def trace_wrapped(*args, **kwargs):
with torch.no_grad():
return _trace_wrapped_op(*args, **kwargs)
# TODO(jansel): need to ensure this does not get DCEed
_trace_wrapped_op = HigherOrderOperator("trace_wrapped")
@ -56,54 +61,51 @@ def _assert_meta(grad, size, stride, dtype):
@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode)
def inner_trace(mode, *args, fn):
import torch
def inner_trace(mode, *args, bw_state=None, **kwargs):
def self_invoke(*args, **dyn_kwargs):
with torch.no_grad():
return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs)
assert len(args) == 1
grad = args[0]
assert isinstance(grad, torch.Tensor)
def unwrap_proxies(x):
if isinstance(x, torch.Tensor):
return mode.tracer.unwrap_proxy(x)
if isinstance(x, (list, tuple)):
return type(x)(map(unwrap_proxies, x))
if x is None:
return None
raise AssertionError(f"unhandled type: {type(x)}")
def self_invoke(*args):
return _trace_wrapped_op(*args, fn=fn)
proxy_args = (mode.tracer.unwrap_proxy(grad),)
out_proxy = mode.tracer.create_proxy(
"call_function", self_invoke, proxy_args, {}, name="trace_wrapped"
)
grad = torch.zeros_like(grad)
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
# We have a little shortcut here, wherein we DO NOT yet run a meta func, and so
# we take on an assumption that input and output meta matches. As such, we must introduce
# a runtime assert
proxy_args = (
mode.tracer.unwrap_proxy(grad),
grad.size(),
grad.stride(),
grad.dtype,
)
proxy_kwargs = {}
if bw_state is not None:
assert isinstance(bw_state, BackwardState) and bw_state.proxy is not None
proxy_kwargs["bw_state"] = bw_state.proxy
out_proxy = mode.tracer.create_proxy(
"call_function",
_assert_meta,
proxy_args,
{},
name="assert",
self_invoke,
unwrap_proxies(args),
proxy_kwargs,
name="trace_wrapped",
)
grad = torch.empty_like(grad)
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
if args[0] is None:
grad = args[1] # module backward hooks
else:
grad = args[0] # other backward hooks
grad = tree_map_only(torch.Tensor, torch.empty_like, grad)
track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
return grad
@_trace_wrapped_op.py_impl(FakeTensorMode)
def inner_fake(*args, fn):
def inner_fake(*args, **kwargs):
raise RuntimeError("This op should never be invoked here")
@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def _trace_wrapped_op_dense(*args, fn):
def _trace_wrapped_op_dense(*args, fn, **kwargs):
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return fn(*args)
return fn(*args, **kwargs)
_trace_wrapped_op.py_impl(DispatchKey.Autograd)(
@ -112,7 +114,7 @@ _trace_wrapped_op.py_impl(DispatchKey.Autograd)(
@_trace_wrapped_op.py_functionalize_impl
def _trace_wrapped_functionalized(ctx, *args, fn):
def _trace_wrapped_functionalized(ctx, *args, **kwargs):
unwrapped_args = ctx.unwrap_tensors(args)
with ctx.redispatch_to_next():
return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, fn=fn))
return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, **kwargs))

View File

@ -10,6 +10,7 @@ from torch._logging import getArtifactLogger, trace_structured
from torch._prims_common import clone_preserve_strides
from torch._subclasses import FakeTensorMode
from torch.fx import GraphModule
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import (
decompose,
disable_autocast_cache,
@ -223,6 +224,13 @@ class AutogradCompilerInstance:
assert len(tensors) == len(proxies)
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
def bind_backward_state(self, index: int):
assert self.hooks_proxy is not None
proxy = self.hooks_proxy[index] # type: ignore[index]
bw_state = BackwardState()
track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer)
return bw_state
compiled_autograd_enabled = False

View File

@ -74,3 +74,7 @@ def call_backward(backward_fn, saved_tensors, *args):
def untyped_storage_size(x: torch.Tensor):
return x.untyped_storage().size()
def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs):
return getattr(bw_state, hook_name)(*args, **kwargs)

View File

@ -30,6 +30,7 @@ from torch._guards import (
)
from torch._utils_internal import signpost_event
from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.sym_node import SymNode
from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
@ -60,6 +61,7 @@ from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects
from .source import (
AttrSource,
BackwardStateSource,
ConstantSource,
GlobalStateSource,
is_constant_source,
@ -87,7 +89,13 @@ from .utils import (
same,
)
from .variables.base import VariableTracker
from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
from .variables.builder import (
BackwardStateGraphArg,
GraphArg,
TrackedFake,
VariableBuilder,
wrap_fx_proxy,
)
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
NumpyNdarrayVariable,
@ -380,6 +388,29 @@ class OutputGraph(Checkpointable[OutputGraphState]):
] = []
self.random_values_var = None
# Use to pass values to backward hooks when using compiled autograd
self.backward_state: Dict[str, VariableTracker] = {}
self.backward_state_proxy: Optional[torch.fx.Proxy] = None
self.backward_state_var: Optional[str] = None
def add_backward_state_hook(self, hook: VariableTracker):
name = f"hook{len(self.backward_state)}"
assert name not in self.backward_state
self.backward_state[name] = hook
return name, self.get_backward_state_proxy()
def get_backward_state_proxy(self):
if self.backward_state_proxy is None:
if self.export:
unimplemented("backward_state does not support export")
self.backward_state_proxy = self.root_tracer.create_graph_input(
"dynamo_backward_state", BackwardState, source=BackwardStateSource()
)
self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
self.backward_state_proxy.node.meta["example_value"] = BackwardState()
self.backward_state_var = self.new_var()
return self.backward_state_proxy
# This gets its own helper function so guards DEBUG logs are more informative
def init_ambient_guards(self):
# Register a SHAPE_ENV guard to make sure we setup shape guards
@ -924,6 +955,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
and all(isinstance(x, TensorVariable) for x in stack_values)
and len(set(stack_values)) == len(stack_values)
and self.side_effects.is_empty()
and not self.backward_state
):
append_prefix_insts()
# optimization to generate better code in a common case
@ -934,10 +966,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
else:
graph_output_var = self.new_var("graph_out")
pass1 = PyCodegen(tx, root, graph_output_var)
self.side_effects.codegen_hooks(pass1)
self.side_effects.codegen_save_tempvars(pass1)
pass1.restore_stack(stack_values, value_from_source=not tx.export)
self.side_effects.codegen_update_mutated(pass1)
self.codegen_suffix(tx, stack_values, pass1)
# one more time now that we have established tempvars
pass2 = PyCodegen(
@ -946,10 +975,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
graph_output_var,
tempvars={val: None for val, count in pass1.uses.items() if count > 1},
)
self.side_effects.codegen_hooks(pass2)
self.side_effects.codegen_save_tempvars(pass2)
pass2.restore_stack(stack_values, value_from_source=not tx.export)
self.side_effects.codegen_update_mutated(pass2)
self.codegen_suffix(tx, stack_values, pass2)
output = []
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
@ -969,6 +995,18 @@ class OutputGraph(Checkpointable[OutputGraphState]):
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
)
def codegen_suffix(self, tx, stack_values, cg):
if self.backward_state:
assert not self.export
for name, val in self.backward_state.items():
cg(val)
cg.append_output(cg.create_load(self.backward_state_var))
cg.store_attr(name)
self.side_effects.codegen_hooks(cg)
self.side_effects.codegen_save_tempvars(cg)
cg.restore_stack(stack_values, value_from_source=not tx.export)
self.side_effects.codegen_update_mutated(cg)
def cleanup_graph(self):
"""
Remove "creation_timestamp" from node meta
@ -1243,11 +1281,15 @@ class OutputGraph(Checkpointable[OutputGraphState]):
if not node.users:
recheck_placeholders.append(node)
else:
if not node.users:
if not node.users and not isinstance(
node.meta["grapharg"], BackwardStateGraphArg
):
remove_unused(node)
else:
# Register the free symbols as uses
arg = node.meta["grapharg"]
if isinstance(arg, BackwardStateGraphArg):
continue
fake = (
arg.fake_tensor if arg.fake_tensor is not None else arg.example
)

View File

@ -485,6 +485,15 @@ class ShapeEnvSource(Source):
return GuardSource.SHAPE_ENV
@dataclasses.dataclass(frozen=True)
class BackwardStateSource(Source):
def name(self):
return ""
def guard_source(self):
return GuardSource.BACKWARD_STATE
def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
if isinstance(source, ChainedSource):
return is_from_local_source(

View File

@ -27,6 +27,7 @@ from torch._guards import GuardSource, TracingContext
from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimDynamic,
@ -213,6 +214,24 @@ class GraphArg:
return self.source.name() == other.source.name()
class BackwardStateGraphArg(GraphArg):
def __init__(self):
super().__init__(
source=None,
_example=BackwardState(),
is_unspecialized=False,
fake_tensor=None,
is_tensor=False,
)
def reconstruct(self, codegen):
assert codegen.tx.output.backward_state_var
codegen.load_import_from(BackwardState.__module__, "BackwardState")
codegen.call_function(0, True)
codegen.dup_top()
codegen.store(codegen.tx.output.backward_state_var)
@dataclasses.dataclass
class FrameStateSizeEntry:
scalar: Optional[int]

View File

@ -1443,18 +1443,19 @@ class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable):
class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable):
"""
Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace
by unwrapping the higher order op and inlining through it. This op
is created by dynamo to survive through AotAutograd, then unwrapped
here in the call to dynamo from compiled autograd.
"""
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import TensorVariable
assert "fn" in kwargs
fn = kwargs["fn"]
assert len(args) == 1
grad = args[0]
assert isinstance(grad, TensorVariable)
return fn.call_function(tx, args, {})
kwargs = dict(kwargs)
fn = kwargs.pop("fn")
return fn.call_function(tx, args, kwargs)
class AutogradFunctionApplyVariable(VariableTracker):

View File

@ -10,6 +10,7 @@ from typing import Dict, List
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from ..bytecode_transformation import create_call_method
from ..external_utils import call_hook_from_backward_state
try:
import numpy as np
@ -769,7 +770,7 @@ class TensorVariable(VariableTracker):
"register_post_accumulate_grad_hook", *args, **kwargs
)
def _method_register_hook(self, name, hook):
def _method_register_hook(self, name: str, hook: VariableTracker):
# Note - do not arbitrarily add hooks here - make sure they match the same contract
# see [On tensor.register_hook]
from ..symbolic_convert import InstructionTranslator
@ -797,14 +798,22 @@ class TensorVariable(VariableTracker):
"Compilation of intermediate hooks requires compiled autograd"
)
# This wraps our user provided fn with a function that intercedes and
# uses our `invoke` higher order op to record a hook invocation in bwd graph.
fn = functools.partial(trace_wrapped, fn=hook.guard_as_python_constant())
hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook)
def _register_hook_trampoline(tensor):
hook_callable = getattr(tensor, name)
hook_callable(fn)
return tensor
def _register_hook_trampoline(tensor, bw_state):
register_hook = getattr(tensor, name)
register_hook(
functools.partial(
trace_wrapped,
fn=call_hook_from_backward_state,
bw_state=bw_state,
hook_name=hook_name,
)
)
# TODO(jansel): returning None here is wrong, it should be
# RemovableHandle, but we need some extra work to support
# this properly.
return None
from .builder import wrap_fx_proxy
@ -813,7 +822,7 @@ class TensorVariable(VariableTracker):
tx.output.create_proxy(
"call_function",
_register_hook_trampoline,
(self.as_proxy(),),
(self.as_proxy(), bw_state_proxy),
{},
),
)

View File

@ -13,7 +13,6 @@ from functools import wraps
from typing import Any, List, Optional
import torch
import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch import Tensor
from torch._dynamo.utils import lazy_format_graph_code
@ -21,6 +20,7 @@ from torch._guards import detect_fake_mode, tracing, TracingContext
from torch._logging import getArtifactLogger, trace_structured
from torch._prims_common import CUDARngStateHelper
from torch._subclasses import FakeTensor
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals
from .. import config
@ -158,10 +158,6 @@ def aot_dispatch_autograd(
)
# Copied from aot_dispatch_autograd_graph.
traced_tangents = pytree.tree_map(
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
fw_metadata.traced_tangents,
)
disable_amp = torch._C._is_any_autocast_enabled()
if aot_config.enable_log:
@ -419,6 +415,11 @@ def aot_dispatch_autograd(
saved_context = TracingContext.try_get()
backward_state_indices = [
idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
]
assert len(backward_state_indices) <= 1
class CompiledFunction(torch.autograd.Function):
compiled_fw = compiled_fw_func
compiled_bw = compiled_bw_func
@ -434,6 +435,10 @@ def aot_dispatch_autograd(
@staticmethod
def forward(ctx, *deduped_flat_tensor_args):
args = deduped_flat_tensor_args
if backward_state_indices:
bw_state = args[backward_state_indices[0]]
assert isinstance(bw_state, BackwardState)
ctx._compiled_autograd_backward_state = bw_state
marked_dirty_inps = []
for i in fw_metadata.mutated_graph_handled_indices_seen_by_autograd:
@ -457,8 +462,6 @@ def aot_dispatch_autograd(
num_outputs = CompiledFunction.metadata.num_outputs
num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased
num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases
num_symints_saved_for_bw = CompiledFunction.num_symints_saved_for_bw
num_mutated_runtime_inps = (
CompiledFunction.metadata.num_mutated_inp_runtime_indices
)
@ -742,6 +745,9 @@ Got grad_output types: {str(grad_output_types)}"""
symints = ctx._get_compiled_autograd_symints()
assert len(symints) == len(ctx.symints)
all_args[: len(symints)] = symints
if backward_state_indices:
assert ctx._compiled_autograd_backward_state.proxy is not None
all_args.append(ctx._compiled_autograd_backward_state)
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
out = normalize_as_list(bw_module(*all_args))
@ -749,6 +755,9 @@ Got grad_output types: {str(grad_output_types)}"""
CompiledFunction.metadata, out
)
return tuple(out)
assert (
not backward_state_indices
), "BackwardState requires CompiledAutograd"
ctx.maybe_clear_saved_tensors()
if CompiledFunction.compiled_bw is None:
context = torch._C._DisableAutocast if disable_amp else nullcontext

View File

@ -10,9 +10,19 @@ from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.utils._pytree as pytree
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import py_sym_types
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types)
KNOWN_TYPES = [
torch.Tensor,
BackwardState,
int,
str,
float,
bool,
type(None),
*py_sym_types,
]
original_zip = zip

View File

@ -5,6 +5,7 @@ from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import (
hint_int, free_symbols, is_symbol_binding_fx_node, find_symbol_binding_fx_nodes
)
from torch.fx.experimental._backward_state import BackwardState
import torch
import torch.fx as fx
import operator
@ -21,6 +22,7 @@ from .compile_utils import fx_graph_cse, get_aten_target
from . import config
import functools
AOT_PARTITIONER_DEBUG = config.debug_partitioner
@ -124,6 +126,9 @@ def _is_bwd_seed_offset(node):
def _is_fwd_seed_offset(node):
return node.op == "placeholder" and ("fwd_seed" in node.target or "fwd_base_offset" in node.target)
def _is_backward_state(node):
return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState)
def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs):
outputs = pytree.arg_tree_leaves(*(node.args for node in joint_module.graph.nodes if node.op == 'output'))
@ -132,38 +137,49 @@ def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs):
return fwd_outputs, bwd_outputs
def _remove_by_name(saved_values, name):
for saved_value in saved_values:
if saved_value.name == name:
saved_values.remove(saved_value)
break
def _placeholders(nodes):
# Avoid making an entire pass over the graph if we only care about the input placeholders
result = []
for node in nodes:
if node.op == 'placeholder':
result.append(node)
else:
break # placeholders are all at the start of graph
return result
def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs):
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
bwd_seed_offset_inputs = list(filter(_is_bwd_seed_offset, joint_module.graph.nodes))
placeholders = _placeholders(joint_module.graph.nodes)
primal_inputs = [*filter(_is_primal, placeholders)]
tangent_inputs = [*filter(_is_tangent, placeholders)]
fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
backward_state_inputs = [*filter(_is_backward_state, placeholders)]
# Construct the forward module
# Keep symints separate from tensors, passed between fwd/bwd graphs, and in the right order.
fwd_graph = _extract_graph_with_inputs_outputs(
joint_module.graph,
primal_inputs + fwd_seed_offset_inputs,
fwd_outputs + saved_values + saved_sym_nodes
)
bwd_graph = _extract_graph_with_inputs_outputs(
joint_module.graph,
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
bwd_outputs
)
for node in _placeholders(bwd_graph.nodes):
assert node.op == 'placeholder'
# This is to filter out saved values that don't actually end up being used by the backwards pass
for node in bwd_graph.nodes:
if node.op == 'placeholder' and not node.users:
for saved_value in saved_values:
if saved_value.name == node.name:
saved_values.remove(saved_value)
break
if not node.users:
_remove_by_name(saved_values, node.name)
_remove_by_name(saved_sym_nodes, node.name)
elif _is_backward_state(node):
# BackwardState is saved directly
_remove_by_name(saved_values, node.name)
assert backward_state_inputs
for saved_sym in saved_sym_nodes:
if saved_sym.name == node.name:
saved_sym_nodes.remove(saved_sym)
break
# Now that we have the finalized list of saved values, we need to ensure
# we propagate all symbols which are referenced by backwards inputs.
@ -216,7 +232,7 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_s
)
bwd_graph = _extract_graph_with_inputs_outputs(
joint_module.graph,
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs + backward_state_inputs,
bwd_outputs
)
@ -863,7 +879,7 @@ def min_cut_rematerialization_partition(
if is_sym_node(node):
weight = sym_node_size(node)
elif is_non_tensor_node:
weight = math.inf
weight = 0 if isinstance(node.meta.get("val"), BackwardState) else math.inf
else:
weight = get_node_weight(node)

View File

@ -86,6 +86,7 @@ class GuardSource(enum.Enum):
SHAPE_ENV = 6
LOCAL_FSDP_MODULE = 7
GLOBAL_FSDP_MODULE = 8
BACKWARD_STATE = 9
def is_fsdp_module(self) -> bool:
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)

View File

@ -546,9 +546,10 @@ class WrapperCodeGen(CodeGen):
with self.prefix.indent():
if config.triton.debug_sync_graph:
self.prefix.writeline(V.graph.device_ops.synchronize())
inp_len = len(V.graph.graph_inputs.keys())
if inp_len != 0:
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
if V.graph.graph_inputs:
lhs = ", ".join(V.graph.graph_input_names)
if len(V.graph.graph_input_names) == 1:
lhs += ","
self.prefix.writeline(f"{lhs} = args")
self.prefix.writeline("args.clear()")

View File

@ -18,6 +18,7 @@ from torch._decomp import get_decompositions
from torch._dynamo.utils import defake, dynamo_timed
from torch._logging import LazyString, trace_structured
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
from torch.utils._mode_utils import no_dispatch
@ -236,6 +237,7 @@ class GraphLowering(torch.fx.Interpreter):
self.reuse_shape_env = True
self._shape_env = shape_env
self.sizevars = SizeVarAllocator(shape_env)
self.graph_input_names: List[str] = []
self.graph_inputs: Dict[str, TensorBox] = {}
self.graph_inputs_original: Dict[str, InputBuffer] = {}
self.device_types: Set[str] = (
@ -718,6 +720,7 @@ class GraphLowering(torch.fx.Interpreter):
def placeholder(self, target: str, args, kwargs):
example = super().placeholder(target, args, kwargs)
self.graph_input_names.append(target)
if isinstance(example, SymTypes):
expr = example.node.expr
self.graph_inputs[target] = expr
@ -726,6 +729,10 @@ class GraphLowering(torch.fx.Interpreter):
expr = sympy.sympify(example)
self.graph_inputs[target] = expr
return expr
if isinstance(example, BackwardState):
# Ignored arg, must be unused
# Alternately we could filter this out in AotAutograd
return None
assert isinstance(example, torch.Tensor), example
# todo(chilli): We can remove the last check once we turn buffers into
# static shape tensors. That's a hack to workaround Inductor believing

View File

@ -131,8 +131,7 @@ PyObject* to_py_size(const std::vector<c10::SymInt>& size) {
} // namespace
namespace torch {
namespace autograd {
namespace torch::autograd {
// NOTE: this function is written in a way that assumes it's only called for
// backward; it's used by engine.cpp. This is responsible for forwarding a call
@ -237,6 +236,7 @@ auto PyNode::compiled_apply(
TORCH_INTERNAL_ASSERT(
_backward_idx.has_value(),
"indices should already be set by compiled_args, called before apply_with_saved");
TORCH_INTERNAL_ASSERT(!_backward_state_idx.has_value());
THPObjectPtr r(PyObject_CallMethod(
*compiler,
"proxy_call_backward",
@ -302,7 +302,7 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
throw_python_error();
TORCH_CHECK(
PyTuple_CheckExact(pykey.get()),
"_compiled_autograd_key shoud return tuple of ints");
"_compiled_autograd_key should return tuple of ints");
auto size = PyTuple_GET_SIZE(pykey.get());
TORCH_INTERNAL_ASSERT(size > 0);
// first value is unique ID of the AotAutograd graph
@ -345,6 +345,13 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
PyObject* backward(PyObject_GetAttr(forward_cls.get(), backward_name));
_backward_idx =
args.add_backward(c10::SafePyObject(backward, getPyInterpreter()));
PyObject* bw_state = f->compiled_autograd_backward_state;
if (args.cond(bw_state != nullptr)) {
Py_INCREF(bw_state);
_backward_state_idx = args.add_backward_state(
c10::SafePyObject(bw_state, getPyInterpreter()));
}
}
variable_list PyNode::apply_with_saved(
@ -361,7 +368,23 @@ variable_list PyNode::apply_with_saved(
f->compiled_autograd_tracing = true;
variable_list result;
if (!compiled_autograd_should_lift()) {
if (_backward_state_idx.has_value()) {
PyObject* r = PyObject_CallMethod(
saved.get_py_compiler(),
"bind_backward_state",
"i",
*_backward_state_idx);
if (r == nullptr) {
throw python_error();
}
THPObjectPtr prior(f->compiled_autograd_backward_state);
f->compiled_autograd_backward_state = r;
result = apply(variable_list(inputs));
Py_CLEAR(f->compiled_autograd_backward_state);
f->compiled_autograd_backward_state = prior.release();
} else {
result = apply(variable_list(inputs));
}
} else {
result = compiled_apply(variable_list(inputs), saved.get_py_compiler());
}
@ -445,8 +468,7 @@ variable_list PyNode::to_variable_list(
return results;
}
} // namespace autograd
} // namespace torch
} // namespace torch::autograd
// Traverse and clear are required for supporting Python's GC cycle handling.
static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
@ -455,6 +477,7 @@ static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
Py_VISIT(self->to_save);
Py_VISIT(self->non_differentiable);
Py_VISIT(self->dirty_tensors);
Py_VISIT(self->compiled_autograd_backward_state);
Py_VISIT(self->saved_for_forward);
return 0;
}
@ -469,6 +492,7 @@ static int THPFunction_clear(THPFunction* self) {
Py_CLEAR(self->to_save);
Py_CLEAR(self->non_differentiable);
Py_CLEAR(self->dirty_tensors);
Py_CLEAR(self->compiled_autograd_backward_state);
Py_CLEAR(self->saved_for_forward);
self->output_info.clear();
@ -1504,6 +1528,33 @@ PyObject* THPFunction_get_compiled_autograd_symints(
END_HANDLE_TH_ERRORS
}
PyObject* THPFunction_get_compiled_autograd_backward_state(
PyObject* _self,
void* _unused) {
HANDLE_TH_ERRORS
auto self = (THPFunction*)_self;
PyObject* bw_state = self->compiled_autograd_backward_state;
if (bw_state == nullptr) {
bw_state = Py_None;
}
Py_INCREF(bw_state);
return bw_state;
END_HANDLE_TH_ERRORS
}
int THPFunction_set_compiled_autograd_backward_state(
PyObject* _self,
PyObject* bw_state,
void* _unused) {
HANDLE_TH_ERRORS
auto self = (THPFunction*)_self;
TORCH_INTERNAL_ASSERT(self->compiled_autograd_backward_state == nullptr);
Py_INCREF(bw_state);
self->compiled_autograd_backward_state = bw_state;
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject* THPFunction_raw_saved_tensors(THPFunction* self, void* _unused) {
HANDLE_TH_ERRORS
// User tries to access saved variables after they have been freed
@ -1684,6 +1735,11 @@ static struct PyGetSetDef THPFunction_properties[] = {
(setter)THPFunction_set_materialize_non_diff_grads,
nullptr,
nullptr},
{"_compiled_autograd_backward_state",
(getter)THPFunction_get_compiled_autograd_backward_state,
(setter)THPFunction_set_compiled_autograd_backward_state,
nullptr,
nullptr},
{nullptr}};
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -56,6 +56,10 @@ struct PyNode : public Node {
// The AutogradCompilerCall::hooks idx corresponding to this node's backward
std::optional<int> _backward_idx;
// The AutogradCompilerCall::hooks idx corresponding to this node's
// backward_state
std::optional<int> _backward_state_idx;
// NOLINTNEXTLINE(bugprone-exception-escape)
~PyNode() override {
// Can't use THPObjectPtr as a field in this class; destructor won't take
@ -121,6 +125,7 @@ struct THPFunction {
// This is enabled by compiled autograd as a way to signal to AotAutograd it
// should call the original FX graph rather than compiling.
bool compiled_autograd_tracing;
PyObject* compiled_autograd_backward_state;
std::vector<c10::SymInt> compiled_autograd_symints;
std::vector<torch::autograd::VariableInfo> output_info;

View File

@ -375,6 +375,10 @@ class CompiledNodeArgs {
return _compiler.emplace_hook(std::move(obj));
}
int add_backward_state(c10::SafePyObject&& obj) {
return _compiler.emplace_hook(std::move(obj));
}
void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
auto fn_id = _compiler.emplace_hook(std::move(obj));
collect_size(fn_id);

View File

@ -0,0 +1,27 @@
import torch.fx
class BackwardState:
"""
BackwardState is used to pass Python hooks from the forwards pass
into the backwards pass in Dynamo+Compiled Autograd.
It is created by TorchDynamo and has special handling there.
Dynamo will pass an empty BackwardState to the forwards, then populate
members on it (via setattr) only after the forwards graph is finished.
Later on, in CompileAutograd we will inline and add the needed guards
on the BackwardState.
BackwardState is identified and has special handling in AOTAutograd.
During AOTAutograd:
1) BackwardState is an input to the forwards graph
2) It must only be used in the backwards
3) It will be empty in the forwards
4) In the forwards we add a wrapper to save it
5) In the backwards it becomes an input
6) There can only be one per graph
BackwardState requires CompiledAutograd.
"""
proxy: torch.fx.Proxy

View File

@ -33,6 +33,7 @@ from torch.utils._python_dispatch import (
TorchDispatchMode,
)
from ._backward_state import BackwardState
from .sym_node import SymNode
from ._sym_dispatch_mode import SymDispatchMode
from torch.fx import Proxy
@ -42,6 +43,7 @@ from torch.utils.weak import WeakTensorKeyDictionary, WeakIdKeyDictionary, _Weak
from torch._ops import unset_mode_pre_dispatch, _set_mode_pre_dispatch, _get_dispatch_mode_pre_dispatch
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"]
aten = torch.ops.aten
prim = torch.ops.prim
@ -139,6 +141,8 @@ def extract_val(val):
return val
elif isinstance(val, torch.ScriptObject):
return val
elif isinstance(val, BackwardState):
return val
elif isinstance(val, (list, tuple)):
return val.__class__([extract_val(x) for x in val])
elif isinstance(val, torch.Tensor):
@ -232,6 +236,9 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
# example use case: triton_kernel_wrapper takes arguments as kwargs
for key, val in e.items():
wrap_with_proxy(val, proxy[key], None)
elif isinstance(e, BackwardState):
set_meta(proxy, e)
e.proxy = proxy
else:
# intentionally pass on primitives
pass

View File

@ -286,6 +286,8 @@ def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
def free_symbols(val: Union[SymInt, torch.Tensor]) -> Set[sympy.Symbol]:
if val is None:
return set()
itr = _iterate_exprs(val)
# we need at least 1 to call union, so we hand code the identity
try: