mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Compiled Autograd] Introduce BackwardState capture (#120382)
This adds support for backwards hooks that are *both*: 1) Interior to the graph; and 2) Dynamically generated (e.g. lambdas) We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382 Approved by: https://github.com/xmfan
This commit is contained in:
committed by
PyTorch MergeBot
parent
c016ffed5b
commit
01ec8df6d8
@ -59,19 +59,15 @@ class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
out = make_fx(_multiply_invoke)(x)
|
||||
self.assertEqual(out(x), torch.tensor([0.25, 0.25]))
|
||||
actual = normalize_gm(out.print_readable(False))
|
||||
|
||||
expected = """\
|
||||
self.assertExpectedInline(
|
||||
actual,
|
||||
"""\
|
||||
class _multiply_invoke(torch.nn.Module):
|
||||
def forward(self, grad_1: "f32[2]"):
|
||||
trace_wrapped: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
|
||||
assert_1: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(trace_wrapped, (2,), (1,), torch.float32); trace_wrapped = None
|
||||
detach: "f32[2]" = torch.ops.aten.detach.default(assert_1); assert_1 = None
|
||||
detach_1: "f32[2]" = torch.ops.aten.detach.default(detach); detach = None
|
||||
detach_2: "f32[2]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
||||
detach_3: "f32[2]" = torch.ops.aten.detach.default(detach_2); detach_2 = None
|
||||
return detach_3
|
||||
"""
|
||||
self.assertExpectedInline(actual, expected)
|
||||
return trace_wrapped
|
||||
""",
|
||||
)
|
||||
|
||||
def test_invoke_make_bw(self):
|
||||
x = torch.tensor([0.5, 0.5], requires_grad=True)
|
||||
@ -86,14 +82,15 @@ class _multiply_invoke(torch.nn.Module):
|
||||
self.assertEqual(out(x.grad), torch.tensor([4.0, 4.0]))
|
||||
actual = normalize_gm(out.print_readable(False))
|
||||
|
||||
expected = """\
|
||||
self.assertExpectedInline(
|
||||
actual,
|
||||
"""\
|
||||
class _multiply_invoke(torch.nn.Module):
|
||||
def forward(self, grad_1: "f32[2]"):
|
||||
trace_wrapped: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
|
||||
assert_1: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(trace_wrapped, (2,), (1,), torch.float32); trace_wrapped = None
|
||||
return assert_1
|
||||
"""
|
||||
self.assertExpectedInline(actual, expected)
|
||||
return trace_wrapped
|
||||
""",
|
||||
)
|
||||
|
||||
def test_invoke_in_pt2_compiled_autograd(self):
|
||||
graph = None
|
||||
|
@ -33,6 +33,11 @@ def global_hook_2(grad):
|
||||
h0 = None
|
||||
|
||||
|
||||
class ClassWithVal:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
|
||||
class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
def test_tensor_only_register_hook_in_graph_lambda(self):
|
||||
def fn(x):
|
||||
@ -517,10 +522,6 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
def test_register_hook_partial_guarding(
|
||||
self,
|
||||
):
|
||||
class SomePyClass:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def some_hook(grad, *, obj):
|
||||
return grad + obj.val
|
||||
|
||||
@ -533,8 +534,9 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return (z,)
|
||||
|
||||
mod = MyMod()
|
||||
obj1 = SomePyClass(88)
|
||||
obj2 = SomePyClass(99)
|
||||
obj1 = ClassWithVal(torch.tensor(88))
|
||||
obj2 = ClassWithVal(torch.tensor(99))
|
||||
obj3 = ClassWithVal(11)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
x0 = torch.ones(4, requires_grad=True)
|
||||
@ -543,34 +545,23 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj1)
|
||||
torch.compile(mod, backend=cnt, fullgraph=True)(x1, obj1)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
# New obj forces recompile (for now)
|
||||
# TODO(jansel): this behavor is bad, we should fix it so it doesn't happen
|
||||
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj2)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj3)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def test_hook_with_closure(self):
|
||||
class SomePyClass:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
def fn(x, obj):
|
||||
y = x.mul(2)
|
||||
|
||||
def hook1(grad):
|
||||
return grad + obj.val
|
||||
|
||||
x.register_hook(hook1)
|
||||
z = y.mul(3)
|
||||
y = x.sin()
|
||||
x.register_hook(lambda grad: grad + obj.val)
|
||||
z = y.sin()
|
||||
return z
|
||||
|
||||
opt = torch.compile(fn, backend=cnt, fullgraph=True)
|
||||
|
||||
obj1 = SomePyClass(88)
|
||||
obj2 = SomePyClass(99)
|
||||
cnt_fw = torch._dynamo.testing.CompileCounter()
|
||||
cnt_bw = torch._dynamo.testing.CompileCounter()
|
||||
opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
|
||||
|
||||
obj1 = ClassWithVal(torch.tensor(88))
|
||||
obj2 = ClassWithVal(torch.tensor(99))
|
||||
x0 = torch.ones(4, requires_grad=True)
|
||||
x1 = torch.ones(4, requires_grad=True)
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
@ -579,11 +570,72 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
fn(x1, obj2).sum().backward()
|
||||
|
||||
with compiled_autograd.enable(
|
||||
functools.partial(torch.compile, backend="eager")
|
||||
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
|
||||
):
|
||||
opt(x2, obj1).sum().backward()
|
||||
opt(x3, obj2).sum().backward()
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt_fw.frame_count, 1)
|
||||
self.assertEqual(cnt_bw.frame_count, 1)
|
||||
|
||||
self.assertEqual(x0.grad, x2.grad)
|
||||
self.assertEqual(x1.grad, x3.grad)
|
||||
|
||||
def test_intermediate_hook_with_closure_eager(self):
|
||||
def fn(x, obj):
|
||||
y = x.sin()
|
||||
y.register_hook(lambda grad: grad + obj.val)
|
||||
z = y.sin()
|
||||
return z
|
||||
|
||||
cnt_fw = torch._dynamo.testing.CompileCounter()
|
||||
cnt_bw = torch._dynamo.testing.CompileCounter()
|
||||
opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
|
||||
|
||||
obj1 = ClassWithVal(torch.tensor(88))
|
||||
obj2 = ClassWithVal(torch.tensor(99))
|
||||
x0 = torch.ones(4, requires_grad=True)
|
||||
x1 = torch.ones(4, requires_grad=True)
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
x3 = torch.ones(4, requires_grad=True)
|
||||
fn(x0, obj1).sum().backward()
|
||||
fn(x1, obj2).sum().backward()
|
||||
|
||||
with compiled_autograd.enable(
|
||||
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
|
||||
):
|
||||
opt(x2, obj1).sum().backward()
|
||||
opt(x3, obj2).sum().backward()
|
||||
self.assertEqual(cnt_fw.frame_count, 1)
|
||||
self.assertEqual(cnt_bw.frame_count, 1)
|
||||
|
||||
self.assertEqual(x0.grad, x2.grad)
|
||||
self.assertEqual(x1.grad, x3.grad)
|
||||
|
||||
def test_intermediate_hook_with_closure_aot(self):
|
||||
def fn(x, obj):
|
||||
y = x.sin()
|
||||
y.register_hook(lambda grad: grad + obj.val)
|
||||
z = y.sin()
|
||||
return z
|
||||
|
||||
cnt_bw = torch._dynamo.testing.CompileCounter()
|
||||
opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||
|
||||
obj1 = ClassWithVal(torch.tensor(88))
|
||||
obj2 = ClassWithVal(torch.tensor(99))
|
||||
x0 = torch.ones(4, requires_grad=True)
|
||||
x1 = torch.ones(4, requires_grad=True)
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
x3 = torch.ones(4, requires_grad=True)
|
||||
fn(x0, obj1).sum().backward()
|
||||
fn(x1, obj2).sum().backward()
|
||||
|
||||
with compiled_autograd.enable(
|
||||
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
|
||||
):
|
||||
opt(x2, obj1).sum().backward()
|
||||
opt(x3, obj2).sum().backward()
|
||||
self.assertEqual(cnt_bw.frame_count, 1)
|
||||
|
||||
self.assertEqual(x0.grad, x2.grad)
|
||||
self.assertEqual(x1.grad, x3.grad)
|
||||
@ -624,7 +676,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
comp_out = comp_mod(x1)
|
||||
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
comp_out[0].backward(torch.ones(4))
|
||||
self.assertEqual(x0.grad, x1.grad)
|
||||
|
||||
|
@ -133,6 +133,13 @@ uniform_qconfig_8bit = QConfig(
|
||||
qconfig_dict = {"object_type": [(torch.nn.Linear, uniform_qconfig_8bit)]}
|
||||
|
||||
|
||||
def closure_adder(val):
|
||||
def inner(x):
|
||||
return torch.sin(x + val)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class MiscTests(torch._dynamo.test_case.TestCase):
|
||||
def test_get_cache_entry(self):
|
||||
def f(x):
|
||||
@ -467,6 +474,25 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
||||
cleanup_op("mylib::foo")
|
||||
del lib
|
||||
|
||||
def test_closure_recompiles(self):
|
||||
cnt = CompileCounter()
|
||||
|
||||
def fn(x, other_fn):
|
||||
return other_fn(x + 1) - 1
|
||||
|
||||
opt = torch.compile(fn, backend=cnt, fullgraph=True)
|
||||
|
||||
x = torch.randn(8)
|
||||
for f in (
|
||||
closure_adder(5),
|
||||
closure_adder(5),
|
||||
closure_adder(torch.randn(8)),
|
||||
closure_adder(torch.randn(8)),
|
||||
):
|
||||
self.assertEqual(opt(x, f), fn(x, f))
|
||||
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_generate_trivial_abstract_impl(self):
|
||||
try:
|
||||
lib = torch.library.Library("mylib", "FRAGMENT")
|
||||
|
@ -1,11 +1,45 @@
|
||||
# Owner(s): ["oncall: pt2"]
|
||||
import dataclasses
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch._dynamo import compiled_autograd
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
|
||||
|
||||
class DistributedPatternTests(TestCase):
|
||||
def test_intermediate_hook_with_closure(self):
|
||||
@dataclasses.dataclass
|
||||
class CustomObj:
|
||||
val: torch.Tensor
|
||||
|
||||
def fn(x, obj):
|
||||
y = x.sin()
|
||||
closure_var = y + 1
|
||||
y.register_hook(lambda grad: grad + obj.val + closure_var)
|
||||
z = y.sin()
|
||||
return z
|
||||
|
||||
opt = torch.compile(fn, fullgraph=True)
|
||||
|
||||
obj1 = CustomObj(torch.tensor(88))
|
||||
obj2 = CustomObj(torch.tensor(99))
|
||||
x0 = torch.ones(4, requires_grad=True)
|
||||
x1 = torch.ones(4, requires_grad=True)
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
x3 = torch.ones(4, requires_grad=True)
|
||||
fn(x0, obj1).sum().backward()
|
||||
fn(x1, obj2).sum().backward()
|
||||
|
||||
with compiled_autograd.enable(functools.partial(torch.compile, fullgraph=True)):
|
||||
opt(x2, obj1).sum().backward()
|
||||
opt(x3, obj2).sum().backward()
|
||||
|
||||
self.assertEqual(x0.grad, x2.grad)
|
||||
self.assertEqual(x1.grad, x3.grad)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_storage_resize_zero(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
|
@ -1,11 +1,14 @@
|
||||
import torch
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
__all__ = ["trace_wrapped"]
|
||||
@ -41,10 +44,12 @@ __all__ = ["trace_wrapped"]
|
||||
# compiled autograd do we inline into the function.
|
||||
|
||||
|
||||
def trace_wrapped(*args, fn):
|
||||
return _trace_wrapped_op(*args, fn=fn)
|
||||
def trace_wrapped(*args, **kwargs):
|
||||
with torch.no_grad():
|
||||
return _trace_wrapped_op(*args, **kwargs)
|
||||
|
||||
|
||||
# TODO(jansel): need to ensure this does not get DCEed
|
||||
_trace_wrapped_op = HigherOrderOperator("trace_wrapped")
|
||||
|
||||
|
||||
@ -56,54 +61,51 @@ def _assert_meta(grad, size, stride, dtype):
|
||||
|
||||
|
||||
@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode)
|
||||
def inner_trace(mode, *args, fn):
|
||||
import torch
|
||||
def inner_trace(mode, *args, bw_state=None, **kwargs):
|
||||
def self_invoke(*args, **dyn_kwargs):
|
||||
with torch.no_grad():
|
||||
return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs)
|
||||
|
||||
assert len(args) == 1
|
||||
grad = args[0]
|
||||
assert isinstance(grad, torch.Tensor)
|
||||
def unwrap_proxies(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return mode.tracer.unwrap_proxy(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
return type(x)(map(unwrap_proxies, x))
|
||||
if x is None:
|
||||
return None
|
||||
raise AssertionError(f"unhandled type: {type(x)}")
|
||||
|
||||
def self_invoke(*args):
|
||||
return _trace_wrapped_op(*args, fn=fn)
|
||||
|
||||
proxy_args = (mode.tracer.unwrap_proxy(grad),)
|
||||
out_proxy = mode.tracer.create_proxy(
|
||||
"call_function", self_invoke, proxy_args, {}, name="trace_wrapped"
|
||||
)
|
||||
grad = torch.zeros_like(grad)
|
||||
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
|
||||
|
||||
# We have a little shortcut here, wherein we DO NOT yet run a meta func, and so
|
||||
# we take on an assumption that input and output meta matches. As such, we must introduce
|
||||
# a runtime assert
|
||||
proxy_args = (
|
||||
mode.tracer.unwrap_proxy(grad),
|
||||
grad.size(),
|
||||
grad.stride(),
|
||||
grad.dtype,
|
||||
)
|
||||
proxy_kwargs = {}
|
||||
if bw_state is not None:
|
||||
assert isinstance(bw_state, BackwardState) and bw_state.proxy is not None
|
||||
proxy_kwargs["bw_state"] = bw_state.proxy
|
||||
out_proxy = mode.tracer.create_proxy(
|
||||
"call_function",
|
||||
_assert_meta,
|
||||
proxy_args,
|
||||
{},
|
||||
name="assert",
|
||||
self_invoke,
|
||||
unwrap_proxies(args),
|
||||
proxy_kwargs,
|
||||
name="trace_wrapped",
|
||||
)
|
||||
grad = torch.empty_like(grad)
|
||||
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
|
||||
|
||||
if args[0] is None:
|
||||
grad = args[1] # module backward hooks
|
||||
else:
|
||||
grad = args[0] # other backward hooks
|
||||
grad = tree_map_only(torch.Tensor, torch.empty_like, grad)
|
||||
track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
|
||||
return grad
|
||||
|
||||
|
||||
@_trace_wrapped_op.py_impl(FakeTensorMode)
|
||||
def inner_fake(*args, fn):
|
||||
def inner_fake(*args, **kwargs):
|
||||
raise RuntimeError("This op should never be invoked here")
|
||||
|
||||
|
||||
@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def _trace_wrapped_op_dense(*args, fn):
|
||||
def _trace_wrapped_op_dense(*args, fn, **kwargs):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
||||
return fn(*args)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
_trace_wrapped_op.py_impl(DispatchKey.Autograd)(
|
||||
@ -112,7 +114,7 @@ _trace_wrapped_op.py_impl(DispatchKey.Autograd)(
|
||||
|
||||
|
||||
@_trace_wrapped_op.py_functionalize_impl
|
||||
def _trace_wrapped_functionalized(ctx, *args, fn):
|
||||
def _trace_wrapped_functionalized(ctx, *args, **kwargs):
|
||||
unwrapped_args = ctx.unwrap_tensors(args)
|
||||
with ctx.redispatch_to_next():
|
||||
return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, fn=fn))
|
||||
return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, **kwargs))
|
||||
|
@ -10,6 +10,7 @@ from torch._logging import getArtifactLogger, trace_structured
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
decompose,
|
||||
disable_autocast_cache,
|
||||
@ -223,6 +224,13 @@ class AutogradCompilerInstance:
|
||||
assert len(tensors) == len(proxies)
|
||||
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
|
||||
|
||||
def bind_backward_state(self, index: int):
|
||||
assert self.hooks_proxy is not None
|
||||
proxy = self.hooks_proxy[index] # type: ignore[index]
|
||||
bw_state = BackwardState()
|
||||
track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer)
|
||||
return bw_state
|
||||
|
||||
|
||||
compiled_autograd_enabled = False
|
||||
|
||||
|
@ -74,3 +74,7 @@ def call_backward(backward_fn, saved_tensors, *args):
|
||||
|
||||
def untyped_storage_size(x: torch.Tensor):
|
||||
return x.untyped_storage().size()
|
||||
|
||||
|
||||
def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs):
|
||||
return getattr(bw_state, hook_name)(*args, **kwargs)
|
||||
|
@ -30,6 +30,7 @@ from torch._guards import (
|
||||
)
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.sym_node import SymNode
|
||||
from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
@ -60,6 +61,7 @@ from .mutation_guard import is_dynamic_nn_module
|
||||
from .side_effects import SideEffects
|
||||
from .source import (
|
||||
AttrSource,
|
||||
BackwardStateSource,
|
||||
ConstantSource,
|
||||
GlobalStateSource,
|
||||
is_constant_source,
|
||||
@ -87,7 +89,13 @@ from .utils import (
|
||||
same,
|
||||
)
|
||||
from .variables.base import VariableTracker
|
||||
from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
|
||||
from .variables.builder import (
|
||||
BackwardStateGraphArg,
|
||||
GraphArg,
|
||||
TrackedFake,
|
||||
VariableBuilder,
|
||||
wrap_fx_proxy,
|
||||
)
|
||||
from .variables.nn_module import NNModuleVariable
|
||||
from .variables.tensor import (
|
||||
NumpyNdarrayVariable,
|
||||
@ -380,6 +388,29 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
] = []
|
||||
self.random_values_var = None
|
||||
|
||||
# Use to pass values to backward hooks when using compiled autograd
|
||||
self.backward_state: Dict[str, VariableTracker] = {}
|
||||
self.backward_state_proxy: Optional[torch.fx.Proxy] = None
|
||||
self.backward_state_var: Optional[str] = None
|
||||
|
||||
def add_backward_state_hook(self, hook: VariableTracker):
|
||||
name = f"hook{len(self.backward_state)}"
|
||||
assert name not in self.backward_state
|
||||
self.backward_state[name] = hook
|
||||
return name, self.get_backward_state_proxy()
|
||||
|
||||
def get_backward_state_proxy(self):
|
||||
if self.backward_state_proxy is None:
|
||||
if self.export:
|
||||
unimplemented("backward_state does not support export")
|
||||
self.backward_state_proxy = self.root_tracer.create_graph_input(
|
||||
"dynamo_backward_state", BackwardState, source=BackwardStateSource()
|
||||
)
|
||||
self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
|
||||
self.backward_state_proxy.node.meta["example_value"] = BackwardState()
|
||||
self.backward_state_var = self.new_var()
|
||||
return self.backward_state_proxy
|
||||
|
||||
# This gets its own helper function so guards DEBUG logs are more informative
|
||||
def init_ambient_guards(self):
|
||||
# Register a SHAPE_ENV guard to make sure we setup shape guards
|
||||
@ -924,6 +955,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
and all(isinstance(x, TensorVariable) for x in stack_values)
|
||||
and len(set(stack_values)) == len(stack_values)
|
||||
and self.side_effects.is_empty()
|
||||
and not self.backward_state
|
||||
):
|
||||
append_prefix_insts()
|
||||
# optimization to generate better code in a common case
|
||||
@ -934,10 +966,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
else:
|
||||
graph_output_var = self.new_var("graph_out")
|
||||
pass1 = PyCodegen(tx, root, graph_output_var)
|
||||
self.side_effects.codegen_hooks(pass1)
|
||||
self.side_effects.codegen_save_tempvars(pass1)
|
||||
pass1.restore_stack(stack_values, value_from_source=not tx.export)
|
||||
self.side_effects.codegen_update_mutated(pass1)
|
||||
self.codegen_suffix(tx, stack_values, pass1)
|
||||
|
||||
# one more time now that we have established tempvars
|
||||
pass2 = PyCodegen(
|
||||
@ -946,10 +975,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
graph_output_var,
|
||||
tempvars={val: None for val, count in pass1.uses.items() if count > 1},
|
||||
)
|
||||
self.side_effects.codegen_hooks(pass2)
|
||||
self.side_effects.codegen_save_tempvars(pass2)
|
||||
pass2.restore_stack(stack_values, value_from_source=not tx.export)
|
||||
self.side_effects.codegen_update_mutated(pass2)
|
||||
self.codegen_suffix(tx, stack_values, pass2)
|
||||
|
||||
output = []
|
||||
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
|
||||
@ -969,6 +995,18 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
|
||||
)
|
||||
|
||||
def codegen_suffix(self, tx, stack_values, cg):
|
||||
if self.backward_state:
|
||||
assert not self.export
|
||||
for name, val in self.backward_state.items():
|
||||
cg(val)
|
||||
cg.append_output(cg.create_load(self.backward_state_var))
|
||||
cg.store_attr(name)
|
||||
self.side_effects.codegen_hooks(cg)
|
||||
self.side_effects.codegen_save_tempvars(cg)
|
||||
cg.restore_stack(stack_values, value_from_source=not tx.export)
|
||||
self.side_effects.codegen_update_mutated(cg)
|
||||
|
||||
def cleanup_graph(self):
|
||||
"""
|
||||
Remove "creation_timestamp" from node meta
|
||||
@ -1243,11 +1281,15 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
if not node.users:
|
||||
recheck_placeholders.append(node)
|
||||
else:
|
||||
if not node.users:
|
||||
if not node.users and not isinstance(
|
||||
node.meta["grapharg"], BackwardStateGraphArg
|
||||
):
|
||||
remove_unused(node)
|
||||
else:
|
||||
# Register the free symbols as uses
|
||||
arg = node.meta["grapharg"]
|
||||
if isinstance(arg, BackwardStateGraphArg):
|
||||
continue
|
||||
fake = (
|
||||
arg.fake_tensor if arg.fake_tensor is not None else arg.example
|
||||
)
|
||||
|
@ -485,6 +485,15 @@ class ShapeEnvSource(Source):
|
||||
return GuardSource.SHAPE_ENV
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BackwardStateSource(Source):
|
||||
def name(self):
|
||||
return ""
|
||||
|
||||
def guard_source(self):
|
||||
return GuardSource.BACKWARD_STATE
|
||||
|
||||
|
||||
def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_local_source(
|
||||
|
@ -27,6 +27,7 @@ from torch._guards import GuardSource, TracingContext
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_constrain_range_for_size,
|
||||
DimDynamic,
|
||||
@ -213,6 +214,24 @@ class GraphArg:
|
||||
return self.source.name() == other.source.name()
|
||||
|
||||
|
||||
class BackwardStateGraphArg(GraphArg):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
source=None,
|
||||
_example=BackwardState(),
|
||||
is_unspecialized=False,
|
||||
fake_tensor=None,
|
||||
is_tensor=False,
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
assert codegen.tx.output.backward_state_var
|
||||
codegen.load_import_from(BackwardState.__module__, "BackwardState")
|
||||
codegen.call_function(0, True)
|
||||
codegen.dup_top()
|
||||
codegen.store(codegen.tx.output.backward_state_var)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FrameStateSizeEntry:
|
||||
scalar: Optional[int]
|
||||
|
@ -1443,18 +1443,19 @@ class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
||||
|
||||
class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable):
|
||||
"""
|
||||
Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace
|
||||
by unwrapping the higher order op and inlining through it. This op
|
||||
is created by dynamo to survive through AotAutograd, then unwrapped
|
||||
here in the call to dynamo from compiled autograd.
|
||||
"""
|
||||
|
||||
def call_function(
|
||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||
) -> "VariableTracker":
|
||||
from . import TensorVariable
|
||||
|
||||
assert "fn" in kwargs
|
||||
fn = kwargs["fn"]
|
||||
assert len(args) == 1
|
||||
grad = args[0]
|
||||
assert isinstance(grad, TensorVariable)
|
||||
|
||||
return fn.call_function(tx, args, {})
|
||||
kwargs = dict(kwargs)
|
||||
fn = kwargs.pop("fn")
|
||||
return fn.call_function(tx, args, kwargs)
|
||||
|
||||
|
||||
class AutogradFunctionApplyVariable(VariableTracker):
|
||||
|
@ -10,6 +10,7 @@ from typing import Dict, List
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from ..bytecode_transformation import create_call_method
|
||||
from ..external_utils import call_hook_from_backward_state
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
@ -769,7 +770,7 @@ class TensorVariable(VariableTracker):
|
||||
"register_post_accumulate_grad_hook", *args, **kwargs
|
||||
)
|
||||
|
||||
def _method_register_hook(self, name, hook):
|
||||
def _method_register_hook(self, name: str, hook: VariableTracker):
|
||||
# Note - do not arbitrarily add hooks here - make sure they match the same contract
|
||||
# see [On tensor.register_hook]
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
@ -797,14 +798,22 @@ class TensorVariable(VariableTracker):
|
||||
"Compilation of intermediate hooks requires compiled autograd"
|
||||
)
|
||||
|
||||
# This wraps our user provided fn with a function that intercedes and
|
||||
# uses our `invoke` higher order op to record a hook invocation in bwd graph.
|
||||
fn = functools.partial(trace_wrapped, fn=hook.guard_as_python_constant())
|
||||
hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook)
|
||||
|
||||
def _register_hook_trampoline(tensor):
|
||||
hook_callable = getattr(tensor, name)
|
||||
hook_callable(fn)
|
||||
return tensor
|
||||
def _register_hook_trampoline(tensor, bw_state):
|
||||
register_hook = getattr(tensor, name)
|
||||
register_hook(
|
||||
functools.partial(
|
||||
trace_wrapped,
|
||||
fn=call_hook_from_backward_state,
|
||||
bw_state=bw_state,
|
||||
hook_name=hook_name,
|
||||
)
|
||||
)
|
||||
# TODO(jansel): returning None here is wrong, it should be
|
||||
# RemovableHandle, but we need some extra work to support
|
||||
# this properly.
|
||||
return None
|
||||
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
@ -813,7 +822,7 @@ class TensorVariable(VariableTracker):
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
_register_hook_trampoline,
|
||||
(self.as_proxy(),),
|
||||
(self.as_proxy(), bw_state_proxy),
|
||||
{},
|
||||
),
|
||||
)
|
||||
|
@ -13,7 +13,6 @@ from functools import wraps
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils.dlpack
|
||||
from torch import Tensor
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
@ -21,6 +20,7 @@ from torch._guards import detect_fake_mode, tracing, TracingContext
|
||||
from torch._logging import getArtifactLogger, trace_structured
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.proxy_tensor import is_sym_node
|
||||
from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals
|
||||
from .. import config
|
||||
@ -158,10 +158,6 @@ def aot_dispatch_autograd(
|
||||
)
|
||||
|
||||
# Copied from aot_dispatch_autograd_graph.
|
||||
traced_tangents = pytree.tree_map(
|
||||
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
|
||||
fw_metadata.traced_tangents,
|
||||
)
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
|
||||
if aot_config.enable_log:
|
||||
@ -419,6 +415,11 @@ def aot_dispatch_autograd(
|
||||
|
||||
saved_context = TracingContext.try_get()
|
||||
|
||||
backward_state_indices = [
|
||||
idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
|
||||
]
|
||||
assert len(backward_state_indices) <= 1
|
||||
|
||||
class CompiledFunction(torch.autograd.Function):
|
||||
compiled_fw = compiled_fw_func
|
||||
compiled_bw = compiled_bw_func
|
||||
@ -434,6 +435,10 @@ def aot_dispatch_autograd(
|
||||
@staticmethod
|
||||
def forward(ctx, *deduped_flat_tensor_args):
|
||||
args = deduped_flat_tensor_args
|
||||
if backward_state_indices:
|
||||
bw_state = args[backward_state_indices[0]]
|
||||
assert isinstance(bw_state, BackwardState)
|
||||
ctx._compiled_autograd_backward_state = bw_state
|
||||
|
||||
marked_dirty_inps = []
|
||||
for i in fw_metadata.mutated_graph_handled_indices_seen_by_autograd:
|
||||
@ -457,8 +462,6 @@ def aot_dispatch_autograd(
|
||||
|
||||
num_outputs = CompiledFunction.metadata.num_outputs
|
||||
num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased
|
||||
num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases
|
||||
num_symints_saved_for_bw = CompiledFunction.num_symints_saved_for_bw
|
||||
num_mutated_runtime_inps = (
|
||||
CompiledFunction.metadata.num_mutated_inp_runtime_indices
|
||||
)
|
||||
@ -742,6 +745,9 @@ Got grad_output types: {str(grad_output_types)}"""
|
||||
symints = ctx._get_compiled_autograd_symints()
|
||||
assert len(symints) == len(ctx.symints)
|
||||
all_args[: len(symints)] = symints
|
||||
if backward_state_indices:
|
||||
assert ctx._compiled_autograd_backward_state.proxy is not None
|
||||
all_args.append(ctx._compiled_autograd_backward_state)
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
with context():
|
||||
out = normalize_as_list(bw_module(*all_args))
|
||||
@ -749,6 +755,9 @@ Got grad_output types: {str(grad_output_types)}"""
|
||||
CompiledFunction.metadata, out
|
||||
)
|
||||
return tuple(out)
|
||||
assert (
|
||||
not backward_state_indices
|
||||
), "BackwardState requires CompiledAutograd"
|
||||
ctx.maybe_clear_saved_tensors()
|
||||
if CompiledFunction.compiled_bw is None:
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
|
@ -10,9 +10,19 @@ from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.proxy_tensor import py_sym_types
|
||||
|
||||
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types)
|
||||
KNOWN_TYPES = [
|
||||
torch.Tensor,
|
||||
BackwardState,
|
||||
int,
|
||||
str,
|
||||
float,
|
||||
bool,
|
||||
type(None),
|
||||
*py_sym_types,
|
||||
]
|
||||
|
||||
original_zip = zip
|
||||
|
||||
|
@ -5,6 +5,7 @@ from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
hint_int, free_symbols, is_symbol_binding_fx_node, find_symbol_binding_fx_nodes
|
||||
)
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
import operator
|
||||
@ -21,6 +22,7 @@ from .compile_utils import fx_graph_cse, get_aten_target
|
||||
from . import config
|
||||
import functools
|
||||
|
||||
|
||||
AOT_PARTITIONER_DEBUG = config.debug_partitioner
|
||||
|
||||
|
||||
@ -124,6 +126,9 @@ def _is_bwd_seed_offset(node):
|
||||
def _is_fwd_seed_offset(node):
|
||||
return node.op == "placeholder" and ("fwd_seed" in node.target or "fwd_base_offset" in node.target)
|
||||
|
||||
def _is_backward_state(node):
|
||||
return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState)
|
||||
|
||||
|
||||
def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs):
|
||||
outputs = pytree.arg_tree_leaves(*(node.args for node in joint_module.graph.nodes if node.op == 'output'))
|
||||
@ -132,38 +137,49 @@ def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs):
|
||||
return fwd_outputs, bwd_outputs
|
||||
|
||||
|
||||
def _remove_by_name(saved_values, name):
|
||||
for saved_value in saved_values:
|
||||
if saved_value.name == name:
|
||||
saved_values.remove(saved_value)
|
||||
break
|
||||
|
||||
def _placeholders(nodes):
|
||||
# Avoid making an entire pass over the graph if we only care about the input placeholders
|
||||
result = []
|
||||
for node in nodes:
|
||||
if node.op == 'placeholder':
|
||||
result.append(node)
|
||||
else:
|
||||
break # placeholders are all at the start of graph
|
||||
return result
|
||||
|
||||
|
||||
def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs):
|
||||
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
bwd_seed_offset_inputs = list(filter(_is_bwd_seed_offset, joint_module.graph.nodes))
|
||||
placeholders = _placeholders(joint_module.graph.nodes)
|
||||
primal_inputs = [*filter(_is_primal, placeholders)]
|
||||
tangent_inputs = [*filter(_is_tangent, placeholders)]
|
||||
fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
|
||||
bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
|
||||
backward_state_inputs = [*filter(_is_backward_state, placeholders)]
|
||||
|
||||
# Construct the forward module
|
||||
# Keep symints separate from tensors, passed between fwd/bwd graphs, and in the right order.
|
||||
fwd_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph,
|
||||
primal_inputs + fwd_seed_offset_inputs,
|
||||
fwd_outputs + saved_values + saved_sym_nodes
|
||||
)
|
||||
bwd_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph,
|
||||
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
|
||||
bwd_outputs
|
||||
)
|
||||
|
||||
# This is to filter out saved values that don't actually end up being used by the backwards pass
|
||||
for node in bwd_graph.nodes:
|
||||
if node.op == 'placeholder' and not node.users:
|
||||
for saved_value in saved_values:
|
||||
if saved_value.name == node.name:
|
||||
saved_values.remove(saved_value)
|
||||
break
|
||||
for node in _placeholders(bwd_graph.nodes):
|
||||
assert node.op == 'placeholder'
|
||||
# This is to filter out saved values that don't actually end up being used by the backwards pass
|
||||
if not node.users:
|
||||
_remove_by_name(saved_values, node.name)
|
||||
_remove_by_name(saved_sym_nodes, node.name)
|
||||
elif _is_backward_state(node):
|
||||
# BackwardState is saved directly
|
||||
_remove_by_name(saved_values, node.name)
|
||||
assert backward_state_inputs
|
||||
|
||||
for saved_sym in saved_sym_nodes:
|
||||
if saved_sym.name == node.name:
|
||||
saved_sym_nodes.remove(saved_sym)
|
||||
break
|
||||
|
||||
# Now that we have the finalized list of saved values, we need to ensure
|
||||
# we propagate all symbols which are referenced by backwards inputs.
|
||||
@ -216,7 +232,7 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_s
|
||||
)
|
||||
bwd_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph,
|
||||
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
|
||||
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs + backward_state_inputs,
|
||||
bwd_outputs
|
||||
)
|
||||
|
||||
@ -863,7 +879,7 @@ def min_cut_rematerialization_partition(
|
||||
if is_sym_node(node):
|
||||
weight = sym_node_size(node)
|
||||
elif is_non_tensor_node:
|
||||
weight = math.inf
|
||||
weight = 0 if isinstance(node.meta.get("val"), BackwardState) else math.inf
|
||||
else:
|
||||
weight = get_node_weight(node)
|
||||
|
||||
|
@ -86,6 +86,7 @@ class GuardSource(enum.Enum):
|
||||
SHAPE_ENV = 6
|
||||
LOCAL_FSDP_MODULE = 7
|
||||
GLOBAL_FSDP_MODULE = 8
|
||||
BACKWARD_STATE = 9
|
||||
|
||||
def is_fsdp_module(self) -> bool:
|
||||
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
|
||||
|
@ -546,9 +546,10 @@ class WrapperCodeGen(CodeGen):
|
||||
with self.prefix.indent():
|
||||
if config.triton.debug_sync_graph:
|
||||
self.prefix.writeline(V.graph.device_ops.synchronize())
|
||||
inp_len = len(V.graph.graph_inputs.keys())
|
||||
if inp_len != 0:
|
||||
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
|
||||
if V.graph.graph_inputs:
|
||||
lhs = ", ".join(V.graph.graph_input_names)
|
||||
if len(V.graph.graph_input_names) == 1:
|
||||
lhs += ","
|
||||
self.prefix.writeline(f"{lhs} = args")
|
||||
self.prefix.writeline("args.clear()")
|
||||
|
||||
|
@ -18,6 +18,7 @@ from torch._decomp import get_decompositions
|
||||
from torch._dynamo.utils import defake, dynamo_timed
|
||||
from torch._logging import LazyString, trace_structured
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
||||
from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
@ -236,6 +237,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.reuse_shape_env = True
|
||||
self._shape_env = shape_env
|
||||
self.sizevars = SizeVarAllocator(shape_env)
|
||||
self.graph_input_names: List[str] = []
|
||||
self.graph_inputs: Dict[str, TensorBox] = {}
|
||||
self.graph_inputs_original: Dict[str, InputBuffer] = {}
|
||||
self.device_types: Set[str] = (
|
||||
@ -718,6 +720,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
def placeholder(self, target: str, args, kwargs):
|
||||
example = super().placeholder(target, args, kwargs)
|
||||
self.graph_input_names.append(target)
|
||||
if isinstance(example, SymTypes):
|
||||
expr = example.node.expr
|
||||
self.graph_inputs[target] = expr
|
||||
@ -726,6 +729,10 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
expr = sympy.sympify(example)
|
||||
self.graph_inputs[target] = expr
|
||||
return expr
|
||||
if isinstance(example, BackwardState):
|
||||
# Ignored arg, must be unused
|
||||
# Alternately we could filter this out in AotAutograd
|
||||
return None
|
||||
assert isinstance(example, torch.Tensor), example
|
||||
# todo(chilli): We can remove the last check once we turn buffers into
|
||||
# static shape tensors. That's a hack to workaround Inductor believing
|
||||
|
@ -131,8 +131,7 @@ PyObject* to_py_size(const std::vector<c10::SymInt>& size) {
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace torch::autograd {
|
||||
|
||||
// NOTE: this function is written in a way that assumes it's only called for
|
||||
// backward; it's used by engine.cpp. This is responsible for forwarding a call
|
||||
@ -237,6 +236,7 @@ auto PyNode::compiled_apply(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
_backward_idx.has_value(),
|
||||
"indices should already be set by compiled_args, called before apply_with_saved");
|
||||
TORCH_INTERNAL_ASSERT(!_backward_state_idx.has_value());
|
||||
THPObjectPtr r(PyObject_CallMethod(
|
||||
*compiler,
|
||||
"proxy_call_backward",
|
||||
@ -302,7 +302,7 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
|
||||
throw_python_error();
|
||||
TORCH_CHECK(
|
||||
PyTuple_CheckExact(pykey.get()),
|
||||
"_compiled_autograd_key shoud return tuple of ints");
|
||||
"_compiled_autograd_key should return tuple of ints");
|
||||
auto size = PyTuple_GET_SIZE(pykey.get());
|
||||
TORCH_INTERNAL_ASSERT(size > 0);
|
||||
// first value is unique ID of the AotAutograd graph
|
||||
@ -345,6 +345,13 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
|
||||
PyObject* backward(PyObject_GetAttr(forward_cls.get(), backward_name));
|
||||
_backward_idx =
|
||||
args.add_backward(c10::SafePyObject(backward, getPyInterpreter()));
|
||||
|
||||
PyObject* bw_state = f->compiled_autograd_backward_state;
|
||||
if (args.cond(bw_state != nullptr)) {
|
||||
Py_INCREF(bw_state);
|
||||
_backward_state_idx = args.add_backward_state(
|
||||
c10::SafePyObject(bw_state, getPyInterpreter()));
|
||||
}
|
||||
}
|
||||
|
||||
variable_list PyNode::apply_with_saved(
|
||||
@ -361,7 +368,23 @@ variable_list PyNode::apply_with_saved(
|
||||
f->compiled_autograd_tracing = true;
|
||||
variable_list result;
|
||||
if (!compiled_autograd_should_lift()) {
|
||||
result = apply(variable_list(inputs));
|
||||
if (_backward_state_idx.has_value()) {
|
||||
PyObject* r = PyObject_CallMethod(
|
||||
saved.get_py_compiler(),
|
||||
"bind_backward_state",
|
||||
"i",
|
||||
*_backward_state_idx);
|
||||
if (r == nullptr) {
|
||||
throw python_error();
|
||||
}
|
||||
THPObjectPtr prior(f->compiled_autograd_backward_state);
|
||||
f->compiled_autograd_backward_state = r;
|
||||
result = apply(variable_list(inputs));
|
||||
Py_CLEAR(f->compiled_autograd_backward_state);
|
||||
f->compiled_autograd_backward_state = prior.release();
|
||||
} else {
|
||||
result = apply(variable_list(inputs));
|
||||
}
|
||||
} else {
|
||||
result = compiled_apply(variable_list(inputs), saved.get_py_compiler());
|
||||
}
|
||||
@ -445,8 +468,7 @@ variable_list PyNode::to_variable_list(
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
} // namespace torch::autograd
|
||||
|
||||
// Traverse and clear are required for supporting Python's GC cycle handling.
|
||||
static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
|
||||
@ -455,6 +477,7 @@ static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->to_save);
|
||||
Py_VISIT(self->non_differentiable);
|
||||
Py_VISIT(self->dirty_tensors);
|
||||
Py_VISIT(self->compiled_autograd_backward_state);
|
||||
Py_VISIT(self->saved_for_forward);
|
||||
return 0;
|
||||
}
|
||||
@ -469,6 +492,7 @@ static int THPFunction_clear(THPFunction* self) {
|
||||
Py_CLEAR(self->to_save);
|
||||
Py_CLEAR(self->non_differentiable);
|
||||
Py_CLEAR(self->dirty_tensors);
|
||||
Py_CLEAR(self->compiled_autograd_backward_state);
|
||||
Py_CLEAR(self->saved_for_forward);
|
||||
|
||||
self->output_info.clear();
|
||||
@ -1504,6 +1528,33 @@ PyObject* THPFunction_get_compiled_autograd_symints(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPFunction_get_compiled_autograd_backward_state(
|
||||
PyObject* _self,
|
||||
void* _unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = (THPFunction*)_self;
|
||||
PyObject* bw_state = self->compiled_autograd_backward_state;
|
||||
if (bw_state == nullptr) {
|
||||
bw_state = Py_None;
|
||||
}
|
||||
Py_INCREF(bw_state);
|
||||
return bw_state;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
int THPFunction_set_compiled_autograd_backward_state(
|
||||
PyObject* _self,
|
||||
PyObject* bw_state,
|
||||
void* _unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = (THPFunction*)_self;
|
||||
TORCH_INTERNAL_ASSERT(self->compiled_autograd_backward_state == nullptr);
|
||||
Py_INCREF(bw_state);
|
||||
self->compiled_autograd_backward_state = bw_state;
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
PyObject* THPFunction_raw_saved_tensors(THPFunction* self, void* _unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
// User tries to access saved variables after they have been freed
|
||||
@ -1684,6 +1735,11 @@ static struct PyGetSetDef THPFunction_properties[] = {
|
||||
(setter)THPFunction_set_materialize_non_diff_grads,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"_compiled_autograd_backward_state",
|
||||
(getter)THPFunction_get_compiled_autograd_backward_state,
|
||||
(setter)THPFunction_set_compiled_autograd_backward_state,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{nullptr}};
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
@ -56,6 +56,10 @@ struct PyNode : public Node {
|
||||
// The AutogradCompilerCall::hooks idx corresponding to this node's backward
|
||||
std::optional<int> _backward_idx;
|
||||
|
||||
// The AutogradCompilerCall::hooks idx corresponding to this node's
|
||||
// backward_state
|
||||
std::optional<int> _backward_state_idx;
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
~PyNode() override {
|
||||
// Can't use THPObjectPtr as a field in this class; destructor won't take
|
||||
@ -121,6 +125,7 @@ struct THPFunction {
|
||||
// This is enabled by compiled autograd as a way to signal to AotAutograd it
|
||||
// should call the original FX graph rather than compiling.
|
||||
bool compiled_autograd_tracing;
|
||||
PyObject* compiled_autograd_backward_state;
|
||||
std::vector<c10::SymInt> compiled_autograd_symints;
|
||||
|
||||
std::vector<torch::autograd::VariableInfo> output_info;
|
||||
|
@ -375,6 +375,10 @@ class CompiledNodeArgs {
|
||||
return _compiler.emplace_hook(std::move(obj));
|
||||
}
|
||||
|
||||
int add_backward_state(c10::SafePyObject&& obj) {
|
||||
return _compiler.emplace_hook(std::move(obj));
|
||||
}
|
||||
|
||||
void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
|
||||
auto fn_id = _compiler.emplace_hook(std::move(obj));
|
||||
collect_size(fn_id);
|
||||
|
27
torch/fx/experimental/_backward_state.py
Normal file
27
torch/fx/experimental/_backward_state.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch.fx
|
||||
|
||||
|
||||
class BackwardState:
|
||||
"""
|
||||
BackwardState is used to pass Python hooks from the forwards pass
|
||||
into the backwards pass in Dynamo+Compiled Autograd.
|
||||
|
||||
It is created by TorchDynamo and has special handling there.
|
||||
Dynamo will pass an empty BackwardState to the forwards, then populate
|
||||
members on it (via setattr) only after the forwards graph is finished.
|
||||
Later on, in CompileAutograd we will inline and add the needed guards
|
||||
on the BackwardState.
|
||||
|
||||
BackwardState is identified and has special handling in AOTAutograd.
|
||||
During AOTAutograd:
|
||||
1) BackwardState is an input to the forwards graph
|
||||
2) It must only be used in the backwards
|
||||
3) It will be empty in the forwards
|
||||
4) In the forwards we add a wrapper to save it
|
||||
5) In the backwards it becomes an input
|
||||
6) There can only be one per graph
|
||||
|
||||
BackwardState requires CompiledAutograd.
|
||||
"""
|
||||
|
||||
proxy: torch.fx.Proxy
|
@ -33,6 +33,7 @@ from torch.utils._python_dispatch import (
|
||||
TorchDispatchMode,
|
||||
)
|
||||
|
||||
from ._backward_state import BackwardState
|
||||
from .sym_node import SymNode
|
||||
from ._sym_dispatch_mode import SymDispatchMode
|
||||
from torch.fx import Proxy
|
||||
@ -42,6 +43,7 @@ from torch.utils.weak import WeakTensorKeyDictionary, WeakIdKeyDictionary, _Weak
|
||||
from torch._ops import unset_mode_pre_dispatch, _set_mode_pre_dispatch, _get_dispatch_mode_pre_dispatch
|
||||
|
||||
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"]
|
||||
|
||||
aten = torch.ops.aten
|
||||
prim = torch.ops.prim
|
||||
|
||||
@ -139,6 +141,8 @@ def extract_val(val):
|
||||
return val
|
||||
elif isinstance(val, torch.ScriptObject):
|
||||
return val
|
||||
elif isinstance(val, BackwardState):
|
||||
return val
|
||||
elif isinstance(val, (list, tuple)):
|
||||
return val.__class__([extract_val(x) for x in val])
|
||||
elif isinstance(val, torch.Tensor):
|
||||
@ -232,6 +236,9 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
|
||||
# example use case: triton_kernel_wrapper takes arguments as kwargs
|
||||
for key, val in e.items():
|
||||
wrap_with_proxy(val, proxy[key], None)
|
||||
elif isinstance(e, BackwardState):
|
||||
set_meta(proxy, e)
|
||||
e.proxy = proxy
|
||||
else:
|
||||
# intentionally pass on primitives
|
||||
pass
|
||||
|
@ -286,6 +286,8 @@ def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
|
||||
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
|
||||
|
||||
def free_symbols(val: Union[SymInt, torch.Tensor]) -> Set[sympy.Symbol]:
|
||||
if val is None:
|
||||
return set()
|
||||
itr = _iterate_exprs(val)
|
||||
# we need at least 1 to call union, so we hand code the identity
|
||||
try:
|
||||
|
Reference in New Issue
Block a user