Files
pytorch/torch/_functorch/config.py
IvanKobzarev 4439255148 [aotd] Support saved tensors hooks in aot_autograd (#150032)
https://github.com/pytorch/pytorch/issues/148222

Goal:

At the moment autograd saved tensors hooks are run in eager after compiled forward.
They are executed at the same time for all saved tensors.
Hooks can be used to reduce amout of memory used for saved tensors, doing quantization or offloading to cpu.
This is suboptimal for optimization of peak memory.
Better solution will be to put the hooks in the graph, as close as possible to the last usage of the tensor.

To get user specified autograd saved tensors hooks in the graph.

Logic:

UX:
If user specifies with torch.autograd.graph.saved_tensors_hooks(pack_gm, unpack_gm).
Where pack_gm and unpack_gm are torch.fx.GraphModule.
Then AotAutograd will retrace those graph modules, doing decompositions and functionalization in aot_autograd, inlining the result graphs in forward epilogue and backward prologue.

User may want to use control logic in the hooks, for example applying quantization only for specific dtypes and sizes.

This is also possible, user can put it into torch.fx.wrap function and use symbolic trace to make a GraphModule.

In that case AotAutograd cahing will work only in case when user explicitly set to the torch.fx.wrap call_function node "user_cache_hash" metadata.

If this metadata set - then aot_autograd cache can use saved cache artifact.
If metadata is not set - then cache is bypassed.

Dynamo:
Dynamo traces pack and unpack hooks and installs them as subgraph and explicitly adds to the output_graph. (As those subgraphs are not used and will not be copied in the result by default).

The complexity here is that at this moment we do not have example of inputs for the hooks.
We trace  pack_hook with some Tensor from the inputs.
The result subgraphs are added to the hashing of AotAutograd Cache.

In AotAutograd we retrace the graph with the true saved tensors coming from partitioner.

Backwards Compatibility:
As current hooks are executed in eager mode and not all of them will be traceable - we only try to put in the graph hooks, explicitly marked by user with annotation (@_inlineable_saved_tensors_hooks).
For other hooks or if compiled autograd is enabled - keep the same logic.

Recompilations:
Hooks are guarded with lambda guard matching function id to cause recompilation if user reruns compiled function.

Aot_autograd:
After partitioner prepared forward and backward module - we trace prepared at Dynamo graphs for pack and unpack hooks and inline them in epilogue of forward and prologue of backward. Forward outputs and backward inputs are changed, transparently for user.

We do not try to put it close the last usage etc., relying on inductor to do this optimization.

```
INFO: TRACED GRAPH
 ===== Forward graph pre saved_tensors_hooks inlining 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
        return (view, add, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== Backward graph pre saved_tensors_hooks inlining 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
        return (view, add, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== saved_tensors_pack_hook add 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class pack_float8(torch.nn.Module):
    def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn);  x_1 = None
        return (torch.float32, _to_copy)

INFO: TRACED GRAPH
 ===== saved_tensors_unpack_hook add 3 =====
 <eval_with_key>.22 from /data/users/ivankobzarev/a/pytorch/torch/fx/experimental/proxy_tensor.py:1225 in wrapped class pack_float8(torch.nn.Module):
    def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn);  x_1 = None
        return (torch.float32, _to_copy)

INFO: TRACED GRAPH
 ===== Forward graph 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add, dtype = torch.float8_e4m3fn)

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2]);  add = None
        return (view, _to_copy, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== Backward graph 3 =====
 <eval_with_key>.21 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", add_packed_2: "f8e4m3fn[s0, s1][s1, 1]cuda:0", tangents_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add_packed_2, dtype = torch.float32);  add_packed_2 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        add_7: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(tangents_1, _to_copy);  tangents_1 = _to_copy = None
        return (None, None, add_7)

```

Differential Revision: [D72187044](https://our.internmc.facebook.com/intern/diff/D72187044)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150032
Approved by: https://github.com/bdhirsh
2025-05-22 14:09:38 +00:00

312 lines
13 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Global flags for aot autograd
"""
import os
import sys
from typing import Literal, Optional, TYPE_CHECKING
from torch.utils._config_module import Config, install_config_module
# Converts torch rng ops to their functional philox rng equivalents. Note that
# we functionalize only CUDA rng ops today.
functionalize_rng_ops = False
# can be useful for debugging if we are incorrectly creating meta fake tensors
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
# Enables optional asserts in hotpath code to check for errors. If
# you are seeing weird accuracy problems, try turning this on.
# This is currently off by default as it will harm tracing time,
# but it is on by default for aot_eager.
debug_assert = False
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
# See # NOTE [Export custom triton op]
decompose_custom_triton_ops = True
static_weight_shapes = True
# See https://github.com/pytorch/pytorch/issues/141881
# Tells partitioner that parameters are free to save for backward.
treat_parameters_as_free_to_save = True
# Applies CSE to the graph before partitioning
cse = True
from torch._environment import is_fbcode
enable_autograd_cache: bool = Config(
justknob="pytorch/remote_cache:enable_local_autograd_cache",
env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE",
default=True,
)
autograd_cache_allow_custom_autograd_functions: bool = Config(
env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE_ALLOW_CUSTOM_AUTOGRAD", default=False
)
# For now, this is just for enabling unit testing in test_aot_autograd_cache.py
# We will either make this the default with AOTAutogradCache, or
# we'll just use it in the precompile flow. So there's no
# need to add env vars or make it configurable
bundled_autograd_cache: bool = False
def remote_autograd_cache_default() -> Optional[bool]:
if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "1":
return True
if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "0":
return False
return None
enable_remote_autograd_cache = remote_autograd_cache_default()
# When AOTAutograd regenerates aliased graph outputs,
# attempt to use functionalization's view-replay logic
# before falling back to the autograd engine's view replay or as_strided.
# This can have some perf implications
# (although for many models this will not matter).
# (1) If you have many view ops chained together, replaying all of them
# at runtime can have more overhead compared to a single as_strided call
# (2) If you are doing training, AsStridedBackward is quite slow,
# and the individual view op backward formulas will likely be faster.
# (3) Some backends like XLA do not support as_strided
# Temporary hack: disable this flag for internal
# (needed to fix an internal issue while avoiding bumping XLA pin)
# eventually: either default this config to false completely
# once XLA pin update works,
# or default config to true and fix relevant bugs
# View replay is currently not compatible with AOTAutogradCache, since
# FunctionalTensors are not serializable. We'll need to make them
# serializable before enabling warm cache with this config turned on.
view_replay_for_aliased_outputs = not is_fbcode()
# Restricts the amount of computation AOTAutograd can do.
# NB: We have essentially disabled this heuristic now. However, this is kept
# here for now in case it's useful. Setting it low can artificially reduce the
# amount of recomputation AOTAutograd performs, although not in any kind of
# principled way.
max_dist_from_bw = 1000
# Bans recomputation of nodes that are reading from nodes that is far before
# the current node
ban_recompute_used_far_apart = True
# Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily
# long chain of recomputation in the backwards pass.
ban_recompute_long_fusible_chains = True
# Bans recomputation of nodes that must be materialized in the backwards pass
# (used by a non-fusible node)
ban_recompute_materialized_backward = True
# Chooses to ban recomputation of nodes based off an allowlist. Setting it to
# False changes it to use a denylist. Main change is on operators like
# sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't
# that expensive
ban_recompute_not_in_allowlist = True
# Chooses to ban recomputation of reductions. This is generally a good idea, as
# the result of reductions is generally very small but recomputing reductions in
# a fusion can be expensive.
ban_recompute_reductions = True
# Prevents the partitioner from ever saving views (i.e. always recompute them).
# Generally a good idea since views are free to recompute.
recompute_views = False
# By default, the partitioner is purely trying to optimize for runtime (although
# it should always use less memory than eager)
# This knob controls the partitioner to make that tradeoff for you, choosing the
# fastest option that saves less activations than the memory budget.
# Specifically, 0.0 corresponds to the activation memory from applying
# activation checkpointing to the full compiled region, and 1.0 corresponds to
# the activation memory from the default runtime-optimized strategy. So, 0.4
# would result in a strategy that saves 40% of the activations compared to the
# default strategy.
# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below
# the activation memory budget.
# NOTE: This *cannot* be treated as
activation_memory_budget = 1.0
# This controls how we estimate the runtime when deciding what the cheapest
# operators to recompute are. The 3 options are
# "flops": Bases it off of the flop count provided by torch.utils.flop_counter
# "profile": Benchmarks each operator to come up with a runtime
# "testing": Returns 1 for everything
activation_memory_budget_runtime_estimator = "flops"
# This controls the solver used for the 0-1 knapsack. By default we use a
# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp"
# (which has a scipy dependency).
activation_memory_budget_solver = "dp"
# This dumps out a SVG visualization of the expected runtime vs. activation
# memory tradeoffs for all memory budget values from 0 to 1 in increments of
# 0.5. See an example here:
# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
visualize_memory_budget_pareto = (
os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1"
)
# This controls the directory in which to dump the SVG plot with the pareto
# frontier of the activation checkpointing memory-vs-runtime tradeoffs.
memory_budget_pareto_dir = os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO_DIR")
# Sets all of the ban_recompute heuristics to False except ban_recompute_reductions
# Generally, this will probably result in some memory improvement, but at the
# cost of some performance
aggressive_recomputation = False
# If FakeTensor.data_ptr() should error.
# This option is independent of AOTAutograd and torch.compile, but our policy
# is to turn it off during torch.compile.
fake_tensor_allow_unsafe_data_ptr_access = True
# Unlifts effect tokens from the inputs/outputs in the traced graph and instead
# inserts make_token/sink_token calls in the graph to create tokens and then
# sink them at the end. Note that this means the graph is no longer functional
# which may lead to silent errors unless the backend knows how to handle the
# tokens.
unlift_effect_tokens = False
# NOTE: [The default layout constraint for custom operators.]
# This must be the name of one of the layout constraint tags
# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
# If the custom op does not have a layout constraint tag already
# then we assume the following applies.
#
# This config is respected by Inductor and we recommend other backends also
# respect it.
# This config is in torch._functorch and not torch._inductor because it affects
# ProxyTensor tracing.
custom_op_default_layout_constraint: Literal[
"needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
] = "needs_exact_strides"
# Run aot eager decomp partition with CrossRefFakeMode
# options = False, "all", "custom_ops"
fake_tensor_crossref = False
# This mode specifies that we should also keep track of the real
# tensor along with the fake tensor, and do real compute. While
# seemingly this eliminates the whole point of fake tensors, there are
# two obvious use cases for it:
#
# 1. When users call item()/other data dependent operations,
# if we propagate_real_tensors we are able to determine what
# the true value is and keep going.
#
# 2. It can be useful for testing, when you want to see if the fake
# and real tensors agree with each other. (Note that there are
# currently known inaccuracies in how we clone real tensors, that
# would have to be tightened up for this to be useful in this
# case.)
#
# Note that fake tensors are typically understood to be cheap to store
# indefinitely, so we tend to hold on to them longer than we would
# hold onto the real tensors. So we also support you explicitly
# deallocating the real tensor associated with a fake tensor, at which
# point we will stop propagating real tensors.
#
# One more thing: when you provide a real tensor to fakeify, we will
# clone it, so that we can safely perform mutations on it if necessary.
# This will increase live memory usage. This could potentially be
# optimized by using COW. We also currently do not faithfully
# maintain autograd metadata on the real tensor; this is fine because
# AOTAutograd will only use the fake tensor to determine leafness/etc
# of tensors in question.
fake_tensor_propagate_real_tensors = False
# This controls whether we collect donated buffer. This flag must be set
# False if a user wants to retain_graph=True for backward.
donated_buffer = False if is_fbcode() else True
# Controls the default graph output format used by draw_graph
# Supported formats are defined here https://graphviz.org/docs/outputs/
torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")
# Valid only if fake_tensor_propagate_real_tensors = True; if a fake-real
# kernel mismatch is detected, bypasses by making a fake kernel from the
# real tensor outputs.
generate_fake_kernels_from_real_mismatches = False
# CUDAGraph save run_with_rng functionalization.
# TODO: turn on by default
graphsafe_rng_functionalization = True
# Error on BypassAOTAutogradCache instead of just a warning
# Used for tests
strict_autograd_cache = False
# Note [Recomputing collectives in the partitioner]
# The purpose of this config is as follows:
# - We have many passes in the compiler (min-cut partitioning, DCE, etc)
# which can reorder or ,delete duplicate nodes in the graph
# - If any of these passes reorder/delete/duplicate a collective
# in a setting where the compiler is being run independently on multiple
# ranks, we run the risk that the compiler will make a different decison on
# different ranks, resulting in a NCCL hang when using torch.compile
# To handle this, we will (by default) ensure that collectives are not modified
# by the compiler.
#
# A few examples:
# - don't dead-code-eliminate collectives
# (in case they are dead on rank i but not rank j)
# - don't recompute collectives in partitioning
# (in case we recompute on rank i but not rank j)
#
# Today this flag **must** be set to false, but eventually
# we want the option to set it to true.
# In order to potentially optimize collectives, we'll need the compiler
# to broadcast information across ranks at compile time to ensure
# that any decisions on collectives are made consistently.
unsafe_allow_optimization_of_collectives = False
# See Note [AOTAutograd Tangent Subclassness for mutated inputs]
# TODO(ivankobzarev): Remove this config, being able to deduce it compile time.
disable_guess_zero_tangent_for_mutated_input_subclass = False
# See Note [Tangents memory format]
# By default tangents strideness is guessed to be contiguous,
# At runtime non contiguous tangents will be coerced to be contiguous.
# This config changes this guess for tangents strides to be the same as outputs.
# TODO(ivankobzarev): Remove this config once extra memory usage is investigated.
guess_tangent_strides_as_outputs = False
# This is a temporary config to ensure all ranks take the same decision in the partitioner
# it will untimately be removed once we share size_hints across ranks through compiler collectives
_broadcast_rank0_decision = False
# By default apply inlined saved_tensors_hooks only for "donated" buffers.
# "donated" buffers are invisible to the user, they are intermediates of the forward graph.
# Applying saved tensors hooks for memory optimizations only for intermediates
# guarantees that original saved tensors could be deallocated.
# This config enables saved_tensors_hooks are applied for **all** saved tensors,
# that could include inputs, parameters, outputs.
# "donated" - applied only to saved intermediates of the graph
# "no_static" - applied to all saved but not "static"
# (this includes parameters and user marked as static)
# "all" - no filtering, everything saved for backward.
saved_tensors_hooks_filtering_mode = "donated"
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
# adds patch, save_config, invalid config checks, etc
install_config_module(sys.modules[__name__])