[aotd] Support saved tensors hooks in aot_autograd (#150032)

https://github.com/pytorch/pytorch/issues/148222

Goal:

At the moment autograd saved tensors hooks are run in eager after compiled forward.
They are executed at the same time for all saved tensors.
Hooks can be used to reduce amout of memory used for saved tensors, doing quantization or offloading to cpu.
This is suboptimal for optimization of peak memory.
Better solution will be to put the hooks in the graph, as close as possible to the last usage of the tensor.

To get user specified autograd saved tensors hooks in the graph.

Logic:

UX:
If user specifies with torch.autograd.graph.saved_tensors_hooks(pack_gm, unpack_gm).
Where pack_gm and unpack_gm are torch.fx.GraphModule.
Then AotAutograd will retrace those graph modules, doing decompositions and functionalization in aot_autograd, inlining the result graphs in forward epilogue and backward prologue.

User may want to use control logic in the hooks, for example applying quantization only for specific dtypes and sizes.

This is also possible, user can put it into torch.fx.wrap function and use symbolic trace to make a GraphModule.

In that case AotAutograd cahing will work only in case when user explicitly set to the torch.fx.wrap call_function node "user_cache_hash" metadata.

If this metadata set - then aot_autograd cache can use saved cache artifact.
If metadata is not set - then cache is bypassed.

Dynamo:
Dynamo traces pack and unpack hooks and installs them as subgraph and explicitly adds to the output_graph. (As those subgraphs are not used and will not be copied in the result by default).

The complexity here is that at this moment we do not have example of inputs for the hooks.
We trace  pack_hook with some Tensor from the inputs.
The result subgraphs are added to the hashing of AotAutograd Cache.

In AotAutograd we retrace the graph with the true saved tensors coming from partitioner.

Backwards Compatibility:
As current hooks are executed in eager mode and not all of them will be traceable - we only try to put in the graph hooks, explicitly marked by user with annotation (@_inlineable_saved_tensors_hooks).
For other hooks or if compiled autograd is enabled - keep the same logic.

Recompilations:
Hooks are guarded with lambda guard matching function id to cause recompilation if user reruns compiled function.

Aot_autograd:
After partitioner prepared forward and backward module - we trace prepared at Dynamo graphs for pack and unpack hooks and inline them in epilogue of forward and prologue of backward. Forward outputs and backward inputs are changed, transparently for user.

We do not try to put it close the last usage etc., relying on inductor to do this optimization.

```
INFO: TRACED GRAPH
 ===== Forward graph pre saved_tensors_hooks inlining 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
        return (view, add, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== Backward graph pre saved_tensors_hooks inlining 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
        return (view, add, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== saved_tensors_pack_hook add 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class pack_float8(torch.nn.Module):
    def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn);  x_1 = None
        return (torch.float32, _to_copy)

INFO: TRACED GRAPH
 ===== saved_tensors_unpack_hook add 3 =====
 <eval_with_key>.22 from /data/users/ivankobzarev/a/pytorch/torch/fx/experimental/proxy_tensor.py:1225 in wrapped class pack_float8(torch.nn.Module):
    def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn);  x_1 = None
        return (torch.float32, _to_copy)

INFO: TRACED GRAPH
 ===== Forward graph 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add, dtype = torch.float8_e4m3fn)

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2]);  add = None
        return (view, _to_copy, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== Backward graph 3 =====
 <eval_with_key>.21 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", add_packed_2: "f8e4m3fn[s0, s1][s1, 1]cuda:0", tangents_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add_packed_2, dtype = torch.float32);  add_packed_2 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        add_7: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(tangents_1, _to_copy);  tangents_1 = _to_copy = None
        return (None, None, add_7)

```

Differential Revision: [D72187044](https://our.internmc.facebook.com/intern/diff/D72187044)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150032
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev
2025-05-22 02:54:13 -07:00
committed by PyTorch MergeBot
parent f12d8d60b1
commit 4439255148
18 changed files with 1602 additions and 15 deletions

View File

@ -26,9 +26,9 @@ bool SavedTensorDefaultHooks::is_enabled() {
return !tls.disabled_error_message.has_value();
}
void SavedTensorDefaultHooks::disable(const std::string& message) {
void SavedTensorDefaultHooks::disable(const std::string& message, const bool fail_if_non_empty) {
tls.disabled_error_message = message;
if (!tls.stack.empty()) {
if (fail_if_non_empty && !tls.stack.empty()) {
assertSavedTensorHooksNotDisabled();
}
}
@ -72,9 +72,9 @@ std::pair<SafePyObject, SafePyObject> SavedTensorDefaultHooks::pop_hooks() {
return hooks;
}
std::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks() {
std::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks(bool ignore_is_tracing) {
// For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime]
if (!is_initialized || tls.stack.empty() || tls.is_tracing) {
if (!is_initialized || tls.stack.empty() || (!ignore_is_tracing && tls.is_tracing)) {
return std::nullopt;
}
return tls.stack.top();

View File

@ -36,7 +36,7 @@ struct TORCH_API SavedTensorDefaultHooks {
c10::SafePyObject unpack_hook);
static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
get_hooks();
get_hooks(bool ignore_is_tracing = false);
static void lazy_initialize();
static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
@ -48,7 +48,9 @@ struct TORCH_API SavedTensorDefaultHooks {
// disabled, then the following will raise an error:
// - Attempting to push_hooks
// - calling disable(message) with a non-zero stack (hooks) size
static void disable(const std::string& error_message);
static void disable(
const std::string& error_message,
const bool fail_if_non_empty = true);
static void enable();
static bool is_enabled();
static const std::optional<std::string>& get_disabled_error_message();

View File

@ -40,6 +40,76 @@ from torch.testing._internal.triton_utils import requires_cuda
from torch.testing._internal.two_tensor import TwoTensor
def saved_tensors_hooks_to_gm(
pack_fn,
unpack_fn,
pack_cache_hash=None,
unpack_cache_hash=None,
symbolic_tracing=True,
inp_fn=None,
):
if symbolic_tracing:
pack_gm = torch.fx.symbolic_trace(pack_fn)
unpack_gm = torch.fx.symbolic_trace(unpack_fn)
else:
from functorch import make_fx
if inp_fn:
inp = inp_fn()
else:
inp = torch.randn(2, 3)
torch._dynamo.mark_dynamic(inp, 0)
torch._dynamo.mark_dynamic(inp, 1)
pack_out = pack_fn(inp)
pack_gm = make_fx(pack_fn)(inp)
unpack_gm = make_fx(unpack_fn)(pack_out)
def set_manual_hash(g, manual_hash):
for node in g.nodes:
if node.meta and node.meta.get("is_wrapped", False):
node.meta["user_cache_hash"] = manual_hash
if pack_cache_hash:
set_manual_hash(pack_gm.graph, pack_cache_hash)
if unpack_cache_hash:
set_manual_hash(unpack_gm.graph, unpack_cache_hash)
return pack_gm, unpack_gm
def amax_to_scale(
amax: torch.Tensor,
float8_dtype: torch.dtype,
round_scales_to_power_of_2: bool = False,
):
amax = amax.to(torch.float64)
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=1e-12)
res = res.to(torch.float32)
return res
# Must be at module level to use fx.wrap
@torch.fx.wrap
def _pack_fp8_with_scale_wrap(x):
if not x.dtype.is_floating_point:
return x
amax = torch.max(torch.abs(x))
scale = amax_to_scale(amax, torch.float8_e5m2)
x_scaled = x.to(torch.float32) * scale
x_fp8 = x_scaled.to(torch.float8_e5m2)
return x.dtype, scale, x_fp8
@torch.fx.wrap
def _unpack_fp8_with_scale_wrap(x):
if isinstance(x, torch.Tensor):
return x
dtype, scale, x_fp8 = x
y = x_fp8.to(torch.float32) / scale
return y.to(dtype)
@instantiate_parametrized_tests
class AOTAutogradCacheTests(InductorTestCase):
def setUp(self):
@ -1099,6 +1169,188 @@ class AOTAutogradCacheTests(InductorTestCase):
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@unittest.skipIf(not SM80OrLater, "bfloat16, float8")
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"activation_memory_budget": 1.0})
@functorch_config.patch({"activation_memory_budget_runtime_estimator": "testing"})
@functorch_config.patch({"saved_tensors_hooks_filtering_mode": "all"})
def test_saved_tensors_hooks_autograd_cache(self):
ctx = torch.autograd.graph.saved_tensors_hooks
device = torch.device("cuda:0")
def pack_cpu(x):
return x.to(device="cpu")
def unpack_cpu(x):
return x.to(device=device)
def pack_cpu2(x):
return x.to(device="cpu")
def unpack_cpu2(x):
return x.to(device=device)
def pack_mul2(x):
return x * 2
def unpack_mul2(x):
return x / 2
# Can not use custom AutogradFunction here,
# Cache bypasses AutogradFunction Ctx usage.
# Can not save in ctx non floating point dtypes.
# For non-symbolic tracing all dtypes and devices and burned in the graph.
def fn(x):
x = x + 1
x = x.sin().cos()
x = x.relu()
x = x.exp()
x = 2 * x
return x
backend = "inductor"
def inp_fn():
x = torch.ones(2, 3, device=device, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
return x
x = inp_fn()
fn_compiled = torch.compile(fn, backend=backend, fullgraph=True)
y = fn_compiled(x)
y.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
with ctx(
*saved_tensors_hooks_to_gm(
pack_cpu,
unpack_cpu,
symbolic_tracing=False,
inp_fn=inp_fn,
pack_cache_hash="cpu_offload",
unpack_cache_hash="cpu_offload",
)
):
x = inp_fn()
y = fn_compiled(x)
y.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
with ctx(
*saved_tensors_hooks_to_gm(
pack_cpu2,
unpack_cpu2,
symbolic_tracing=False,
inp_fn=inp_fn,
pack_cache_hash="cpu_offload",
unpack_cache_hash="cpu_offload",
)
):
x = inp_fn()
y = fn_compiled(x)
y.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
with ctx(
*saved_tensors_hooks_to_gm(pack_mul2, unpack_mul2, symbolic_tracing=False)
):
x = inp_fn()
y = fn_compiled(x)
y.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 3)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 3)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@unittest.skipIf(not SM80OrLater, "bfloat16, float8")
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
def test_saved_tensors_hooks_autograd_cache_symbolic(self):
def pack_fp8_with_scale(x):
return _pack_fp8_with_scale_wrap(x)
def unpack_fp8_with_scale(packed):
return _unpack_fp8_with_scale_wrap(packed)
ctx = torch.autograd.graph.saved_tensors_hooks
def fn(x):
x = x + 1
# Relu saves bitmask in AutogradContext
x = x.relu()
x = x.relu()
return x
device = torch.device("cuda:0")
backend = "inductor"
def inp_fn():
x = torch.ones(2, 3, device=device, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
return x
x = inp_fn()
fn_compiled = torch.compile(fn, backend=backend, fullgraph=True)
y = fn_compiled(x)
y.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
with ctx(
*saved_tensors_hooks_to_gm(
pack_fp8_with_scale,
unpack_fp8_with_scale,
"fp8_with_scale_dtype_floating_point",
"fp8_with_scale_dtype_floating_point",
)
):
x = inp_fn()
y = fn_compiled(x)
y.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
with ctx(
*saved_tensors_hooks_to_gm(
pack_fp8_with_scale,
unpack_fp8_with_scale,
"fp8_with_scale_dtype_floating_point",
"fp8_with_scale_dtype_floating_point",
)
):
x = inp_fn()
y = fn_compiled(x)
y.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
with ctx(
*saved_tensors_hooks_to_gm(
pack_fp8_with_scale,
unpack_fp8_with_scale,
"fp8_with_scale_dtype_floating_point_MISS",
"fp8_with_scale_dtype_floating_point_MISS",
)
):
x = inp_fn()
y = fn_compiled(x)
y.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 3)
@functorch_config.patch({"bundled_autograd_cache": True})
class AOTAutogradCacheBundledTests(AOTAutogradCacheTests):

View File

@ -309,6 +309,13 @@ y = FakeTensor(..., size=(2,))
'obj_weakref': None
'guarded_class': None
}
global '' AUTOGRAD_SAVED_TENSORS_HOOKS
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
global '' GRAD_MODE
{
'guard_types': None,

View File

@ -710,6 +710,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
self.assertExpectedInline(
munge_shape_guards(record.getMessage()),
"""\
| +- __SHAPE_GUARD__: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # #:# in #
+- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
+- __SHAPE_GUARD__: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
+- __SHAPE_GUARD__: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in #
@ -728,6 +729,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
self.assertExpectedInline(
munge_shape_guards(record.getMessage()),
"""\
| +- __SHAPE_GUARD__: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # #:# in #
+- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # return any([x.size(0) == y.size(0) * 2]) # #:# in # #:# in #
+- __SHAPE_GUARD__: 2 <= L['y'].size()[0] # return any([x.size(0) == y.size(0) * 2]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
)
@ -747,6 +749,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
self.assertExpectedInline(
munge_shape_guards(record.getMessage()),
"""\
| +- __SHAPE_GUARD__: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # #:# in #
+- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # torch._check(x.size(0) == y.size(0) * 2) # #:# in # #:# in #
+- __SHAPE_GUARD__: 3 <= L['y'].size()[0] <= 14 # torch._check(x.size(0) > 5) # #:# in # #:# in # and torch._check(x.size(0) < 30) # #:# in # #:# in #""", # noqa: B950
)

View File

@ -330,7 +330,7 @@ class TestDynamoTimed(TestCase):
'graph_input_count': 1,
'graph_node_count': 3,
'graph_op_count': 1,
'guard_count': 8,
'guard_count': 9,
'has_guarded_code': True,
'inductor_code_gen_cumulative_compile_time_us': 0,
'inductor_compile_time_s': 0.0,

View File

@ -614,3 +614,29 @@ def check_vmap_fallback(test_case, thunk, opinfo, dry_run=False):
print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
else:
print(f"xfail('{opinfo.name}'),")
def saved_tensors_hooks_to_gm(
pack_fn, unpack_fn, pack_cache_hash, unpack_cache_hash, symbolic_tracing=True
):
if symbolic_tracing:
pack_gm = torch.fx.symbolic_trace(pack_fn)
unpack_gm = torch.fx.symbolic_trace(unpack_fn)
else:
from torch.functorch import make_fx
inp = torch.randn(2, 3)
torch._dynamo.mark_dynamic(inp, 0)
torch._dynamo.mark_dynamic(inp, 1)
pack_out = pack_fn(inp)
pack_gm = make_fx(pack_fn)(inp)
unpack_gm = make_fx(unpack_fn)(pack_out)
def set_manual_hash(g, manual_hash):
node = next(iter(g.nodes))
node.meta["user_cache_hash"] = manual_hash
set_manual_hash(pack_gm.graph, pack_cache_hash)
set_manual_hash(unpack_gm.graph, unpack_cache_hash)
return pack_gm, unpack_gm

View File

@ -10,12 +10,19 @@ import copy
import itertools
import unittest
import warnings
from contextlib import ContextDecorator, nullcontext
from contextlib import ContextDecorator, ExitStack, nullcontext
from functools import partial, wraps
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
from common_utils import decorate, decorateForModules, skip, skipOps, xfail
from common_utils import (
decorate,
decorateForModules,
saved_tensors_hooks_to_gm,
skip,
skipOps,
xfail,
)
import torch
import torch._dynamo as torchdynamo
@ -56,6 +63,8 @@ from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.utils.rnn import PackedSequence
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
@ -114,6 +123,73 @@ except ImportError:
# NB: numpy is a testing dependency!
def amax_to_scale(
amax: torch.Tensor,
float8_dtype: torch.dtype,
round_scales_to_power_of_2: bool = False,
):
amax = amax.to(torch.float64)
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=1e-12)
res = res.to(torch.float32)
return res
# Must be at module level to use fx.wrap
@torch.fx.wrap
def _pack_fp8_with_scale_wrap(x):
if not x.dtype.is_floating_point:
return x
amax = torch.max(torch.abs(x))
scale = amax_to_scale(amax, torch.float8_e5m2)
x_scaled = x.to(torch.float32) * scale
x_fp8 = x_scaled.to(torch.float8_e5m2)
return x.dtype, scale, x_fp8
@torch.fx.wrap
def _unpack_fp8_with_scale_wrap(x):
if isinstance(x, torch.Tensor):
return x
dtype, scale, x_fp8 = x
y = x_fp8.to(torch.float32) / scale
return y.to(dtype)
@torch.fx.wrap
def _pack_fp8_wrap(x):
if not x.dtype.is_floating_point:
return x
return (x.dtype, x.to(torch.float8_e5m2))
@torch.fx.wrap
def _unpack_fp8_wrap(x):
if isinstance(x, torch.Tensor):
return x
dtype, tensor = x
return tensor.to(dtype)
def pack_fp8(x):
return _pack_fp8_wrap(x)
def unpack_fp8(packed):
return _unpack_fp8_wrap(packed)
def pack_fp8_with_scale(x):
return _pack_fp8_with_scale_wrap(x)
def unpack_fp8_with_scale(packed):
return _unpack_fp8_with_scale_wrap(packed)
class AOTTestCase(TestCase):
pass
@ -4061,6 +4137,49 @@ def forward(self, tangents_1):
counters.clear()
torch._dynamo.reset()
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@torch._functorch.config.patch(saved_tensors_hooks_filtering_mode="no_static")
@torch._functorch.config.patch(recompute_views=True)
def test_saved_tensors_hooks_mutations_raise(self):
ctx = torch.autograd.graph.saved_tensors_hooks
device = "cuda"
class SAF(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gx):
(saved_x,) = ctx.saved_tensors
return gx + saved_x
def mutate(x):
return x.mul_(2)
def fn(x):
x = 2 * x
x = SAF.apply(x)
return x
def inp_fn():
x = torch.ones(2, 3, device=device, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
return x
with self.assertRaisesRegex(
AssertionError, "Saved tensors hooks with inputs mutations are not allowed"
):
try:
with ctx(*saved_tensors_hooks_to_gm(mutate, mutate, None, None)):
x = inp_fn()
y = torch.compile(fn, backend="aot_eager", fullgraph=True)(x)
y.sum().backward()
except torch._dynamo.exc.BackendCompilerFailed as e:
raise e.inner_exception from e
def test_mark_activations_dynamic_with_nested(self):
# The flattened tensors of the nested tensor aren't
# marked as activations, but they add some offset
@ -6775,6 +6894,489 @@ metadata incorrectly.
self.assertEqual(1, len(ctx.tangent_strides))
self.assertEqual((128, 4, 16, 1), ctx.tangent_strides[0])
def _test_pack_hooks(
self,
fn,
inp_fn,
hooks,
symbolic_tracing=True,
pre_compile_fn=None,
backend="inductor",
):
ctx = torch.autograd.graph.saved_tensors_hooks
torch._dynamo.reset()
with ExitStack() as stack:
# All hooks in eager to get ref
for hook, _ in hooks:
pack, unpack = hook
stack.enter_context(ctx(pack, unpack))
ref_x = inp_fn()
def _f(t):
if t.dtype.is_floating_point:
return t.detach().clone().requires_grad_()
return t
x = pytree.tree_map_only(torch.Tensor, _f, ref_x)
ref_y = fn(*ref_x)
ref_y.sum().backward()
if pre_compile_fn:
pre_compile_fn()
with ExitStack() as stack:
for hook, inline in hooks:
pack, unpack = hook
if inline:
if symbolic_tracing:
stack.enter_context(
ctx(
*saved_tensors_hooks_to_gm(
pack,
unpack,
"pack_hash",
"unpack_hash",
)
)
)
else:
stack.enter_context(
ctx(
*saved_tensors_hooks_to_gm(
pack, unpack, "pack_hash", "unpack_hash"
)
)
)
else:
stack.enter_context(ctx(pack, unpack))
y = torch.compile(fn, backend=backend, fullgraph=True)(*x)
y.sum().backward()
self.assertEqual(ref_y, y, atol=1e-2, rtol=1e-2)
ref_x_grad = pytree.tree_map_only(torch.Tensor, lambda t: t.grad, ref_x)
x_grad = pytree.tree_map_only(torch.Tensor, lambda t: t.grad, x)
self.assertEqual(ref_x_grad, x_grad, atol=1e-2, rtol=1e-2)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@unittest.skipIf(not SM80OrLater, "bfloat16, float8")
@parametrize("saved_tensors_hooks_filtering_mode", ["donated", "no_static", "all"])
def test_saved_tensors_hooks_base(self, saved_tensors_hooks_filtering_mode):
with patch(
"torch._functorch.config.saved_tensors_hooks_filtering_mode",
saved_tensors_hooks_filtering_mode,
):
# y argument is expected to test saving of int tensor,
# to check filtering functionality to not apply hooks for e.g. is_floating_point
class SAF(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x
@staticmethod
def backward(ctx, gx):
(saved_x, saved_y) = ctx.saved_tensors
return gx + saved_x + saved_y, None
class AF(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
ctx.d1 = x.size(1)
return x
@staticmethod
def backward(ctx, gx):
(saved_x,) = ctx.saved_tensors
d1 = ctx.d1
return gx + saved_x * d1
def fn(x, y):
x = x.relu()
x = x + 1
x = x.relu()
x = 2 * x
x = AF.apply(x)
return x
def simple_fn(x, y):
x = x + 1
x = x.t()
x = x.relu()
x = x.t()
x = SAF.apply(x, y)
return x
device = torch.device("cuda:0")
def inp_fn():
x = torch.ones(2, 2, device=device, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
y = torch.zeros(2, 2, device=device, dtype=torch.int64)
return x, y
def pack_dev_sym_cpu(x):
return x.dtype, x.device, x.size(1), x.cpu()
def unpack_dev_sym_cpu(packed):
dtype, device, dim1, x = packed
x = x.to(device=device)
return x.to(dtype)
def pack_tensor(x):
return x.device, x.cpu()
def unpack_tensor(packed):
device, t_cpu = packed
return t_cpu.to(device)
def pack_bf16(x):
return x.dtype, x.to(dtype=torch.bfloat16)
def unpack_bf16(packed):
dtype, x = packed
return x.to(dtype)
def pack_mul2(x):
return x.dtype, x * 2
def unpack_mul2(x):
dtype, x = x
x = x / 2
return x.to(dtype)
def pack_wrapper_sc(x):
return WrapperSubclass(x)
def unpack_wrapper_sc(x):
return x.a
def pack_wrapper_two_tensor(x):
return TwoTensor(x, x)
def unpack_wrapper_two_tensor(x):
return x.a + x.b
def pack_mul2_eager(x):
return x * 2
def unpack_mul2_eager(x):
return x / 2
def pack_cpu(x):
return x.to(device="cpu")
def unpack_cpu(x):
return x.to(device=device)
for test_fn in [simple_fn, fn]:
self._test_pack_hooks(
test_fn,
inp_fn,
[((pack_cpu, unpack_cpu), True)],
symbolic_tracing=False,
)
self._test_pack_hooks(
test_fn, inp_fn, [((pack_bf16, unpack_bf16), True)]
)
self._test_pack_hooks(
test_fn, inp_fn, [((pack_mul2, unpack_mul2), True)]
)
self._test_pack_hooks(
test_fn, inp_fn, [((pack_tensor, unpack_tensor), True)]
)
self._test_pack_hooks(
test_fn, inp_fn, [((pack_dev_sym_cpu, unpack_dev_sym_cpu), True)]
)
self._test_pack_hooks(
test_fn, inp_fn, [((pack_mul2_eager, unpack_mul2_eager), False)]
)
self._test_pack_hooks(
test_fn,
inp_fn,
[((pack_fp8, unpack_fp8), True)],
)
self._test_pack_hooks(
test_fn,
inp_fn,
[((pack_fp8_with_scale, unpack_fp8_with_scale), True)],
)
# Disable testing of Subclasses for now
# self._test_pack_hooks(test_fn, inp_fn, [(pack_wrapper_sc, unpack_wrapper_sc)])
# self._test_pack_hooks(
# test_fn, inp_fn, [(pack_wrapper_two_tensor, unpack_wrapper_two_tensor)]
# )
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@unittest.skipIf(not SM80OrLater, "bfloat16, float8")
def test_saved_tensors_hooks_params(self):
lib = torch.library.Library("_test_aotdispatch_lib", "FRAGMENT")
logged_shapes = []
logged_dtypes = []
lib.define("log(Tensor x) -> Tensor")
def log_impl(x):
logged_shapes.append(list(x.shape))
logged_dtypes.append(x.dtype)
return x.clone()
def log_meta(x):
return x.clone()
for backend in ["CPU", "CUDA"]:
lib.impl(
"log",
log_impl,
backend,
)
lib.impl("log", log_meta, "Meta")
def pack_fp8_with_scale_and_log(x):
torch.ops._test_aotdispatch_lib.log(x)
return _pack_fp8_with_scale_wrap(x)
def unpack_fp8_with_scale_and_log(packed):
return _unpack_fp8_with_scale_wrap(packed)
def m_inp_fn():
x = torch.ones(
2, 2, 2, device=device, dtype=torch.float64, requires_grad=True
)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
return (x,)
class SAF0(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gx):
(saved_x,) = ctx.saved_tensors
return gx + saved_x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(2, 2)
def forward(self, x):
x = SAF0.apply(x)
x = x.to(dtype=torch.float32)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def _reset_logged():
logged_shapes.clear()
logged_dtypes.clear()
device = torch.device("cuda:0")
m = M().to(device=device)
def _test_m():
self._test_pack_hooks(
m,
m_inp_fn,
[
(
(
pack_fp8_with_scale_and_log,
unpack_fp8_with_scale_and_log,
),
True,
)
],
pre_compile_fn=_reset_logged,
backend="aot_eager",
)
with patch(
"torch._functorch.config.saved_tensors_hooks_filtering_mode", "donated"
):
_reset_logged()
_test_m()
# Check that hooks were not applied to Parameters
# parameters excluded
self.assertFalse([2, 2] in logged_shapes)
self.assertTrue([2, 2, 2] in logged_shapes)
# input excluded
self.assertFalse(torch.float64 in logged_dtypes)
with patch(
"torch._functorch.config.saved_tensors_hooks_filtering_mode", "no_static"
):
_reset_logged()
_test_m()
# Check that hooks were not applied to Parameters
# parameters excluded
self.assertFalse([2, 2] in logged_shapes)
self.assertTrue([2, 2, 2] in logged_shapes)
self.assertTrue(torch.float64 in logged_dtypes)
with patch("torch._functorch.config.saved_tensors_hooks_filtering_mode", "all"):
_reset_logged()
_test_m()
# Check that hooks were applied to all saved tensors
self.assertTrue([2, 2] in logged_shapes)
self.assertTrue([2, 2, 2] in logged_shapes)
self.assertTrue(torch.float64 in logged_dtypes)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@unittest.skipIf(not SM80OrLater, "bfloat16, float8")
@torch._functorch.config.patch(saved_tensors_hooks_filtering_mode="all")
def test_saved_tensors_hooks_recompile(self):
ctx = torch.autograd.graph.saved_tensors_hooks
def pack_bf16(x):
return x.to(dtype=torch.bfloat16)
def unpack_bf16(x):
return x.to(dtype=torch.float)
def pack_mul2(x):
return x * 2
def unpack_mul2(x):
return x / 2
def _test(hooks, inline, expected_compile_count):
class SAF(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gx):
(saved_x,) = ctx.saved_tensors
return gx + saved_x
class AF(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
ctx.d1 = x.size(1)
return x
@staticmethod
def backward(ctx, gx):
(saved_x,) = ctx.saved_tensors
d1 = ctx.d1
return gx + saved_x * d1
def fn(x):
x = x.relu()
x = x + 1
x = 2 * x
x = AF.apply(x)
return x
device = torch.device("cuda:0")
def inp_fn():
x = torch.ones(2, 3, device=device, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
return x
from torch._dynamo.testing import CompileCounter
cnt = CompileCounter()
x = inp_fn()
y = torch.compile(fn, backend=cnt, fullgraph=True)(x)
y.sum().backward()
def _test_with_hooks(hooks):
with ExitStack() as stack:
pack, unpack = hooks
if inline:
stack.enter_context(
ctx(
*saved_tensors_hooks_to_gm(
pack, unpack, "pack_hash", "unpack_hash"
)
)
)
else:
stack.enter_context(ctx(pack, unpack))
x = inp_fn()
y = torch.compile(fn, backend=cnt, fullgraph=True)(x)
y.sum().backward()
_test_with_hooks(hooks[0])
_test_with_hooks(hooks[1])
self.assertEqual(cnt.frame_count, expected_compile_count)
_test(
((pack_bf16, unpack_bf16), (pack_mul2, unpack_mul2)),
inline=False,
expected_compile_count=1,
)
_test(
((pack_bf16, unpack_bf16), (pack_mul2, unpack_mul2)),
inline=True,
expected_compile_count=3,
)
@torch._functorch.config.patch(donated_buffer=True)
@torch._functorch.config.patch(saved_tensors_hooks_filtering_mode="no_static")
def test_saved_tensors_hooks_donated_buffers(self):
pack_gm, unpack_gm = saved_tensors_hooks_to_gm(
pack_fp8,
unpack_fp8,
"pack_hash",
"unpack_hash",
)
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
class SAF(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gx):
(saved_x,) = ctx.saved_tensors
return gx + saved_x
def fn(x):
x0 = x
x = SAF.apply(x)
return x0, torch.nn.functional.relu(x)
inp = torch.rand([3, 3], requires_grad=True)
# 1. No donated buffers without hooks, as relu saves input which is also user output.
with self.assertLogs(logger_name, level="INFO") as captured:
out = torch.compile(fn, backend="aot_eager", fullgraph=True, dynamic=False)(
inp
)
out[1].sum().backward()
expected_msg = "bw_donated_idxs=[]"
FileCheck().check(expected_msg).run("\n".join(captured.output))
# 2. Hooks applied for all saved, as we set saved_tensors_hooks_no_filtering=True
# Results of the hooks become donated buffers.
inp = torch.rand([3, 3], requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack_gm, unpack_gm):
with self.assertLogs(logger_name, level="INFO") as captured:
out = torch.compile(
fn, backend="aot_eager", fullgraph=True, dynamic=False
)(inp)
out[1].sum().backward()
expected_msg = "bw_donated_idxs=[0, 1]"
FileCheck().check(expected_msg).run("\n".join(captured.output))
# entries in here don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)

View File

@ -116,6 +116,9 @@ def _push_saved_tensors_default_hooks(
unpack_hook: Callable[[Any], torch.Tensor],
) -> None: ...
def _pop_saved_tensors_default_hooks() -> None: ...
def _top_saved_tensors_default_hooks(
ignore_is_tracing: bool,
) -> tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]]: ...
def _unsafe_set_version_counter(
t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...]
) -> None: ...
@ -123,7 +126,7 @@ def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
def _profiler_type() -> ActiveProfilerType: ...
def _saved_tensors_hooks_enable() -> None: ...
def _saved_tensors_hooks_disable(message: str) -> None: ...
def _saved_tensors_hooks_disable(message: str, fail_if_non_empty=True) -> None: ...
def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...

View File

@ -1614,6 +1614,33 @@ class GuardBuilder(GuardBuilderBase):
fn, get_verbose_code_parts(code, guard)
)
def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard):
get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
are_inline_hooks = (
torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
)
def hooks_ids_fn(hooks):
if not are_inline_hooks(hooks):
return None
pack_hook, unpack_hook = hooks
return tuple(map(id, hooks))
guard_hooks_ids = hooks_ids_fn(get_hooks())
code = [
f"torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == {guard_hooks_ids}"
]
self._set_guard_export_info(guard, code)
def fn(x):
return guard_hooks_ids == hooks_ids_fn(get_hooks())
self.guard_manager.root.add_lambda_guard(
fn, get_verbose_code_parts(code, guard)
)
def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard):
value = self.get(guard.name)
original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1])

View File

@ -519,6 +519,14 @@ class OutputGraph(OutputGraphGuardsState):
self.compiler_trace_stack = contextlib.ExitStack()
# These are the ambient, currently-global saved_tensor_hooks stashed in autograd,
# that are set for the entire duration of the compiled region.
# This is an invariant today because we graph break on the saved_tensor_hook
# context manager inside a compiled region
self.saved_tensors_hooks_subgraph_names: Optional[list[str]] = (
self.maybe_install_saved_tensors_hooks_subgraphs()
)
def mark_bytecode_tracing_start(self):
self.compiler_trace_stack.enter_context(
dynamo_timed(
@ -598,6 +606,41 @@ class OutputGraph(OutputGraphGuardsState):
self.guards.add(
GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
)
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
self.guards.add(
GlobalStateSource().make_guard(
GuardBuilder.AUTOGRAD_SAVED_TENSORS_HOOKS
)
)
def maybe_install_saved_tensors_hooks_subgraphs(self) -> Optional[list[str]]:
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
return None
get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
are_inline_hooks = (
torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
)
hooks = get_hooks()
if not are_inline_hooks(hooks):
return None
# If GraphModule provided by user contains fx.wrap,
# We can only rely on user provided cache hash in this case.
# If user did not provide cache hash - then we always bypass cache.
pack_gm, unpack_gm = hooks
pack_subgraph_name = self.install_subgraph(
"saved_tensors_hooks_pack",
torch.fx.GraphModule(self.nn_modules, pack_gm.graph),
)
unpack_subgraph_name = self.install_subgraph(
"saved_tensors_hooks_unpack",
torch.fx.GraphModule(self.nn_modules, unpack_gm.graph),
)
assert pack_subgraph_name == "saved_tensors_hooks_pack_0"
assert unpack_subgraph_name == "saved_tensors_hooks_unpack_0"
return [pack_subgraph_name, unpack_subgraph_name]
def dump_guards_state(self):
return OutputGraphGuardsState(
@ -854,7 +897,7 @@ class OutputGraph(OutputGraphGuardsState):
*names,
**options,
):
if is_dynamic_nn_module(target, self.root_tx.export):
if is_dynamic_nn_module(target, self.export):
# Instead of returning UnspecializedNNModuleVariable, call
# VariableTracker.build so that it is tracked for mutation.
return VariableTracker.build(self.current_tx, target, **options)
@ -1484,6 +1527,14 @@ class OutputGraph(OutputGraphGuardsState):
self.real_value_cache.clear()
gm = _make_graph_module(root, self.graph)
# Saved tensors hooks are not used by the graph.
# GraphModule by default only copies used in the graph submodules.
# Copying them into the result graph manually.
if self.saved_tensors_hooks_subgraph_names:
for subgraph_name in self.saved_tensors_hooks_subgraph_names:
setattr(gm, subgraph_name, getattr(root, subgraph_name))
for register_finalizer in self.register_finalizer_fns:
register_finalizer(gm)

View File

@ -202,6 +202,13 @@ def check_node_safe(node: Node):
# I'd love to use a match statement here, but it wasn't introduced until py3.10
if node.op == "call_function":
if node.meta and node.meta.get("is_wrapped", False):
# This is fx.wrap function
# By default we BypassAOTAutogradCache for unknown functions,
# But if user explicitly specified cache hash - allow to cache it.
if node.meta.get("user_cache_hash", None):
return
if not is_cacheable_function(node.target):
module = getattr(node.target, "__module__", None)
name = getattr(node.target, "__name__", None)
@ -259,6 +266,15 @@ def check_cacheable(gm: torch.fx.GraphModule):
for node in nodes:
check_node_safe(node)
# Saved tensors hooks are globally set subgraphs,
# that are not used explicitly in the main graph.
# They are inlined in aot_autograd graphs.
# Subgraphs are only used for caching logic.
if hasattr(gm, "saved_tensors_hooks_pack_0"):
check_cacheable(gm.saved_tensors_hooks_pack_0) # type: ignore[arg-type]
# We have guarantee of unpack sugraph existance if pack subgraph exists
check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type]
def check_metadata_cacheable(metadata: ViewAndMutationMeta):
"""
@ -292,6 +308,27 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
self.disable_amp = torch._C._is_any_autocast_enabled()
self.deterministic_algorithms = torch.are_deterministic_algorithms_enabled()
self.autograd_config = config.save_config()
self.saved_tensors_hooks_fx_wrap_cache_hashes: tuple[list[str], list[str]] = (
[],
[],
)
if hasattr(gm, "saved_tensors_hooks_pack_0"):
def _add_wrapped_user_cache_hashes(_gm, _l):
for node in _gm.graph.nodes:
if node.meta and node.meta.get("is_wrapped", False):
_l.append(node.meta["user_cache_hash"])
_add_wrapped_user_cache_hashes(
gm.saved_tensors_hooks_pack_0,
self.saved_tensors_hooks_fx_wrap_cache_hashes[0],
)
_add_wrapped_user_cache_hashes(
gm.saved_tensors_hooks_unpack_0,
self.saved_tensors_hooks_fx_wrap_cache_hashes[1],
)
try:
# FXGraphCache has constraints on what can be pickled in its inductor
# config. Check that the gm is cacheable by inductor first,

View File

@ -21,6 +21,7 @@ from contextlib import nullcontext
from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch import Tensor
from torch._dynamo.utils import detect_fake_mode, dynamo_timed, lazy_format_graph_code
@ -34,6 +35,8 @@ from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals
from torch.fx.graph_module import GraphModule
from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars
from torch.multiprocessing.reductions import StorageWeakRef
from torch.types import py_sym_types
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torchgen.utils import dataclass_repr
from .. import config
@ -766,6 +769,463 @@ def run_joint_graph_passes_on_hops(
return joint_gm
def maybe_log_graph(
gm,
graph_name,
aot_config,
structured_log_prefix_fn,
out_structured_logs: Optional[list[str]] = None,
):
if not aot_config.enable_log:
return
aot_graphs_log.debug(
"%s",
lazy_format_graph_code(
f"{graph_name}",
gm,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
def gm_str_fn() -> str:
return gm.print_readable(
print_output=False, include_stride=True, include_device=True
)
if out_structured_logs is not None:
out_structured_logs.append(f"{structured_log_prefix_fn()}:{gm_str_fn()}")
else:
trace_structured(
f"{structured_log_prefix_fn()}",
payload_fn=lambda: gm_str_fn(),
)
def create_wrap_fn(fn, args):
from functools import wraps
from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
from .functional_utils import from_fun, has_data_mutation, to_fun
def assert_no_mutation(t):
assert not has_data_mutation(
t
), "Saved tensors hooks with inputs mutations are not allowed"
@wraps(fn)
def _wrapper(*args):
with maybe_enable_thunkify():
disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
with disable_above:
f_args = pytree.tree_map(to_fun, args)
f_outs = fn(*f_args)
pytree.tree_map(assert_no_mutation, f_args)
return pytree.tree_map(from_fun, f_outs)
return _wrapper, args
def prepare_hook_gm(aot_config, fn, args):
from torch._functorch._aot_autograd.dispatch_and_compile_graph import _create_graph
fn, args = create_wrap_fn(fn, args)
gm = _create_graph(fn, args, aot_config=aot_config)
return gm
# Inline Autograd saved_tensors_hooks into epilogue of forward graph
# and prologue of backward graph.
# This changes forward graph outputs and inputs.
# Pack hook can return tensors, sym scalars, constants.
# All tensors to save for backward will be grouped together at front.
# Sym scalars grouped on another end. Constants are inlined in the graph.
def maybe_inline_graph_saved_tensors_hooks(
fw_module,
bw_module,
num_inner_fwd_outputs,
inner_meta,
aot_config,
static_input_indices,
):
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
return
get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
are_inline_hooks = (
torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
)
hooks = get_hooks()
if not are_inline_hooks(hooks):
return
pack_hook_gm, unpack_hook_gm = hooks
structured_logs: list[str] = []
maybe_log_graph(
fw_module,
"Forward graph pre saved_tensors_hooks inlining",
aot_config,
lambda: "aot_forward_graph_pre_saved_tensors_hooks",
structured_logs,
)
maybe_log_graph(
bw_module,
"Backward graph pre saved_tensors_hooks inlining",
aot_config,
lambda: "aot_backward_graph_pre_saved_tensors_hooks",
structured_logs,
)
fw_g = fw_module.graph
bw_g = bw_module.graph
fw_g_names = {node.name for node in fw_g.nodes}
bw_g_names = {node.name for node in bw_g.nodes}
def _gen_unused_name(candidate: str):
c = candidate
i = 0
while c in fw_g_names or c in bw_g_names:
c = f"{candidate}_{i}"
i = i + 1
return c
bw_g_inputs = bw_g.find_nodes(op="placeholder")
fw_out_n = fw_g.output_node()
fw_outs = fw_out_n.args[0] # type: ignore[var-annotated]
fw_outs_inner_set = set(fw_outs[:num_inner_fwd_outputs])
fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:]
fw_outs_packed_tensors = [] # type: ignore[var-annotated]
fw_outs_packed_syms = [] # type: ignore[var-annotated]
# The main use case for saved_tensors_hooks is activation quantization,
# for memory usage optimization.
# Desired behavior is to quantize saved activations to free the original saved tensor.
# Saved nodes may include forward inputs, outputs, parameters.
# They may be held by something else and will not be deallocated after quantization.
# Donated buffers are intermediates in the graph invisible for the user,
# this guarantees that they can be deallocated.
# Using this as a default behavior to select saved nodes to apply hooks.
# There is also a config to apply hooks for all saved nodes without any filtering.
# The plan is to propagate meta about the source of the saved node to the user hook function.
mode = torch._functorch.config.saved_tensors_hooks_filtering_mode
allow_set = None
exclude_set = None
if mode == "donated":
# collect_bw_donated_buffer_idxs requires inner_meta to have num_symints_saved_for_bw
inner_meta.num_symints_saved_for_bw = len(
[n for n in fw_outs_saved_for_bw if is_sym_node(n)]
)
bw_donated_idxs = collect_bw_donated_buffer_idxs(
fw_module,
bw_module,
inner_meta,
)
fw_donated_idxs = [
i - inner_meta.num_symints_saved_for_bw for i in bw_donated_idxs
]
allow_set = {fw_outs_saved_for_bw[i].name for i in fw_donated_idxs}
elif mode == "no_static":
fw_g_inputs = fw_g.find_nodes(op="placeholder")
exclude_set = {fw_g_inputs[i].name for i in static_input_indices}
if (allow_set is not None) and (not allow_set):
# This means we have empty whitelist,
# No donated (intermediate) saved.
# Do not do anything in this case
return
if aot_config.enable_log:
structured_logs.append(f"fw_outs_saved_for_bw:{fw_outs_saved_for_bw}")
structured_logs.append(f"mode:{mode}")
structured_logs.append(f"allow_set:{allow_set}")
structured_logs.append(f"exclude_set:{exclude_set}")
for saved in fw_outs_saved_for_bw:
if ((allow_set is not None) and (saved.name not in allow_set)) or (
(exclude_set is not None) and (saved.name in exclude_set)
):
if isinstance(saved.meta["val"], torch.Tensor):
fw_outs_packed_tensors.append(saved)
continue
val = saved.meta["val"]
if not isinstance(val, torch.Tensor):
continue
pack_out_val = pack_hook_gm(val)
requires_sc_handling = any(
is_traceable_wrapper_subclass(x) for x in pytree.tree_leaves(pack_out_val)
)
if requires_sc_handling:
raise NotImplementedError(
"Tensor subclasses in GraphModule saved tensors hooks are not supported"
"You can workaround it by manually returning subclass's inner tensors"
" in the pack hook, and reconstructing the subclass in the unpack hook"
)
pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,))
pack_g = pack_gm.graph
maybe_log_graph(
pack_gm,
f"saved_tensors_pack_hook {saved.name}",
aot_config,
lambda: f"aot_saved_tensors_hooks_pack {saved.name}",
structured_logs,
)
pack_out_val = pack_gm(val)
# Install pack hook graph as eiplogue of fw_module.
# Saved tensor output becomes input of pack hook graph.
# Replace saved tensor output with pack hook graph output.
# Outputs symbolic scalars, tensors are accumulated separately.
# Then in forward outputs and backward inputs installed in order
# sym_scalars, packed_saved_tensors.
# Keeping all tensors together allows to preserve
# the same identification at runtime,
# updating only number of saved sym_scalars and tensors.
pack_g_inputs = pack_g.find_nodes(op="placeholder")
assert len(pack_g_inputs) == 1
env = {pack_g_inputs[0]: saved}
fw_pack_out_args = None
with fw_g.inserting_before(fw_out_n):
for node in pack_g.nodes:
if node.op == "placeholder":
continue
new_n = fw_g.node_copy(node, lambda n: env[n])
fw_g_names.add(new_n.name)
env[node] = new_n
# Output node is temporarily copied to have remapped arguments.
# Removed in the end.
if node.op == "output":
fw_pack_out_args = new_n.args[0]
fw_g.erase_node(new_n)
env.clear()
assert fw_pack_out_args
fw_outs_bw_ins_node_names = []
for out_idx, _n in enumerate(pytree.tree_leaves(fw_pack_out_args)):
if not isinstance(_n, torch.fx.Node):
fw_outs_bw_ins_node_names.append("")
continue
# This happens when hook is noop and it is either user input or user output.
# Do not do anything with this node.
if _n.op == "placeholder" or _n in fw_outs_inner_set:
# This means the hook returned input primals unchanged
# Do not rename in this case.
n = _n
new_node_name = _n.name
fw_outs_bw_ins_node_names.append(new_node_name)
else:
# We can not specify desired name in node_copy.
# Copying node manually to set specifc name,
# to have matching fw_outs, bw_inputs names.
new_node_name = _gen_unused_name(f"{saved.name}_hook_{out_idx}")
with fw_g.inserting_before(_n):
n = fw_g.create_node(
_n.op,
_n.target,
_n.args,
_n.kwargs,
name=new_node_name,
)
assert n.name == new_node_name
fw_outs_bw_ins_node_names.append(new_node_name)
n.meta = copy.copy(_n.meta)
_n.replace_all_uses_with(n)
fw_g.erase_node(_n)
if isinstance(n.meta["val"], torch.Tensor):
fw_outs_packed_tensors.append(n)
elif is_sym_node(n):
fw_outs_packed_syms.append(n)
# Install unpack hook graph as a prologue of backward graph
# Saved tensors inputs are replaced with packed tensors and packed sym scalars.
# The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs.
unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,))
unpack_g = unpack_gm.graph
maybe_log_graph(
unpack_gm,
f"saved_tensors_unpack_hook {saved.name}",
aot_config,
lambda: f"aot_saved_tensors_hooks_unpack {saved.name}",
structured_logs,
)
def find_saved_in_bw_inputs(bw_inputs):
for n in bw_inputs:
if n.name == saved.name:
return n
bw_g_input = find_saved_in_bw_inputs(bw_g_inputs)
assert bw_g_input
original_bw_g_input_users = list(bw_g_input.users.keys())
bw_g_input_used_directly = False
# Replace backward graph saved tensor input with copy of pack graph outputs
# All non-Tensor, non-symscalars outputs are constanted.
unpack_g_inputs = unpack_g.find_nodes(op="placeholder")
env = {}
for out_idx, (unp_in_n, out_n, val) in enumerate(
zip(
unpack_g_inputs,
pytree.tree_leaves(fw_pack_out_args),
pytree.tree_leaves(pack_out_val),
)
):
is_sym = isinstance(val, py_sym_types)
if isinstance(val, torch.Tensor) or is_sym:
# We want forward_outputs names to match backward_inputs,
# Potentially backward may already have "{saved.name}_hook_{idx}",
# In this case fx.Graph will add suffix.
new_node_name = fw_outs_bw_ins_node_names[out_idx]
if bw_g_input.name == new_node_name:
env[unp_in_n] = bw_g_input
bw_g_input_used_directly = True
else:
# Backward calling convention: ctx_symints,ctx_saved_tensors
# Inserting packed sym scalars before first saved tensor input.
# Inserting packed tensors before last saved tensor input.
# Saved tensor inputs between them will be removed.
with bw_g.inserting_before(
bw_g_inputs[0]
) if is_sym else bw_g.inserting_before(bw_g_input):
new_n = bw_g.placeholder(new_node_name)
assert new_n.name == new_node_name
new_n.meta = copy.copy(out_n.meta)
env[unp_in_n] = new_n
else:
# Inline values of non-Tensor, non-SymScalars
env[unp_in_n] = val
# Inserting unpack hook after placeholders.
bw_unpack_out_n = None
with bw_g.inserting_before(bw_g_inputs[-1].next):
for node in unpack_g.nodes:
if node.op == "placeholder":
continue
new_n = bw_g.node_copy(node, lambda n: env[n])
bw_g_names.add(new_n.name)
env[node] = new_n
# Temporary insert output, to have remapped by node_copy args.
# Removed in the end.
if node.op == "output":
bw_unpack_out_n = new_n
assert bw_unpack_out_n
_leaves = pytree.tree_leaves(bw_unpack_out_n.args)
assert len(_leaves) == 1
unpack_saved_tensor_n = _leaves[0]
if not bw_g_input_used_directly:
bw_g_input.replace_all_uses_with(unpack_saved_tensor_n)
bw_g.erase_node(bw_g_input)
else:
# Keep usages of bw_g_input in inserted unpacked hook graph.
# Replace other usages of bw_g_input with unpack_saved_tensor_n.
from torch._C import _fx_map_arg
def maybe_replace_node(n):
return unpack_saved_tensor_n if n == bw_g_input else n
for use_node in original_bw_g_input_users:
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
use_node._update_args_kwargs(new_args, new_kwargs)
bw_g.erase_node(bw_unpack_out_n)
# Changing forward graph outputs,
# Inserting packed_tensors and packed_syms on the place of saved tensors.
# Packed sym_scalars are together with saved symints
symint_outs_saved_for_bw = [n for n in fw_outs_saved_for_bw if is_sym_node(n)]
fw_new_outs = pytree.tree_leaves(
(
fw_outs[:num_inner_fwd_outputs],
fw_outs_packed_tensors,
fw_outs_packed_syms,
symint_outs_saved_for_bw,
)
)
fw_out_n.args = (tuple(fw_new_outs),)
# Assert that saved tensors and symints in forward outputs are aligned with backward inputs
_fw_n = num_inner_fwd_outputs
_fw_num_t = len(fw_outs_packed_tensors)
_fw_num_s = len(fw_outs_packed_syms) + len(symint_outs_saved_for_bw)
fw_outs_saved_tensors = fw_new_outs[_fw_n : _fw_n + _fw_num_t]
fw_outs_saved_syms = fw_new_outs[_fw_n + _fw_num_t :]
bw_new_ins = list(bw_g.find_nodes(op="placeholder"))
bw_ins_saved_syms = bw_new_ins[:_fw_num_s]
bw_ins_saved_tensors = bw_new_ins[_fw_num_s : _fw_num_s + _fw_num_t]
fw_t_names = [n.name for n in fw_outs_saved_tensors]
bw_t_names = [n.name for n in bw_ins_saved_tensors]
fw_s_names = [n.name for n in fw_outs_saved_syms]
bw_s_names = [n.name for n in bw_ins_saved_syms]
def _log_structured_logs():
if not aot_config.enable_log:
return
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "aot_saved_tensors_hooks_graphs",
"encoding": "string",
},
payload_fn=lambda: "\n".join(structured_logs),
)
if aot_config.enable_log:
structured_logs.append(
f"fw_outs[:num_inner_fwd_outputs]:{fw_outs[:num_inner_fwd_outputs]}"
)
structured_logs.append(f"fw_outs_packed_tensors:{fw_outs_packed_tensors}")
structured_logs.append(f"fw_t_names:{fw_t_names}")
structured_logs.append(f"bw_t_names:{bw_t_names}")
structured_logs.append(f"fw_s_names:{fw_s_names}")
structured_logs.append(f"bw_s_names:{bw_s_names}")
structured_logs.append(f"\nfw_g_pre_assert:{fw_g}")
structured_logs.append(f"\nbw_g_pre_assert:{bw_g}")
maybe_log_graph(
fw_module,
"Forward graph after transform pre-assert",
aot_config,
lambda: "aot_forward_graph_pre_assert_saved_tensors_hooks",
structured_logs,
)
maybe_log_graph(
bw_module,
"Backward graph after transform pre-assert",
aot_config,
lambda: "aot_backward_graph_pre_assert_saved_tensors_hooks",
structured_logs,
)
_log_structured_logs()
assert fw_t_names == bw_t_names
assert fw_s_names == bw_s_names
fw_g.lint()
bw_g.lint()
fw_module.recompile()
bw_module.recompile()
def aot_dispatch_autograd(
flat_fn,
flat_args: list[Any],
@ -875,6 +1335,16 @@ def aot_dispatch_autograd(
joint_inputs[1],
)
maybe_inline_graph_saved_tensors_hooks(
fw_module,
bw_module,
num_inner_fwd_outputs,
inner_meta,
aot_config,
fw_metadata.static_input_indices,
)
static_lifetime_input_indices = fw_metadata.static_input_indices
fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0]
# we only need to bookkeep the symints that are saved for bw, not any symints
# the user forward might have returned in its own output

View File

@ -8,6 +8,7 @@ This module defines runtime wrappers, which, based on previous analysis attempts
"""
import builtins
import collections
import contextlib
import copy
import itertools
import pprint
@ -256,6 +257,23 @@ def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]):
t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined]
def _should_disable_saved_tensors_hooks():
# Compiled autograd is not supported yet, to be added in future.
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
return False
get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
are_inline_hooks = (
torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
)
hooks = get_hooks()
if are_inline_hooks(hooks):
return True
return False
def _create_runtime_wrapper(
compiled_fn,
*,
@ -446,7 +464,15 @@ def _create_runtime_wrapper(
torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation)
return ret_outs
return runtime_wrapper
if not (trace_joint and _should_disable_saved_tensors_hooks()):
return runtime_wrapper
# Disabling saved tensors hooks
def _runtime_wrapper(*args, **kwargs):
with _disable_saved_tensors_hooks():
return runtime_wrapper(*args, **kwargs)
return _runtime_wrapper
@dataclass
@ -1809,6 +1835,35 @@ def coerce_to_expected_memory_format(x: torch.Tensor, memory_format: MemoryForma
return restrided
@contextlib.contextmanager
def _disable_saved_tensors_hooks():
error_message = (
"Saved tensors hooks were specialized as GraphModules."
"In this case aot_autograd inlines them in forward and backward graph "
"and disables them during runtime of aot_autograd compiled region."
"If you see this error, that means that there is some unexpected push or pop manipulation "
"during aot_autograd compiled region runtime."
"Compilation with different hooks must result in recompilation."
)
fail_if_non_empty = False
maybe_prev_message = None
try:
maybe_prev_message = (
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
)
torch._C._autograd._saved_tensors_hooks_disable(
error_message, fail_if_non_empty
)
yield
finally:
if maybe_prev_message is None:
torch._C._autograd._saved_tensors_hooks_enable()
else:
torch._C._autograd._saved_tensors_hooks_disable(
maybe_prev_message, fail_if_non_empty
)
# This is wrapped in a class just for namespacing purposes
# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly
class AOTDispatchAutograd:

View File

@ -500,3 +500,16 @@ def get_cuda_generator_meta_val(device_idx: int):
it is fine to use in the meta.
"""
return torch.cuda.default_generators[device_idx].clone_state()
def top_saved_tensors_hooks():
return torch._C._autograd._top_saved_tensors_default_hooks(True)
def saved_tensors_hooks_are_inlineable(hooks) -> bool:
if not hooks:
return False
pack, unpack = hooks
return isinstance(pack, torch.fx.GraphModule) and isinstance(
unpack, torch.fx.GraphModule
)

View File

@ -290,6 +290,19 @@ guess_tangent_strides_as_outputs = False
# it will untimately be removed once we share size_hints across ranks through compiler collectives
_broadcast_rank0_decision = False
# By default apply inlined saved_tensors_hooks only for "donated" buffers.
# "donated" buffers are invisible to the user, they are intermediates of the forward graph.
# Applying saved tensors hooks for memory optimizations only for intermediates
# guarantees that original saved tensors could be deallocated.
# This config enables saved_tensors_hooks are applied for **all** saved tensors,
# that could include inputs, parameters, outputs.
# "donated" - applied only to saved intermediates of the graph
# "no_static" - applied to all saved but not "static"
# (this includes parameters and user marked as static)
# "all" - no filtering, everything saved for backward.
saved_tensors_hooks_filtering_mode = "donated"
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -455,7 +455,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
"_saved_tensors_hooks_is_enabled",
at::SavedTensorDefaultHooks::is_enabled);
m.def("_saved_tensors_hooks_enable", at::SavedTensorDefaultHooks::enable);
m.def("_saved_tensors_hooks_disable", at::SavedTensorDefaultHooks::disable);
m.def(
"_saved_tensors_hooks_disable",
at::SavedTensorDefaultHooks::disable,
py::arg("error_message"),
py::arg("fail_if_non_empty") = true);
m.def(
"_saved_tensors_hooks_set_tracing",
at::SavedTensorDefaultHooks::set_tracing);
@ -471,6 +475,27 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
m.def("_pop_saved_tensors_default_hooks", []() {
torch::autograd::PyDefaultSavedVariableHooks::pop_hooks();
});
m.def(
"_top_saved_tensors_default_hooks",
[](bool ignore_is_tracing)
-> std::optional<std::pair<py::function, py::function>> {
auto out = at::SavedTensorDefaultHooks::get_hooks(ignore_is_tracing);
if (!out.has_value()) {
return std::nullopt;
}
auto [pack_hook, unpack_hook] = *out;
// gil for destructor of pack_hook, unpack_hook that decrements
// reference
py::gil_scoped_acquire gil;
return std::make_pair(
py::reinterpret_steal<py::function>(pack_hook.release()),
py::reinterpret_steal<py::function>(unpack_hook.release()));
}
);
m.def("_get_creation_meta", [](const at::Tensor& t) {
auto* meta = torch::autograd::impl::get_view_autograd_meta(t);

View File

@ -52,8 +52,9 @@ SavedVariable::SavedVariable(
TORCH_INTERNAL_ASSERT(!is_leaf_ && is_output);
weak_grad_fn_ = variable.grad_fn();
}
auto maybe_hooks = get_default_hooks();
std::unique_ptr<SavedVariableHooks> maybe_hooks =
at::SavedTensorDefaultHooks::is_enabled() ? get_default_hooks()
: nullptr;
// Avoid wrapped numbers from being leaked to the user
if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {