mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good). This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern. ------ Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_` - `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries` - `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients` - `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching` - `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager` - `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32` - `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730 Approved by: https://github.com/bdhirsh
1935 lines
69 KiB
Python
1935 lines
69 KiB
Python
# mypy: allow-untyped-defs
|
|
import copy
|
|
import functools
|
|
import heapq
|
|
import itertools
|
|
import logging
|
|
import math
|
|
import operator
|
|
import os
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, replace
|
|
from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
import torch._inductor.inductor_prims
|
|
import torch.fx as fx
|
|
import torch.utils._pytree as pytree
|
|
from torch.fx.experimental._backward_state import BackwardState
|
|
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
|
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
find_symbol_binding_fx_nodes,
|
|
free_symbols,
|
|
hint_int,
|
|
is_symbol_binding_fx_node,
|
|
)
|
|
from torch.fx.passes import graph_drawer
|
|
from torch.utils.checkpoint import CheckpointPolicy
|
|
|
|
from . import config
|
|
from ._aot_autograd.logging_utils import get_aot_graph_name
|
|
from ._aot_autograd.utils import is_with_effects
|
|
from .compile_utils import fx_graph_cse, get_aten_target
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import sympy
|
|
|
|
|
|
AOT_PARTITIONER_DEBUG = config.debug_partitioner
|
|
log = logging.getLogger(__name__)
|
|
|
|
aten = torch.ops.aten
|
|
prims = torch.ops.prims
|
|
|
|
|
|
@dataclass
|
|
class OpTypes:
|
|
"""Class for keeping track of different operator categories"""
|
|
|
|
fusible_ops: Set[Callable]
|
|
compute_intensive_ops: Set[Callable]
|
|
random_ops: Set[Callable]
|
|
view_ops: Set[Callable]
|
|
recomputable_ops: Set[Callable]
|
|
|
|
def is_fusible(self, node: fx.Node):
|
|
return get_aten_target(node) in self.fusible_ops
|
|
|
|
def is_compute_intensive(self, node: fx.Node):
|
|
return get_aten_target(node) in self.compute_intensive_ops
|
|
|
|
def is_random(self, node: fx.Node):
|
|
return get_aten_target(node) in self.random_ops
|
|
|
|
def is_view(self, node: fx.Node):
|
|
return get_aten_target(node) in self.view_ops
|
|
|
|
def is_recomputable(self, node: fx.Node):
|
|
return get_aten_target(node) in self.recomputable_ops
|
|
|
|
|
|
@dataclass
|
|
class NodeInfo:
|
|
# Be careful about iterating over these explicitly, as their order may not
|
|
# be deterministic
|
|
inputs: List[fx.Node]
|
|
_required_fw_nodes: Set[fx.Node]
|
|
required_bw_nodes: Set[fx.Node]
|
|
unclaimed_nodes: Set[fx.Node]
|
|
fw_order: Dict[fx.Node, int]
|
|
|
|
@functools.cached_property
|
|
def required_fw_nodes(self) -> List[fx.Node]:
|
|
return sorted(
|
|
(n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n]
|
|
)
|
|
|
|
def is_required_fw(self, n: fx.Node) -> bool:
|
|
return n in self._required_fw_nodes
|
|
|
|
def is_required_bw(self, n: fx.Node) -> bool:
|
|
return n in self.required_bw_nodes
|
|
|
|
def is_unclaimed(self, n: fx.Node) -> bool:
|
|
return n in self.unclaimed_nodes
|
|
|
|
def get_fw_order(self, n: fx.Node) -> int:
|
|
assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!"
|
|
return self.fw_order[n]
|
|
|
|
|
|
@dataclass
|
|
class MinCutOptions:
|
|
ban_if_used_far_apart: bool
|
|
ban_if_long_fusible_chains: bool
|
|
ban_if_materialized_backward: bool
|
|
ban_if_not_in_allowlist: bool
|
|
ban_if_reduction: bool
|
|
|
|
|
|
def must_recompute(node: fx.Node) -> bool:
|
|
return node.meta.get("recompute", None) in [
|
|
CheckpointPolicy.MUST_RECOMPUTE,
|
|
CheckpointPolicy.PREFER_RECOMPUTE,
|
|
]
|
|
|
|
|
|
def has_recomputable_ops(fx_g: fx.GraphModule) -> bool:
|
|
found = False
|
|
for node in fx_g.graph.nodes:
|
|
if must_recompute(node):
|
|
return True
|
|
return False
|
|
|
|
|
|
def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool:
|
|
for node in fx_g.graph.nodes:
|
|
if (
|
|
must_recompute(node)
|
|
and hasattr(node.target, "tags")
|
|
and torch.Tag.nondeterministic_seeded in node.target.tags
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
def sym_node_size(node: fx.Node) -> int:
|
|
if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)):
|
|
return 1
|
|
assert isinstance(node.meta["val"], torch.SymFloat)
|
|
return 4
|
|
|
|
|
|
class InvalidNodeBase:
|
|
def __repr__(self):
|
|
return "Invalid Node"
|
|
|
|
|
|
InvalidNode = InvalidNodeBase()
|
|
|
|
|
|
def _extract_graph_with_inputs_outputs(
|
|
joint_graph: fx.Graph,
|
|
inputs: List[fx.Node],
|
|
outputs: List[fx.Node],
|
|
subgraph: Optional[str] = None,
|
|
) -> fx.Graph:
|
|
"""
|
|
Given a graph, extracts out a subgraph that takes the specified nodes as
|
|
inputs and returns the specified outputs.
|
|
|
|
This includes specifying non-placeholder nodes as inputs.
|
|
|
|
The general strategy is to initialize all inputs with proxies as we
|
|
encounter them, and trace through the graph, only keeping values which take
|
|
in valid proxies. Then, all dead code is eliminated.
|
|
"""
|
|
new_graph = fx.Graph()
|
|
env = {}
|
|
|
|
# Add new placeholder nodes in the order specified by the inputs
|
|
for node in inputs:
|
|
new_node = new_graph.placeholder(node.name)
|
|
# Can't use node_copy here as we may be turning previous call_function into placeholders
|
|
new_node.meta = node.meta
|
|
env[node] = new_node
|
|
|
|
for node in joint_graph.nodes:
|
|
if _must_be_in_backward(node) and subgraph != "backward":
|
|
env[node] = InvalidNode # type: ignore[assignment]
|
|
continue
|
|
|
|
if node in env:
|
|
# Node must be one of our inputs. (Any member of env which wasn't an
|
|
# input to start must have been created by this loop and won't be in
|
|
# joint_graph.nodes).
|
|
continue
|
|
elif node.op == "placeholder":
|
|
env[node] = InvalidNode # type: ignore[assignment]
|
|
elif node.op == "call_function":
|
|
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
|
|
all_args = [
|
|
isinstance(env[x], InvalidNodeBase)
|
|
for x in all_args
|
|
if isinstance(x, fx.Node)
|
|
]
|
|
if any(all_args):
|
|
env[node] = InvalidNode # type: ignore[assignment]
|
|
continue
|
|
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
|
elif node.op == "get_attr":
|
|
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
|
elif node.op == "output":
|
|
pass
|
|
output_values = []
|
|
for x in outputs:
|
|
if isinstance(x, fx.Node):
|
|
if x not in env:
|
|
raise RuntimeError(f"Node {x} couldn't be found in env")
|
|
assert not isinstance(
|
|
env[x], InvalidNodeBase
|
|
), f"Node {x} was invalid, but is output"
|
|
output_values.append(env[x])
|
|
else:
|
|
output_values.append(x)
|
|
new_graph.output(tuple(output_values))
|
|
|
|
new_graph.eliminate_dead_code()
|
|
new_graph.lint()
|
|
return new_graph
|
|
|
|
|
|
def _is_primal(node: fx.Node) -> bool:
|
|
return (
|
|
node.op == "placeholder"
|
|
and "tangents" not in str(node.target)
|
|
and not _is_bwd_seed_offset(node)
|
|
and not _is_fwd_seed_offset(node)
|
|
)
|
|
|
|
|
|
def _is_tangent(node: fx.Node) -> bool:
|
|
return node.op == "placeholder" and "tangents" in str(node.target)
|
|
|
|
|
|
def _is_bwd_seed_offset(node: fx.Node) -> bool:
|
|
return node.op == "placeholder" and (
|
|
"bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target)
|
|
)
|
|
|
|
|
|
def _is_fwd_seed_offset(node: fx.Node) -> bool:
|
|
return node.op == "placeholder" and (
|
|
"fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target)
|
|
)
|
|
|
|
|
|
def _is_backward_state(node: fx.Node) -> bool:
|
|
return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState)
|
|
|
|
|
|
def _has_tag_is_backward(node: fx.Node) -> bool:
|
|
return node.meta.get("partitioner_tag", None) == "is_backward"
|
|
|
|
|
|
def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
|
|
return node.meta.get("partitioner_tag", None) == "must_be_in_backward"
|
|
|
|
|
|
def _must_be_in_backward(node: fx.Node) -> bool:
|
|
return _has_tag_must_be_in_backward(node) or (
|
|
_has_tag_is_backward(node) and is_with_effects(node)
|
|
)
|
|
|
|
|
|
def _extract_fwd_bwd_outputs(
|
|
joint_module: fx.GraphModule, *, num_fwd_outputs
|
|
) -> Tuple[List[fx.Node], List[fx.Node]]:
|
|
outputs = pytree.arg_tree_leaves(
|
|
*(node.args for node in joint_module.graph.find_nodes(op="output"))
|
|
)
|
|
fwd_outputs = outputs[:num_fwd_outputs]
|
|
bwd_outputs = outputs[num_fwd_outputs:]
|
|
return fwd_outputs, bwd_outputs
|
|
|
|
|
|
def _remove_by_name(saved_values: List[fx.Node], name: str):
|
|
for saved_value in saved_values:
|
|
if saved_value.name == name:
|
|
saved_values.remove(saved_value)
|
|
break
|
|
|
|
|
|
def _extract_fwd_bwd_modules(
|
|
joint_module: fx.GraphModule,
|
|
saved_values: List[fx.Node],
|
|
saved_sym_nodes: List[fx.Node],
|
|
*,
|
|
num_fwd_outputs: int,
|
|
) -> Tuple[fx.GraphModule, fx.GraphModule]:
|
|
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
|
|
joint_module, num_fwd_outputs=num_fwd_outputs
|
|
)
|
|
placeholders = joint_module.graph.find_nodes(op="placeholder")
|
|
primal_inputs = [*filter(_is_primal, placeholders)]
|
|
tangent_inputs = [*filter(_is_tangent, placeholders)]
|
|
fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
|
|
bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
|
|
backward_state_inputs = [*filter(_is_backward_state, placeholders)]
|
|
|
|
bwd_graph = _extract_graph_with_inputs_outputs(
|
|
joint_module.graph,
|
|
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
|
|
bwd_outputs,
|
|
"backward",
|
|
)
|
|
|
|
for node in bwd_graph.find_nodes(op="placeholder"):
|
|
# This is to filter out saved values that don't actually end up being used by the backwards pass
|
|
if not node.users:
|
|
_remove_by_name(saved_values, node.name)
|
|
_remove_by_name(saved_sym_nodes, node.name)
|
|
elif _is_backward_state(node):
|
|
# BackwardState is saved directly
|
|
_remove_by_name(saved_values, node.name)
|
|
assert backward_state_inputs
|
|
|
|
# Now that we have the finalized list of saved values, we need to ensure
|
|
# we propagate all symbols which are referenced by backwards inputs.
|
|
# These are not directly used in the graph but are required for downstream
|
|
# sizevar assignment
|
|
saved_symbols: Set[sympy.Symbol] = set()
|
|
saved_sym_nodes_binding = []
|
|
saved_sym_nodes_derived = []
|
|
|
|
# Some symbols may already be bound in the directly saved_sym_nodes,
|
|
# keep track of them so we don't re-bind them
|
|
for node in saved_sym_nodes:
|
|
symbol = is_symbol_binding_fx_node(node)
|
|
if symbol:
|
|
saved_symbols.add(symbol)
|
|
saved_sym_nodes_binding.append(node)
|
|
else:
|
|
saved_sym_nodes_derived.append(node)
|
|
|
|
# Now go through all of the prospective backward inputs and track any
|
|
# other symbols we need to bind
|
|
symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph)
|
|
for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs):
|
|
if "val" not in node.meta:
|
|
continue
|
|
new_symbols = free_symbols(node.meta["val"]) - saved_symbols
|
|
# NB: Deterministic order please!
|
|
for s in sorted(new_symbols, key=lambda s: s.name):
|
|
# NB: For well formed graphs, the symbol should always be present,
|
|
# but we also have ways to produce ill-formed graphs, e.g., direct
|
|
# make_fx usages, so don't choke in this case
|
|
if s not in symbol_bindings:
|
|
continue
|
|
saved_sym_nodes_binding.append(symbol_bindings[s])
|
|
saved_symbols |= new_symbols
|
|
|
|
# Update saved_sym_nodes that are now reordered to have all bindings at
|
|
# front. This can also be used later on to figure out the position of saved
|
|
# sym nodes in the output of fwd graph.
|
|
saved_sym_nodes.clear()
|
|
saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)
|
|
|
|
# Now, we re-generate the fwd/bwd graphs.
|
|
# NB: This might increase compilation time, but I doubt it matters
|
|
fwd_graph = _extract_graph_with_inputs_outputs(
|
|
joint_module.graph,
|
|
primal_inputs + fwd_seed_offset_inputs,
|
|
fwd_outputs + saved_values + saved_sym_nodes,
|
|
"forward",
|
|
)
|
|
bwd_graph = _extract_graph_with_inputs_outputs(
|
|
joint_module.graph,
|
|
saved_sym_nodes
|
|
+ saved_values
|
|
+ tangent_inputs
|
|
+ bwd_seed_offset_inputs
|
|
+ backward_state_inputs,
|
|
bwd_outputs,
|
|
"backward",
|
|
)
|
|
|
|
fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph)
|
|
bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph)
|
|
return fwd_module, bwd_module
|
|
|
|
|
|
def default_partition(
|
|
joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs
|
|
) -> Tuple[fx.GraphModule, fx.GraphModule]:
|
|
"""
|
|
Partitions the :attr:`joint_module` in a manner that closely resembles the
|
|
behavior observed in the original ``.forward()`` and ``.backward()`` of the
|
|
callable, i.e., the resulting forward graph contains those operators that
|
|
are executed in the original ``.forward()`` callable passed to
|
|
:func:`aot_function`.
|
|
|
|
The default partitioner collects the operators that are between the forward
|
|
inputs and the forward outputs. This helps in finding the tensors which have
|
|
to be stashed for the backward pass. These stashed tensors become the output
|
|
of the generated forward graph. The remaining operators are then placed in
|
|
the backward graph.
|
|
|
|
.. warning::
|
|
This API is experimental and likely to change.
|
|
|
|
Args:
|
|
joint_module(fx.GraphModule): The joint forward and backward graph. This
|
|
is the result of AOT Autograd tracing.
|
|
|
|
Returns:
|
|
Returns the generated forward and backward Fx graph modules.
|
|
"""
|
|
if has_recomputable_ops(joint_module):
|
|
return min_cut_rematerialization_partition(
|
|
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs
|
|
)
|
|
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
|
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
|
inputs = primal_inputs + fwd_seed_offset_inputs
|
|
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
|
|
joint_module, num_fwd_outputs=num_fwd_outputs
|
|
)
|
|
forward_only_graph = _extract_graph_with_inputs_outputs(
|
|
joint_module.graph, inputs, fwd_outputs, "forward"
|
|
)
|
|
forward_node_names = {
|
|
node.name for node in forward_only_graph.nodes if node.op != "output"
|
|
}
|
|
saved_values = []
|
|
saved_sym_nodes = []
|
|
|
|
for node in joint_module.graph.nodes:
|
|
if node.name not in forward_node_names:
|
|
continue
|
|
if is_sym_node(node):
|
|
# Symints must be kept separate from tensors so that PythonFunction only calls
|
|
# save_for_backward on tensors and stashes symints in autograd .ctx
|
|
saved_sym_nodes.append(node)
|
|
elif "tensor_meta" not in node.meta and node.op == "call_function":
|
|
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
|
|
users = node.users
|
|
assert all(user.target == operator.getitem for user in users)
|
|
saved_values.extend(users)
|
|
else:
|
|
backward_usages = [
|
|
n for n in node.users if n.name not in forward_node_names
|
|
]
|
|
if "tensor_meta" in node.meta and all(
|
|
is_sym_node(n) for n in backward_usages
|
|
):
|
|
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
|
# and not the actual tensor data,
|
|
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
|
#
|
|
# Note that saving the tensor could also cause compilation problems:
|
|
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
|
# then we would be obligated to clone the input before saving it to appease autograd.
|
|
# (This is how we originally found this bug).
|
|
saved_sym_nodes.extend(backward_usages)
|
|
else:
|
|
saved_values.append(node)
|
|
saved_values = list(dict.fromkeys(saved_values).keys())
|
|
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
|
|
|
return _extract_fwd_bwd_modules(
|
|
joint_module,
|
|
saved_values,
|
|
saved_sym_nodes=saved_sym_nodes,
|
|
num_fwd_outputs=num_fwd_outputs,
|
|
)
|
|
|
|
|
|
INT_INF = int(1e6)
|
|
|
|
|
|
def _tensor_nbytes(numel: int, dtype) -> int:
|
|
return numel * dtype.itemsize
|
|
|
|
|
|
def _size_of(node: fx.Node) -> int:
|
|
def object_nbytes(x) -> int:
|
|
if not isinstance(x, torch.Tensor):
|
|
return 0
|
|
return _tensor_nbytes(hint_int(x.numel(), fallback=4096), x.dtype)
|
|
|
|
if "val" in node.meta:
|
|
val = node.meta["val"]
|
|
if isinstance(val, py_sym_types):
|
|
return 1
|
|
# NB: The fallback values here are meaningless, maybe we should respect
|
|
# torch._inductor.config.unbacked_symint_fallback (but this is a
|
|
# layering violation)
|
|
elif isinstance(val, (list, tuple)):
|
|
return sum(object_nbytes(n) for n in val)
|
|
elif isinstance(val, dict):
|
|
return sum(object_nbytes(n) for _, n in val.items())
|
|
elif isinstance(val, torch.Tensor):
|
|
return object_nbytes(val)
|
|
|
|
raise RuntimeError(f"Unknown metadata type {type(val)} on node {node}")
|
|
if node.op == "get_attr":
|
|
return 0
|
|
raise RuntimeError(
|
|
f"Node {node} didn't have `val` metadata; we should always have `val` metadata on the nodes."
|
|
)
|
|
|
|
|
|
# Used for some investigative purposes
|
|
def _count_ops(graph: fx.Graph):
|
|
from collections import defaultdict
|
|
|
|
cnt: Dict[str, int] = defaultdict(int)
|
|
for node in graph.nodes:
|
|
if node.op == "call_function":
|
|
cnt[node.target.__name__] += 1
|
|
print(sorted(cnt.items(), key=lambda x: x[1], reverse=True))
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def pointwise_ops():
|
|
ops = []
|
|
for attr_name in dir(torch.ops.aten):
|
|
opoverloadpacket = getattr(torch.ops.aten, attr_name)
|
|
if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket):
|
|
continue
|
|
|
|
for overload in opoverloadpacket.overloads():
|
|
op_overload = getattr(opoverloadpacket, overload)
|
|
if torch.Tag.pointwise in op_overload.tags:
|
|
# currently aot autograd uses packet not overload
|
|
ops.append(opoverloadpacket)
|
|
break
|
|
|
|
return ops
|
|
|
|
|
|
def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]:
|
|
arg_depths = {
|
|
arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node)
|
|
}
|
|
return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
|
"""
|
|
This pass finds the first bwd node in the graph (by looking at users of
|
|
tangents) and then reorders the graph by walking from this node to all the
|
|
way to the end of the graph. At each op in this traveral, we insert this op
|
|
in a new graph and try to bring only the relevant subgraph from the other
|
|
non-bwd edges relevant for this op. This closely mimics the behavior of
|
|
autograd engine.
|
|
|
|
Why is this pass required in the first place?
|
|
|
|
This is an artifact of how partitioners work today. The starting point of
|
|
partitioner is a joint graph, which is fwd and then bwd graph. In the case
|
|
of checkpointing, we keep portions of fwd graph in their original place in
|
|
the joint graph, while obtaining a bwd graph. As a result, the resulting bwd
|
|
graph has copies of recomputed fwd subgraphs followed by the original bwd
|
|
graph. If we run this naively, this leads to bad memory footprint, because
|
|
the fwd subgraphs are live for way longer duration than necessary. This pass
|
|
reorders the operations such that we prioritize the ops for the original bwd
|
|
graph while only realizing those ops from the fwd graph that are necessary
|
|
at any given point in the graph.
|
|
"""
|
|
|
|
new_graph = fx.Graph()
|
|
env: Dict[fx.Node, fx.Node] = {}
|
|
|
|
# Add new placeholder nodes in the order specified by the inputs
|
|
for node in gm.graph.find_nodes(op="placeholder"):
|
|
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
|
|
|
order = {}
|
|
for idx, node in enumerate(gm.graph.nodes):
|
|
order[node] = idx
|
|
|
|
def insert_node_in_graph(node):
|
|
cur_nodes = [node]
|
|
insertable_nodes = set()
|
|
while len(cur_nodes) > 0:
|
|
node = cur_nodes.pop()
|
|
if node in insertable_nodes or node in env:
|
|
continue
|
|
insertable_nodes.add(node)
|
|
|
|
# Bias traversal towards the nodes that have higher depth - prioritizes
|
|
# critical path first.
|
|
cur_nodes += node.all_input_nodes
|
|
|
|
insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
|
|
for node in insertable_nodes:
|
|
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
|
|
|
# Find first bwd node in the graph
|
|
tangent_inputs = list(filter(_is_tangent, gm.graph.nodes))
|
|
first_node_in_bwd = None
|
|
minimum_order = math.inf
|
|
for tangent in tangent_inputs:
|
|
for user in tangent.users:
|
|
if order[user] < minimum_order:
|
|
minimum_order = order[user]
|
|
first_node_in_bwd = user
|
|
|
|
# If gradInp does not depend upon gradOut, we may not find any nodes in the "backwards pass"
|
|
if first_node_in_bwd is None:
|
|
return gm
|
|
|
|
# Build the graph op-by-op by starting from the node all the way to the end
|
|
for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]:
|
|
insert_node_in_graph(node)
|
|
|
|
# The output node is already built by the traversal.
|
|
new_gm = torch.fx.GraphModule(gm, new_graph)
|
|
return new_gm
|
|
|
|
|
|
def functionalize_rng_ops(
|
|
joint_module: fx.GraphModule,
|
|
fw_module: fx.GraphModule,
|
|
bw_module: fx.GraphModule,
|
|
num_sym_nodes: int,
|
|
) -> Tuple[fx.GraphModule, fx.GraphModule]:
|
|
# During user-driven activation checkpointing, we have to ensure that a rng
|
|
# op in fwd yields the same output as the recomputed rng op in the bwd. To
|
|
# do this, we use functionalize wrappers to wrap the random ops and share
|
|
# rng state between the fwd and bwd graphs.
|
|
|
|
# There are 3 main steps to do this
|
|
# Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
|
|
# Step 2 - Modify the fwd pass such that
|
|
# 1) Replace rand with run_and_save_rng_state wrapper
|
|
# 2) Replace the users of the original op with the output[1] of this op.
|
|
# 3) Collect all the rng_state - output[0] of each op, and make them
|
|
# output nodes. Special care needs to be taken here because fwd outputs
|
|
# has symints at the very end.
|
|
# Step 3 - Modify the bwd pass such that
|
|
# 1) Add the input nodes just before the tangents for the stashed rng states
|
|
# 2) Replace rand with run_with_save_rng_state wrappers
|
|
# 3) Use the stashed states as inputs to these ops
|
|
|
|
# Unique id to generate name
|
|
uid = itertools.count()
|
|
|
|
def get_rng_ops(gmod):
|
|
random_nodes = {}
|
|
for node in gmod.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and hasattr(node.target, "tags")
|
|
and torch.Tag.nondeterministic_seeded in node.target.tags
|
|
):
|
|
random_nodes[node.name] = node
|
|
return random_nodes
|
|
|
|
def get_device(node):
|
|
"""
|
|
Check the example value of the node outputs to find the device type.
|
|
"""
|
|
if "val" not in node.meta:
|
|
return None
|
|
|
|
candidates = node.meta["val"]
|
|
if not isinstance(candidates, tuple):
|
|
candidates = (candidates,)
|
|
|
|
for candidate in candidates:
|
|
if isinstance(candidate, torch.Tensor):
|
|
if candidate.device.type == "cuda":
|
|
return "cuda"
|
|
|
|
return "cpu"
|
|
|
|
def get_sample_rng_state(device):
|
|
if device == "cuda":
|
|
return torch.cuda.get_rng_state()
|
|
return torch.get_rng_state()
|
|
|
|
# Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
|
|
joint_graph_rng_ops = get_rng_ops(joint_module)
|
|
fw_graph_rng_ops = get_rng_ops(fw_module)
|
|
bw_graph_rng_ops = get_rng_ops(bw_module)
|
|
recomputable_rng_ops_map = {}
|
|
for node in joint_module.graph.nodes:
|
|
if (
|
|
must_recompute(node)
|
|
and hasattr(node.target, "tags")
|
|
and torch.Tag.nondeterministic_seeded in node.target.tags
|
|
):
|
|
base_node = joint_graph_rng_ops[node.name]
|
|
fw_node = fw_graph_rng_ops[node.name]
|
|
bw_node = bw_graph_rng_ops[node.name]
|
|
recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node}
|
|
|
|
run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state
|
|
run_with_rng_state = torch._prims.rng_prims.run_with_rng_state
|
|
bw_tangent_start_node = None
|
|
for node in bw_module.graph.find_nodes(op="placeholder"):
|
|
if "tangent" in node.name:
|
|
bw_tangent_start_node = node
|
|
break
|
|
if bw_tangent_start_node is None:
|
|
raise RuntimeError(
|
|
"Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this"
|
|
)
|
|
|
|
fw_rng_state_outputs = []
|
|
for base_node, node_pair in recomputable_rng_ops_map.items():
|
|
# Step 2 - Modify the fwd pass such that
|
|
fw_node = node_pair["fwd"]
|
|
bw_node = node_pair["bwd"]
|
|
fw_graph = fw_module.graph
|
|
with fw_graph.inserting_before(fw_node):
|
|
functional_fw_node = fw_graph.create_node(
|
|
"call_function",
|
|
run_and_save_rng,
|
|
args=(fw_node.target, *fw_node.args),
|
|
kwargs=fw_node.kwargs,
|
|
)
|
|
state = fw_graph.create_node(
|
|
"call_function",
|
|
operator.getitem,
|
|
args=(functional_fw_node, 0),
|
|
kwargs={},
|
|
)
|
|
rng_output = fw_graph.create_node(
|
|
"call_function",
|
|
operator.getitem,
|
|
args=(
|
|
functional_fw_node,
|
|
1,
|
|
),
|
|
kwargs={},
|
|
)
|
|
fw_node.replace_all_uses_with(rng_output)
|
|
fw_graph.erase_node(fw_node)
|
|
fw_rng_state_outputs.append(state)
|
|
|
|
# Step 3 - Modify the bwd pass such that
|
|
bw_graph = bw_module.graph
|
|
with bw_graph.inserting_before(bw_tangent_start_node):
|
|
state_name = f"rng_state_output_{next(uid)}"
|
|
bw_rng_state_node = bw_graph.placeholder(state_name)
|
|
bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node))
|
|
|
|
with bw_graph.inserting_before(bw_node):
|
|
rng_output = bw_graph.create_node(
|
|
"call_function",
|
|
run_with_rng_state,
|
|
args=(bw_rng_state_node, bw_node.target, *bw_node.args),
|
|
kwargs=bw_node.kwargs,
|
|
)
|
|
|
|
bw_node.replace_all_uses_with(rng_output)
|
|
bw_graph.erase_node(bw_node)
|
|
|
|
# Add the rng states in the output of the fwd graph. AOT Autograd assumes
|
|
# that symints are at the end of forward graph outputs. So, insert the new
|
|
# rng states accordingly.
|
|
fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
|
|
fw_outputs = fw_output_node.args[0]
|
|
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
|
|
outputs = (
|
|
fw_outputs[:sym_node_start_idx]
|
|
+ tuple(fw_rng_state_outputs)
|
|
+ fw_outputs[sym_node_start_idx:]
|
|
)
|
|
fw_module.graph.output(outputs)
|
|
fw_module.graph.erase_node(fw_output_node)
|
|
fw_module.recompile()
|
|
bw_module.recompile()
|
|
return fw_module, bw_module
|
|
|
|
|
|
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
|
"""
|
|
If there are two consecutive checkpointed blocks with no operator in
|
|
between, we would still want to stash the tensor at the boundary of
|
|
checkpointed blocks. The following pass makes the last output node
|
|
non-recomputable to allow for that.
|
|
"""
|
|
for node in joint_module.graph.nodes:
|
|
if must_recompute(node):
|
|
for user in node.users:
|
|
if (
|
|
must_recompute(user)
|
|
and user.meta["ac_graph_id"] > node.meta["ac_graph_id"]
|
|
):
|
|
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
|
return joint_module
|
|
|
|
|
|
def solve_min_cut(
|
|
joint_graph: fx.Graph,
|
|
node_info: NodeInfo,
|
|
min_cut_options: MinCutOptions,
|
|
dont_ban=None,
|
|
):
|
|
if dont_ban is None:
|
|
dont_ban = set()
|
|
op_types = get_default_op_list()
|
|
|
|
if AOT_PARTITIONER_DEBUG:
|
|
joint_module_ops = {
|
|
str(node.target._overloadpacket)
|
|
for node in joint_graph.nodes
|
|
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
|
|
}
|
|
ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops}
|
|
print("Ops banned from re-materialization: ", ops_ignored)
|
|
print()
|
|
|
|
def can_fuse_into_auto_functionalized(a, b):
|
|
if b.target != torch.ops.higher_order.auto_functionalized:
|
|
return False
|
|
mutable_op = b.args[0]
|
|
(
|
|
mutable_arg_names,
|
|
_,
|
|
) = torch._higher_order_ops.auto_functionalize.get_mutable_args(mutable_op)
|
|
for name in mutable_arg_names:
|
|
arg = b.kwargs[name]
|
|
if a is arg:
|
|
return True
|
|
if isinstance(arg, list):
|
|
if a in arg:
|
|
return True
|
|
return False
|
|
|
|
def can_fuse_into_triton_kernel_wrapper_functional(a, b):
|
|
if b.target != torch.ops.higher_order.triton_kernel_wrapper_functional:
|
|
return False
|
|
mutable_arg_names = b.kwargs["tensors_to_clone"]
|
|
for name in mutable_arg_names:
|
|
arg = b.kwargs["kwargs"][name]
|
|
if a is arg:
|
|
return True
|
|
return False
|
|
|
|
def is_fusible(a, b):
|
|
# We can perform "memory fusion" into a cat, but cat cannot be a
|
|
# producer to a fusion
|
|
if get_aten_target(b) == aten.cat:
|
|
return True
|
|
if can_fuse_into_auto_functionalized(a, b):
|
|
return True
|
|
if can_fuse_into_triton_kernel_wrapper_functional(a, b):
|
|
return True
|
|
return op_types.is_fusible(a) and op_types.is_fusible(b)
|
|
|
|
try:
|
|
import networkx as nx
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Need networkx installed to perform smart recomputation " "heuristics"
|
|
) from e
|
|
|
|
def is_materialized_backwards(node):
|
|
if op_types.is_view(node):
|
|
return False
|
|
cur_nodes = {node}
|
|
while len(cur_nodes) > 0:
|
|
cur = cur_nodes.pop()
|
|
for user in cur.users:
|
|
if not node_info.is_required_fw(user) and not is_fusible(cur, user):
|
|
return True
|
|
if op_types.is_view(user):
|
|
cur_nodes.add(user)
|
|
|
|
return False
|
|
|
|
def should_ban_recomputation(node):
|
|
if node.op != "call_function":
|
|
return False
|
|
if node.target == operator.getitem:
|
|
return False
|
|
if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE:
|
|
return True
|
|
if config.recompute_views and op_types.is_view(node):
|
|
return False
|
|
if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]:
|
|
return False
|
|
|
|
if min_cut_options.ban_if_not_in_allowlist:
|
|
if not op_types.is_recomputable(node):
|
|
return True
|
|
else:
|
|
if op_types.is_random(node) or op_types.is_compute_intensive(node):
|
|
return True
|
|
|
|
# If a node *must* be materialized in the backwards pass, then we
|
|
# should never recompute it. This is a pretty subtle point. In
|
|
# general, the assumption we make is that recomputing a node in the
|
|
# backwards pass is "free". However, if a node must be materialized
|
|
# in the backwards pass, then recomputing it is never free.
|
|
if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(
|
|
node
|
|
):
|
|
log.info("materialized backwards: %s %s", node, tuple(node.users))
|
|
return True
|
|
|
|
# Arbitrary hack that sometimes seems to help things. The above
|
|
# modification appears to have made this heuristic a lot less critical
|
|
# for performance.
|
|
# NB: As of PR #121692, this hack no longer seems necessary.
|
|
if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw:
|
|
return True
|
|
|
|
# If the output of an op is 4x smaller (arbitrary choice),
|
|
# then we don't allow recomputation. The idea here is that for
|
|
# things like reductions, saving the output of the reduction is very
|
|
# cheap/small, and it makes sure we don't do things like recompute
|
|
# normalizations in the backwards.
|
|
if min_cut_options.ban_if_reduction:
|
|
input_tensors_size = sum(
|
|
_size_of(i) for i in node.args if isinstance(i, fx.Node)
|
|
)
|
|
output_size = _size_of(node)
|
|
return output_size * 4 < input_tensors_size
|
|
return False
|
|
|
|
def is_materialized(node):
|
|
if node.op == "placeholder":
|
|
return True
|
|
|
|
return not all(is_fusible(node, user) for user in node.users)
|
|
|
|
def get_node_weight(node) -> float:
|
|
mem_sz = _size_of(node)
|
|
if config.recompute_views and op_types.is_view(node):
|
|
# If `config.recompute_views=True`, we don't save views. This is generally
|
|
# a good idea since views are free to recompute, and it makes it a bit simpler
|
|
# to analyze.
|
|
# NB: If they're not free to recompute (e.g. nested tensors)... I
|
|
# think we should modify checks for view_ops to `is_view` and check
|
|
# that. Basically, with nested tensors, `aten.view` is not a "view
|
|
# op".
|
|
return math.inf
|
|
|
|
if isinstance(node.meta["val"], py_sym_types):
|
|
# We never want to save symfloats
|
|
if not isinstance(node.meta["val"], torch.SymInt):
|
|
return INT_INF
|
|
|
|
# Heuristic to bias towards nodes closer to the backwards pass
|
|
# Complete guess about current value
|
|
mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1)))
|
|
if is_materialized(node):
|
|
return mem_sz
|
|
else:
|
|
return mem_sz * 2
|
|
|
|
nx_graph = nx.DiGraph()
|
|
banned_nodes = set()
|
|
|
|
def ban_recomputation_if_allowed(node):
|
|
if op_types.is_view(node):
|
|
return False
|
|
if node in dont_ban:
|
|
return False
|
|
# This bans recomputation of the node unless we've been forced not to by
|
|
# user annotation
|
|
if must_recompute(node):
|
|
return False
|
|
|
|
if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat):
|
|
return False
|
|
|
|
banned_nodes.add(node)
|
|
# A node will only ever be recomputed if there is a path from an
|
|
# ancestor of this node to the backwards path through this node that
|
|
# doesn't go through any saved value. If this node is saved, then that
|
|
# condition is not possible.
|
|
nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)
|
|
return True
|
|
|
|
for node in joint_graph.nodes:
|
|
if node.op == "output":
|
|
continue
|
|
|
|
if node in node_info.required_bw_nodes:
|
|
if node not in node_info.inputs:
|
|
nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
|
|
continue
|
|
# If someone saves a input for backward as-is and backward
|
|
# returns that tensor as-is as a grad input, then the node x would
|
|
# be both a required_bw_node and an input. In this case we
|
|
# (1) connect x_in to to the source, (2) x_out to the sink, and
|
|
# (3) assign the proper weight to the x_in-x_out edge, so that
|
|
# x would be part of cut nodes. A case where this happens is if
|
|
# NestedTensor saves a offset tensor as part of the singleton int
|
|
# in sizes.
|
|
nx_graph.add_edge(node.name + "_out", "sink", capacity=math.inf)
|
|
|
|
if must_recompute(node):
|
|
# If user explicitly says they want to recompute a node, we honor it
|
|
# by adding an inf-capacity edge from X_in to the sink.
|
|
# This way, X_in node is guaranteed to be part of the subgraph that contains "sink"
|
|
# after the cut, thus guaranteeing that X op will be recomputed.
|
|
nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
|
|
continue
|
|
|
|
if _is_primal(node) or _is_fwd_seed_offset(node):
|
|
ban_recomputation_if_allowed(node)
|
|
|
|
# If a node can't be recomputed (too expensive or involves randomness),
|
|
# we prevent it from being recomputed by adding an inf edge to the source
|
|
# We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.
|
|
if node_info.is_required_fw(node) and should_ban_recomputation(node):
|
|
ban_recomputation_if_allowed(node)
|
|
|
|
# Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors.
|
|
is_non_tensor_node = (
|
|
"val" not in node.meta and "tensor_meta" not in node.meta
|
|
) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor))
|
|
|
|
if is_sym_node(node):
|
|
weight = float(sym_node_size(node))
|
|
elif is_non_tensor_node:
|
|
weight = (
|
|
0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf
|
|
)
|
|
else:
|
|
weight = get_node_weight(node)
|
|
# Creates the weights on the "node" edge
|
|
nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight)
|
|
for user in node.users:
|
|
nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf)
|
|
|
|
# todo(chilli): This is the most questionable of the 3 heuristics for banning recompute.
|
|
# Some example models to look at where this helps perf: poolformer_m36,
|
|
# mixer_b16_224, cait_m36_384
|
|
|
|
# The "rough" idea here is that if you have some node that is used by both a
|
|
# node nearby downstream as well as a node far downstream, if we recompute
|
|
# both of the downstream nodes, we're unlikely to be able to fuse both
|
|
# downstream nodes together.
|
|
|
|
# Thus, we shouldn't aim to recompute far downstream nodes that depend on
|
|
# this node. That intuition of "far downstream" is captured by whether
|
|
# there's an unfusible op along the chain somewhere
|
|
|
|
# It could probably be improved by properly analyzing what's going on in the
|
|
# backwards pass instead of only relying on whether it's unfusible in the
|
|
# forwards.
|
|
|
|
def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int:
|
|
"""
|
|
Finds the first unfusible node in the chain of nodes starting from
|
|
`start_nodes` and returns its position.
|
|
"""
|
|
sorted_nodes: List[Tuple[int, fx.Node, bool]] = []
|
|
for n in start_nodes:
|
|
heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True))
|
|
|
|
while len(sorted_nodes) > 0:
|
|
_, node, node_is_fusible = heapq.heappop(sorted_nodes)
|
|
if not node_is_fusible:
|
|
return node_info.get_fw_order(node)
|
|
for user in node.users:
|
|
if node_info.is_required_fw(user):
|
|
if node_info.get_fw_order(user) > max_range:
|
|
continue
|
|
heapq.heappush(
|
|
sorted_nodes,
|
|
(node_info.get_fw_order(user), user, is_fusible(node, user)),
|
|
)
|
|
return max_range
|
|
|
|
if min_cut_options.ban_if_used_far_apart:
|
|
for used_node in node_info.required_fw_nodes:
|
|
orders = [
|
|
node_info.get_fw_order(user)
|
|
for user in used_node.users
|
|
if node_info.is_required_fw(user)
|
|
]
|
|
fw_users = [
|
|
user for user in used_node.users if node_info.is_required_fw(user)
|
|
]
|
|
if len(orders) > 0:
|
|
first_unfusible_use = find_first_unfusible(fw_users, max(orders))
|
|
for user in tuple(used_node.users):
|
|
if (
|
|
node_info.is_required_fw(user)
|
|
and node_info.get_fw_order(user) > first_unfusible_use
|
|
and is_fusible(used_node, user)
|
|
):
|
|
if user in banned_nodes:
|
|
continue
|
|
log.info(
|
|
"used above/below fusible %s:(%s) -> %s -> %s:(%s)",
|
|
used_node,
|
|
node_info.get_fw_order(used_node),
|
|
first_unfusible_use,
|
|
user,
|
|
node_info.get_fw_order(user),
|
|
)
|
|
ban_recomputation_if_allowed(user)
|
|
|
|
# This heuristic is fairly straightforward. The idea is that although it is
|
|
# cheap to recompute bandwidth-bound ops, we don't want to end up in a situation
|
|
# where we have a long chain of pointwise ops from the beginning to the end
|
|
# of the model (like say, residual connections)
|
|
|
|
# todo: I'm not totally sure why this heuristic matters. It's possible that this is
|
|
# working around Inductor fusion decisions, or that it's a patch over
|
|
# suboptimal partitioning decisions
|
|
|
|
# Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36
|
|
|
|
if min_cut_options.ban_if_long_fusible_chains:
|
|
visited = set()
|
|
for start_node in joint_graph.nodes:
|
|
if not node_info.is_required_fw(start_node):
|
|
continue
|
|
fusible = [(node_info.get_fw_order(start_node), start_node)]
|
|
start_order = node_info.get_fw_order(start_node)
|
|
while len(fusible) > 0:
|
|
_, cur = heapq.heappop(fusible)
|
|
if cur in visited:
|
|
continue
|
|
visited.add(cur)
|
|
# 100 is arbitrary choice to try and prevent degenerate cases
|
|
if (
|
|
node_info.get_fw_order(cur) > start_order + 100
|
|
and len(fusible) == 0
|
|
):
|
|
log.info(
|
|
"too long %s %s %s %s",
|
|
cur,
|
|
start_node,
|
|
node_info.get_fw_order(cur),
|
|
node_info.get_fw_order(start_node),
|
|
)
|
|
ban_recomputation_if_allowed(cur)
|
|
break
|
|
|
|
for user in cur.users:
|
|
if (
|
|
node_info.is_required_fw(user)
|
|
and is_fusible(cur, user)
|
|
and user not in banned_nodes
|
|
):
|
|
heapq.heappush(fusible, (node_info.get_fw_order(user), user))
|
|
|
|
try:
|
|
cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
|
|
except Exception:
|
|
print("Failed to compute min-cut on following graph:")
|
|
print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph)))
|
|
visualize_min_cut_graph(nx_graph)
|
|
raise
|
|
|
|
reachable, non_reachable = partition
|
|
cutset: Set[Tuple[str, str]] = set()
|
|
for u, nbrs in ((n, nx_graph[n]) for n in reachable):
|
|
cutset.update((u, v) for v in nbrs if v in non_reachable)
|
|
|
|
cut_nodes = set()
|
|
for node_in, node_out in cutset:
|
|
assert node_in[:-3] == node_out[:-4]
|
|
node_name = node_in[:-3]
|
|
cut_nodes.add(node_name)
|
|
|
|
name_to_node = get_name_to_node(joint_graph)
|
|
# To make this stuff deterministic
|
|
node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)}
|
|
saved_values = sorted(
|
|
(name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x]
|
|
)
|
|
return saved_values, banned_nodes
|
|
|
|
|
|
def visualize_min_cut_graph(nx_graph):
|
|
import networkx as nx
|
|
import pydot
|
|
|
|
dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string()
|
|
dot_graph = pydot.graph_from_dot_data(dot_format)[0]
|
|
for edge in dot_graph.get_edges():
|
|
weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"]
|
|
# Set edge label to weight
|
|
edge.set_label(str(weight))
|
|
# Color edges with weight 'inf' as red
|
|
if weight == float("inf"):
|
|
edge.set_color("red")
|
|
print("Visualizing the failed graph to min_cut_failed.svg")
|
|
dot_graph.write_svg("min_cut_failed.svg")
|
|
|
|
|
|
def get_default_op_list() -> OpTypes:
|
|
default_recomputable_ops: List[Callable] = [
|
|
aten.add,
|
|
aten.sub,
|
|
aten.div,
|
|
aten.atan2,
|
|
aten.mul,
|
|
aten.max,
|
|
aten.min,
|
|
aten.pow,
|
|
aten.remainder,
|
|
aten.fmod,
|
|
aten.__and__,
|
|
aten.__or__,
|
|
aten.__xor__,
|
|
aten.__lshift__,
|
|
aten.__rshift__,
|
|
aten.eq,
|
|
aten.ne,
|
|
aten.ge,
|
|
aten.gt,
|
|
aten.le,
|
|
aten.lt,
|
|
aten.abs,
|
|
aten.bitwise_not,
|
|
aten.ceil,
|
|
aten.floor,
|
|
aten.frac,
|
|
aten.neg,
|
|
aten.relu,
|
|
aten.round,
|
|
aten.silu,
|
|
aten.trunc,
|
|
aten.log,
|
|
aten.log10,
|
|
aten.log1p,
|
|
aten.log2,
|
|
aten.lgamma,
|
|
aten.exp,
|
|
aten.expm1,
|
|
aten.erf,
|
|
aten.erfc,
|
|
aten.cos,
|
|
aten.acos,
|
|
aten.cosh,
|
|
aten.sin,
|
|
aten.asin,
|
|
aten.sinh,
|
|
aten.tan,
|
|
aten.atan,
|
|
aten.tanh,
|
|
aten.atanh,
|
|
aten.sqrt,
|
|
aten.rsqrt,
|
|
aten.reciprocal,
|
|
aten.sigmoid,
|
|
aten.softplus,
|
|
aten.threshold,
|
|
aten.threshold_backward,
|
|
aten.clamp,
|
|
aten.where,
|
|
aten.lerp,
|
|
aten.addcmul,
|
|
aten.gelu,
|
|
aten.gelu_backward,
|
|
aten.sum,
|
|
aten.mean,
|
|
aten._grad_sum_to_size,
|
|
aten.sum_to_size,
|
|
aten.amax,
|
|
aten.to,
|
|
aten.type_as,
|
|
operator.getitem,
|
|
aten.squeeze,
|
|
aten.unsqueeze,
|
|
aten.rsub,
|
|
aten._to_copy,
|
|
] # noqa: E501,B950
|
|
recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias]
|
|
recomputable_view_ops += [
|
|
aten.view,
|
|
aten.slice,
|
|
aten.t,
|
|
prims.broadcast_in_dim,
|
|
aten.expand,
|
|
aten.as_strided,
|
|
aten.permute,
|
|
aten.select,
|
|
]
|
|
view_ops = recomputable_view_ops
|
|
default_recomputable_ops += [
|
|
prims.div,
|
|
prims.convert_element_type,
|
|
aten.clone,
|
|
aten._to_copy,
|
|
aten.full_like,
|
|
prims.var,
|
|
prims.sum,
|
|
aten.var,
|
|
aten.std,
|
|
prims.broadcast_in_dim,
|
|
aten.select,
|
|
aten._unsafe_view,
|
|
aten.view,
|
|
aten.expand,
|
|
aten.slice,
|
|
aten.reshape,
|
|
aten.broadcast_tensors,
|
|
aten.scalar_tensor,
|
|
aten.ones,
|
|
aten.new_zeros,
|
|
aten.lift_fresh_copy,
|
|
aten.arange,
|
|
aten.triu,
|
|
aten.var_mean,
|
|
aten.isinf,
|
|
aten.any,
|
|
aten.full,
|
|
aten.as_strided,
|
|
aten.zeros,
|
|
aten.empty,
|
|
aten.empty_like,
|
|
aten.argmax,
|
|
aten.maximum,
|
|
prims.iota,
|
|
prims._low_memory_max_pool2d_offsets_to_indices,
|
|
] # noqa: E501,B950
|
|
# Natalia said that we should allow recomputing indexing :)
|
|
default_recomputable_ops += [aten.index, aten.gather]
|
|
default_recomputable_ops += view_ops
|
|
|
|
default_recomputable_ops += pointwise_ops()
|
|
|
|
default_recomputable_ops += [
|
|
aten.zeros_like,
|
|
]
|
|
|
|
default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
|
|
recomputable_ops = set(default_recomputable_ops)
|
|
|
|
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
|
|
compute_intensive_ops = [
|
|
aten.mm,
|
|
aten.convolution,
|
|
aten.convolution_backward,
|
|
aten.bmm,
|
|
aten.addmm,
|
|
aten._scaled_dot_product_flash_attention,
|
|
aten._scaled_dot_product_efficient_attention,
|
|
aten._flash_attention_forward,
|
|
aten._efficient_attention_forward,
|
|
aten.upsample_bilinear2d,
|
|
aten._scaled_mm,
|
|
] # noqa: E501,B950
|
|
|
|
fusible_ops = recomputable_ops | set(random_ops)
|
|
return OpTypes(
|
|
set(fusible_ops),
|
|
set(compute_intensive_ops),
|
|
set(random_ops),
|
|
set(view_ops),
|
|
set(recomputable_ops),
|
|
)
|
|
|
|
|
|
def get_name_to_node(graph: fx.Graph):
|
|
name_to_node = {}
|
|
for node in graph.nodes:
|
|
name_to_node[node.name] = node
|
|
return name_to_node
|
|
|
|
|
|
def greedy_knapsack(
|
|
memory: List[float], runtimes: List[float], max_memory: float
|
|
) -> Tuple[float, List[int], List[int]]:
|
|
n = len(runtimes)
|
|
items = list(range(n))
|
|
|
|
# Sort items based on the ratio of runtime to memory in descending order
|
|
items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True)
|
|
|
|
total_memory = 0.0
|
|
total_runtime = 0.0
|
|
items_to_save = []
|
|
items_to_allow_recomputing = []
|
|
|
|
for i in items:
|
|
if total_memory + memory[i] <= max_memory:
|
|
total_memory += memory[i]
|
|
total_runtime += runtimes[i]
|
|
items_to_save.append(i)
|
|
else:
|
|
items_to_allow_recomputing.append(i)
|
|
return total_runtime, items_to_save, items_to_allow_recomputing
|
|
|
|
|
|
def ilp_knapsack(
|
|
memory: List[float], runtimes: List[float], max_memory: float
|
|
) -> Tuple[float, List[int], List[int]]:
|
|
import numpy as np
|
|
|
|
try:
|
|
from scipy.optimize import Bounds, LinearConstraint, milp
|
|
except ImportError:
|
|
raise RuntimeError(
|
|
"To use the ILP for memory budget checkpointing you need to install scipy"
|
|
) from None
|
|
|
|
np_memory = np.array(memory)
|
|
np_runtimes = np.array(runtimes)
|
|
c = -np_runtimes # type: ignore[operator]
|
|
|
|
memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory))
|
|
constraints = [memory_constraint]
|
|
|
|
integrality = np.ones_like(c)
|
|
res = milp(
|
|
c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1)
|
|
)
|
|
if not res.success:
|
|
raise RuntimeError("Somehow scipy solving failed")
|
|
|
|
items_to_save = []
|
|
items_to_allow_recomputing = []
|
|
for idx, i in enumerate(res.x):
|
|
if i == 1:
|
|
items_to_save.append(idx)
|
|
else:
|
|
items_to_allow_recomputing.append(idx)
|
|
return -res.fun, items_to_save, items_to_allow_recomputing
|
|
|
|
|
|
def dp_knapsack(
|
|
memory: List[float], runtimes: List[float], max_memory: float
|
|
) -> Tuple[float, List[int], List[int]]:
|
|
# Scaling factor to convert floating point weights to integers
|
|
S = 10000
|
|
|
|
# Quantize the memory weights
|
|
quantized_memory = torch.tensor(
|
|
[int(round(m * S)) for m in memory], dtype=torch.long, device="cpu"
|
|
)
|
|
runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu")
|
|
|
|
# Quantized pseudopolynomial DP for 0-1 Knapsack
|
|
quantized_max_memory = int(round(max_memory * S))
|
|
|
|
n = len(memory)
|
|
|
|
# Initialize the DP table
|
|
# TODO(chilli): I think if needed, this memory can be optimized with sliding
|
|
# window trick + Hirschberg trick:
|
|
# https://codeforces.com/blog/entry/47247?#comment-316200
|
|
dp = torch.zeros(
|
|
(n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu"
|
|
)
|
|
|
|
for i in range(1, n + 1):
|
|
current_memory = quantized_memory[i - 1]
|
|
current_runtime = runtimes[i - 1]
|
|
|
|
# Copy the previous row
|
|
dp[i, :] = dp[i - 1, :]
|
|
|
|
# Update dp[i, j] for all j >= current_memory
|
|
if current_memory == 0:
|
|
dp[i, :] = dp[i - 1, :] + current_runtime
|
|
else:
|
|
dp[i, current_memory:] = torch.maximum(
|
|
dp[i - 1, current_memory:],
|
|
dp[i - 1, :-current_memory] + current_runtime,
|
|
)
|
|
|
|
# Backtrack to find the items included in the knapsack
|
|
saved_items = []
|
|
recomputable_items = []
|
|
j: int = quantized_max_memory
|
|
for i in range(n, 0, -1):
|
|
if dp[i][j] != dp[i - 1][j]:
|
|
saved_items.append(i - 1) # Include this item (indexing from 0)
|
|
j -= int(quantized_memory[i - 1].item())
|
|
else:
|
|
recomputable_items.append(i - 1)
|
|
|
|
saved_items.reverse() # To get items in the order they were added
|
|
|
|
# The maximum runtime that can be achieved within the max_memory constraint
|
|
max_runtime = dp[n][quantized_max_memory].item()
|
|
|
|
return max_runtime, saved_items, recomputable_items
|
|
|
|
|
|
def _optimize_runtime_with_given_memory(
|
|
memory: List[float],
|
|
runtimes: List[float],
|
|
max_memory: float,
|
|
) -> Tuple[float, List[int], List[int]]:
|
|
SOLVER = config.activation_memory_budget_solver
|
|
if SOLVER == "greedy":
|
|
return greedy_knapsack(memory, runtimes, max_memory)
|
|
elif SOLVER == "ilp":
|
|
return ilp_knapsack(memory, runtimes, max_memory)
|
|
elif SOLVER == "dp":
|
|
return dp_knapsack(memory, runtimes, max_memory)
|
|
else:
|
|
raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}")
|
|
|
|
|
|
from torch.utils._mode_utils import no_dispatch
|
|
|
|
|
|
def estimate_runtime(node):
|
|
RUNTIME_MODE = config.activation_memory_budget_runtime_estimator
|
|
|
|
def materialize_arg(x):
|
|
if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor):
|
|
shape = list(x.meta["val"].shape)
|
|
|
|
def realize_symbol(d):
|
|
return hint_int(d, fallback=4096)
|
|
|
|
shape = [realize_symbol(s) for s in shape]
|
|
return x.meta["val"].new_empty_strided(
|
|
shape, stride=x.meta["tensor_meta"].stride
|
|
)
|
|
elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt):
|
|
return hint_int(x.meta["val"], fallback=4096)
|
|
elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat):
|
|
return 1.0
|
|
elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool):
|
|
return True
|
|
else:
|
|
return x
|
|
|
|
if RUNTIME_MODE == "testing":
|
|
return 1
|
|
|
|
elif RUNTIME_MODE == "profile":
|
|
with no_dispatch():
|
|
from torch._inductor.runtime.benchmarking import benchmarker
|
|
|
|
args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs))
|
|
ms = benchmarker.benchmark_gpu(lambda: node.target(*args, **kwargs))
|
|
return ms
|
|
|
|
elif RUNTIME_MODE == "flops":
|
|
# todo(chilli): Normalize this to also return ms
|
|
from torch.utils.flop_counter import FlopCounterMode
|
|
|
|
args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs))
|
|
with FlopCounterMode(display=False) as mode:
|
|
node.target(*args, **kwargs)
|
|
counted_flops = mode.get_total_flops()
|
|
return max(counted_flops, 1)
|
|
else:
|
|
raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}")
|
|
|
|
|
|
def choose_saved_values_set(
|
|
joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1
|
|
) -> List[fx.Node]:
|
|
if memory_budget > 1 or memory_budget < 0:
|
|
raise RuntimeError(
|
|
f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}"
|
|
)
|
|
min_cut_options = MinCutOptions(
|
|
ban_if_used_far_apart=config.ban_recompute_used_far_apart,
|
|
ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains,
|
|
ban_if_materialized_backward=config.ban_recompute_materialized_backward,
|
|
ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist,
|
|
ban_if_reduction=config.ban_recompute_reductions,
|
|
)
|
|
|
|
if config.aggressive_recomputation:
|
|
min_cut_options = replace(
|
|
min_cut_options,
|
|
ban_if_used_far_apart=False,
|
|
ban_if_long_fusible_chains=False,
|
|
ban_if_materialized_backward=False,
|
|
ban_if_not_in_allowlist=False,
|
|
)
|
|
if memory_budget == 0:
|
|
return node_info.inputs
|
|
|
|
runtime_optimized_saved_values, _ = solve_min_cut(
|
|
joint_graph,
|
|
node_info,
|
|
min_cut_options,
|
|
)
|
|
# return runtime_optimized_saved_values
|
|
if memory_budget == 1:
|
|
return runtime_optimized_saved_values
|
|
|
|
def estimate_activations_size(saved_values: List[fx.Node]) -> float:
|
|
return sum(map(_size_of, saved_values)) / 1e9
|
|
|
|
min_act_size = estimate_activations_size(node_info.inputs)
|
|
max_act_size = estimate_activations_size(runtime_optimized_saved_values)
|
|
# The optimized choice is smaller than the inputs anyways
|
|
if max_act_size <= min_act_size:
|
|
return runtime_optimized_saved_values
|
|
|
|
def get_normalized_size(sz):
|
|
return (sz / 1e9) / (max_act_size - min_act_size)
|
|
|
|
def get_mem_ratio(activations: List[fx.Node]):
|
|
return (estimate_activations_size(activations) - min_act_size) / (
|
|
max_act_size - min_act_size
|
|
)
|
|
|
|
more_aggressive_options = replace(
|
|
min_cut_options,
|
|
ban_if_used_far_apart=False,
|
|
ban_if_long_fusible_chains=False,
|
|
ban_if_materialized_backward=False,
|
|
)
|
|
more_aggressive_saved_values, _ = solve_min_cut(
|
|
joint_graph, node_info, more_aggressive_options
|
|
)
|
|
if get_mem_ratio(more_aggressive_saved_values) < memory_budget:
|
|
return more_aggressive_saved_values
|
|
|
|
aggressive_options = replace(
|
|
more_aggressive_options,
|
|
ban_if_not_in_allowlist=False,
|
|
)
|
|
aggressive_recomputation_saved_values, banned_nodes = solve_min_cut(
|
|
joint_graph, node_info, aggressive_options
|
|
)
|
|
|
|
if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget:
|
|
return aggressive_recomputation_saved_values
|
|
|
|
from torch._inductor.fx_utils import get_node_storage
|
|
|
|
input_storages = {get_node_storage(node) for node in node_info.inputs}
|
|
|
|
def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]:
|
|
return [
|
|
i
|
|
for i in banned_nodes
|
|
if (
|
|
# Only allow recomputing nodes that are actually required for BW
|
|
i.dist_from_bw < int(1e9) # type: ignore[attr-defined]
|
|
and get_node_storage(i) not in input_storages
|
|
)
|
|
]
|
|
|
|
recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes)
|
|
|
|
# default: runtime_optimized_saved_values
|
|
# more aggressive: more_aggressive_saved_values
|
|
# full aggressive: aggressive_recomputation_saved_values
|
|
|
|
all_recomputable_banned_nodes = sorted(
|
|
recomputable_banned_nodes, key=_size_of, reverse=True
|
|
)
|
|
if len(all_recomputable_banned_nodes) == 0:
|
|
return node_info.inputs
|
|
memories_banned_nodes = [
|
|
get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes
|
|
]
|
|
runtimes_banned_nodes = [
|
|
estimate_runtime(node) for node in all_recomputable_banned_nodes
|
|
]
|
|
from torch.utils._mode_utils import no_dispatch
|
|
|
|
def get_saved_values_knapsack(memory_budget):
|
|
with no_dispatch():
|
|
(
|
|
expected_runtime,
|
|
saved_node_idxs,
|
|
recomputable_node_idxs,
|
|
) = _optimize_runtime_with_given_memory(
|
|
memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0)
|
|
)
|
|
dont_ban = set()
|
|
for idx in recomputable_node_idxs:
|
|
dont_ban.add(all_recomputable_banned_nodes[idx])
|
|
assert dont_ban.issubset(all_recomputable_banned_nodes)
|
|
|
|
saved_values, _ = solve_min_cut(
|
|
joint_graph,
|
|
node_info,
|
|
aggressive_options,
|
|
dont_ban,
|
|
)
|
|
return saved_values, expected_runtime
|
|
|
|
if config.visualize_memory_budget_pareto:
|
|
options = []
|
|
for sweep_memory_budget in range(100, -1, -5):
|
|
saved_values, expected_runtime = get_saved_values_knapsack(
|
|
sweep_memory_budget / 100
|
|
)
|
|
options.append(
|
|
(
|
|
sweep_memory_budget,
|
|
sum(runtimes_banned_nodes) - expected_runtime,
|
|
get_mem_ratio(saved_values),
|
|
)
|
|
)
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
x_values = [item[2] for item in options]
|
|
y_values = [item[1] for item in options]
|
|
|
|
# Plotting the values with updated axis labels and chart title
|
|
plt.figure(figsize=(10, 6))
|
|
plt.plot(x_values, y_values, marker="o")
|
|
|
|
# Adding labels for each point
|
|
for i, txt in enumerate(x_values):
|
|
plt.annotate(
|
|
f"{txt:.2f}",
|
|
(txt, y_values[i]),
|
|
textcoords="offset points",
|
|
xytext=(0, 10),
|
|
ha="center",
|
|
)
|
|
|
|
plt.xlabel("Memory Budget")
|
|
plt.ylabel("Runtime of Recomputed Components")
|
|
plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime")
|
|
plt.grid(True)
|
|
fig = plt.gcf()
|
|
plt.show()
|
|
fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png"
|
|
fig.savefig(fig_name)
|
|
log.warning("Generated Pareto frontier curve at %s", fig_name)
|
|
|
|
# todo(chilli): Estimated doesn't align exactly with actual - actual is
|
|
# usually less memory than estimated. i'm guessing (actually quite
|
|
# unsure about this) that's because estimated is just only including
|
|
# tensors we actually banned from recompute, but there may be other
|
|
# tensors that we choose to save.
|
|
|
|
return get_saved_values_knapsack(memory_budget=memory_budget)[0]
|
|
|
|
|
|
def min_cut_rematerialization_partition(
|
|
joint_module: fx.GraphModule,
|
|
_joint_inputs,
|
|
compiler="inductor",
|
|
*,
|
|
num_fwd_outputs,
|
|
) -> Tuple[fx.GraphModule, fx.GraphModule]:
|
|
"""
|
|
Partitions the joint graph such that the backward recomputes the forward.
|
|
Recomputing helps in trading off memory bandwidth with computation.
|
|
|
|
To create the fwd and bwd graph, we copy the joint graph, manually set the
|
|
outputs to just original forward or backward outputs. And then we run the
|
|
resulting graphs through dead code elimination.
|
|
|
|
.. warning::
|
|
This API is experimental and likely to change.
|
|
|
|
Args:
|
|
joint_module(fx.GraphModule): The joint forward and backward graph. This
|
|
is the result of AOT Autograd tracing.
|
|
_joint_inputs: The inputs to the joint graph. This is unused.
|
|
compiler: This option determines the default set of recomputable ops.
|
|
Currently, there are two options: ``nvfuser`` and ``inductor``.
|
|
recomputable_ops: This is an optional set of recomputable ops. If this
|
|
is not None, then this set of ops will be used instead of the
|
|
default set of ops.
|
|
num_fwd_outputs: The number of outputs from the forward graph.
|
|
|
|
Returns:
|
|
Returns the generated forward and backward Fx graph modules.
|
|
"""
|
|
|
|
joint_module.graph.eliminate_dead_code()
|
|
joint_module.recompile()
|
|
|
|
fx_g = joint_module.graph
|
|
|
|
# add the CSE pass
|
|
if config.cse:
|
|
cse_graph = fx_graph_cse(fx_g)
|
|
joint_module.graph = cse_graph
|
|
joint_graph = joint_module.graph
|
|
|
|
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
|
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
|
if graph_has_recomputable_ops:
|
|
joint_module = cleanup_recompute_tags(joint_module)
|
|
|
|
def classify_nodes(joint_module):
|
|
name_to_node = get_name_to_node(joint_module.graph)
|
|
required_bw_nodes = set()
|
|
for node in joint_module.graph.nodes:
|
|
if node.op == "placeholder" and "tangents" in node.target:
|
|
required_bw_nodes.add(node)
|
|
elif _must_be_in_backward(node):
|
|
required_bw_nodes.add(node)
|
|
|
|
if node in required_bw_nodes:
|
|
for user in node.users:
|
|
required_bw_nodes.add(user)
|
|
|
|
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
|
fwd_seed_offset_inputs = list(
|
|
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
|
|
)
|
|
inputs = primal_inputs + fwd_seed_offset_inputs
|
|
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
|
|
joint_module, num_fwd_outputs=num_fwd_outputs
|
|
)
|
|
required_bw_nodes.update(
|
|
o for o in bwd_outputs if o is not None and o.op != "output"
|
|
)
|
|
forward_only_graph = _extract_graph_with_inputs_outputs(
|
|
joint_module.graph, inputs, fwd_outputs, "forward"
|
|
)
|
|
required_fw_nodes: Set[fx.Node] = {
|
|
name_to_node[node.name]
|
|
for node in forward_only_graph.nodes
|
|
if node.op != "output"
|
|
}
|
|
unclaimed_nodes = {
|
|
node
|
|
for node in joint_module.graph.nodes
|
|
if node not in required_fw_nodes and node not in required_bw_nodes
|
|
}
|
|
fw_cnt = 0
|
|
fw_order = {}
|
|
for node in joint_module.graph.nodes:
|
|
if node in required_fw_nodes:
|
|
fw_order[node] = fw_cnt
|
|
fw_cnt += 1
|
|
return NodeInfo(
|
|
inputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, fw_order
|
|
)
|
|
|
|
node_info = classify_nodes(joint_module)
|
|
|
|
# networkx blows up on graphs with no required backward nodes
|
|
# Since there's nothing to partition anyway, and the default partitioner can "handle"
|
|
# this case, send our graph over to the default partitioner.
|
|
if len(node_info.required_bw_nodes) == 0:
|
|
return default_partition(
|
|
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs
|
|
)
|
|
|
|
for node in reversed(joint_module.graph.nodes):
|
|
if node.op == "output":
|
|
node.dist_from_bw = int(1e9)
|
|
elif not node_info.is_required_fw(node):
|
|
node.dist_from_bw = 0
|
|
else:
|
|
node.dist_from_bw = int(1e9)
|
|
for user in node.users:
|
|
node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1)
|
|
|
|
memory_budget = config.activation_memory_budget
|
|
for node in joint_graph.nodes:
|
|
if isinstance(node.meta.get("memory_budget", None), float):
|
|
memory_budget = node.meta["memory_budget"]
|
|
break
|
|
# print("Memory Budget: ", memory_budget)
|
|
saved_values = choose_saved_values_set(
|
|
joint_graph, node_info, memory_budget=memory_budget
|
|
)
|
|
# save_for_backward on tensors and stashes symints in autograd .ctx
|
|
saved_sym_nodes = list(filter(is_sym_node, saved_values))
|
|
saved_values = list(filter(lambda n: not is_sym_node(n), saved_values))
|
|
|
|
# NB: saved_sym_nodes will be mutated to reflect the actual saved symbols
|
|
fw_module, bw_module = _extract_fwd_bwd_modules(
|
|
joint_module,
|
|
saved_values,
|
|
saved_sym_nodes=saved_sym_nodes,
|
|
num_fwd_outputs=num_fwd_outputs,
|
|
)
|
|
|
|
if graph_has_recomputable_ops:
|
|
if graph_has_recomputable_rng_ops:
|
|
fw_module, bw_module = functionalize_rng_ops(
|
|
joint_module, fw_module, bw_module, len(saved_sym_nodes)
|
|
)
|
|
bw_module = reordering_to_mimic_autograd_engine(bw_module)
|
|
|
|
if AOT_PARTITIONER_DEBUG:
|
|
from torch._inductor.fx_utils import get_node_storage
|
|
|
|
storages = {get_node_storage(node) for node in saved_values}
|
|
print(
|
|
"Theoretical Activations Stored: ",
|
|
sum(_size_of(i) for i in saved_values) / 1e9,
|
|
)
|
|
sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values])
|
|
fw_module_nodes = {
|
|
node.name for node in fw_module.graph.nodes if node.op == "call_function"
|
|
}
|
|
bw_module_nodes = {
|
|
node.name for node in bw_module.graph.nodes if node.op == "call_function"
|
|
}
|
|
remat_nodes = fw_module_nodes & bw_module_nodes
|
|
|
|
counts: Dict[str, int] = defaultdict(int)
|
|
for node in fw_module.graph.nodes:
|
|
if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"):
|
|
counts[str(node.target._overloadpacket)] += 1
|
|
print(
|
|
f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}"
|
|
)
|
|
print(
|
|
"Count of Ops Rematerialized: ",
|
|
sorted(counts.items(), key=lambda x: x[1], reverse=True),
|
|
)
|
|
return fw_module, bw_module
|
|
|
|
|
|
def draw_graph(
|
|
traced: torch.fx.GraphModule,
|
|
fname: str,
|
|
figname: str = "fx_graph",
|
|
clear_meta: bool = True,
|
|
prog: Optional[Union[str, List[str]]] = None,
|
|
parse_stack_trace: bool = False,
|
|
dot_graph_shape: Optional[str] = None,
|
|
) -> None:
|
|
if clear_meta:
|
|
new_graph = copy.deepcopy(traced.graph)
|
|
traced = fx.GraphModule(traced, new_graph)
|
|
for node in traced.graph.nodes:
|
|
node.meta = {}
|
|
base, ext = os.path.splitext(fname)
|
|
if not ext:
|
|
ext = "." + config.torch_compile_graph_format
|
|
print(f"Writing FX graph to file: {base}{ext}")
|
|
g = graph_drawer.FxGraphDrawer(
|
|
traced,
|
|
figname,
|
|
parse_stack_trace=parse_stack_trace,
|
|
dot_graph_shape=dot_graph_shape,
|
|
)
|
|
x = g.get_main_dot_graph()
|
|
write_method = getattr(x, "write_" + ext.lstrip("."))
|
|
fname = f"{base}{ext}"
|
|
if prog is None:
|
|
write_method(fname)
|
|
else:
|
|
write_method(fname, prog=prog)
|