DeepCompile: Use min_cut_rematerialization for partitioning joint graphs (#7609)

# Motivation

PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization can also be used to recompute param aliases.
When partitioning a joint graph, we don't want to save for backward the
gathered parameters and values computed from them via aliasing ops, as
that essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching `choose_saved_values_set`, we
can achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can use min_cut_rematerialization
easily. The eager backend can use min-cut by customizing the
partition_fn for `aot_module_simplified`, and is already using that for
graphs with activation checkpointing enabled. The inductor backend uses
that algorithm since torch 2.0.0 [2] and is still the default after the
inductor partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. When
autocast is enabled, downcasted parameters are preferred to be
recomputed. It suffices to mark such casting nodes as must-recompute.

[1]
https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813
[2]
https://github.com/pytorch/pytorch/blob/v2.0.0/torch/_inductor/compile_fx.py#L459
[3] https://github.com/pytorch/pytorch/pull/157580

# Proposal

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

# Preliminary Evaluation

Here's a summary of the tests using
https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 on
a 8x RTX 5090 node.

| Configuration | Base Time (ms) | Base Mem (GB) | Time with this PR
(ms) | Mem with this PR (GB) |

|---------------------|----------------|---------------|------------------------|-----------------------|
| eager + autocast | 551.92 | 12.07 | 571.24 | 9.96 |
| eager + bf16 | 419.87 | 9.47 | 445.76 | 7.30 |
| inductor + autocast | 546.97 | 12.84 | 570.09 | 13.04 |
| inductor + bf16 | 444.03 | 10.01 | 444.70 | 10.19 |

## Reduced memory with eager backend

The initial goal of this PR is to reduce peak memory usage when torch
autocast is enabled. That is achieved according to the first row of the
table, but in two different ways simultaneously.

1. Downcasted parameters during forward are throwed away and recomputed
(by the fused cast + allgather) in the backward pass.
2. Without this PR, `fast_free_schedule` will arange most allgather at
the beginning of the graph. That leads to a even higher peak during
forward, but is no longer seen with PR.
3. By diffing the graphs passed to `add_z3_gather_release`, I noticed
that recomputations selected by min-cut is slightly different (that test
script has activation checkpointing enabled for the LLM module). That
can also impact computation time and memory usage.

Here's the shape of memory usage before this PR with eager backend +
torch autocast. eager + BF16 shows similar shapes. Numbers reported in
the table are peak during forward. The peak memory usage during backend
reduces ~0.7GB in both cases.

<img width="1482" height="629" alt="image"
src="https://github.com/user-attachments/assets/7e7ec859-9a04-4ddd-ba37-c2d475a81058"
/>

After this PR:

<img width="1482" height="453" alt="image"
src="https://github.com/user-attachments/assets/f15c71b8-f823-4aa5-801a-a36188c5e866"
/>

## Similar memory with inductor backend

Unlike eager backend, the inductor backend uses similar memory with or
without this PR. The memory usage pattern is as follows, which requires
further analysis.

Before this PR:

<img width="1070" height="613" alt="image"
src="https://github.com/user-attachments/assets/317b9a58-d4ef-459f-ac7b-67ef2318a9de"
/>

After this PR:

<img width="911" height="536" alt="image"
src="https://github.com/user-attachments/assets/7e737a81-cf27-402c-aeea-dfe661043fc1"
/>

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
Junjie Mao
2025-10-03 11:39:38 +08:00
committed by GitHub
parent 9cbd3edd0d
commit 2a76988958
3 changed files with 71 additions and 131 deletions

View File

@ -16,6 +16,7 @@ try:
import torch._dynamo
from functorch.compile import make_boxed_func
from torch._functorch.aot_autograd import aot_module_simplified
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch._subclasses.fake_tensor import is_fake
except ImportError:
@ -367,17 +368,16 @@ def make_backend(backend, compile_config, compile_kwargs={}):
return compiler_fn
partition_fn = get_wrapped_partitioner(z3_partition, param_indices, min_cut_rematerialization_partition)
aot_mod = aot_module_simplified(gm,
real_inputs,
fw_compiler=make_compiler_fn(make_fw_graph),
bw_compiler=make_compiler_fn(make_bw_graph),
partition_fn=get_wrapped_partitioner(param_indices))
partition_fn=partition_fn)
return torch._dynamo.optimize(**compile_kwargs)(aot_mod)
elif backend == "inductor":
patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs,
param_indices, param_manager)
from .partitioner import get_wrapped_choose_saved_values_set
torch._functorch.partitioners.choose_saved_values_set = get_wrapped_choose_saved_values_set(param_indices)
return torch._inductor.compile(gm, real_inputs)

View File

@ -20,6 +20,7 @@ except ImportError:
from deepspeed.utils.torch import required_torch_version
from .util import get_input_nodes
from .graph_param import DSGraphParamManager
from .partitioner import get_wrapped_partitioner
def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):
@ -66,7 +67,8 @@ def wrap_partition_fn(partition_fn, real_inputs, param_indices):
def wrapped_partition_fn(*args, **kwargs):
fw_module, bw_module = partition_fn(*args, **kwargs)
fn = get_wrapped_partitioner(True, param_indices, partition_fn=partition_fn)
fw_module, bw_module = fn(*args, **kwargs)
# get parameter names
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)

View File

@ -3,156 +3,94 @@
# DeepSpeed Team
# This file was copied from PyTorch and modified for DeepSpeed.
from typing import Tuple, List
import operator
import torch
from torch.fx import GraphModule, Graph, Node
try:
from torch._functorch.partitioners import is_sym_node, _is_primal, _is_fwd_seed_offset, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, _extract_fwd_bwd_modules, has_recomputable_ops, min_cut_rematerialization_partition, choose_saved_values_set
from torch.utils.checkpoint import CheckpointPolicy
from torch._functorch.partitioners import _is_primal
except ImportError:
pass
from .util import get_no_copy_ops
_recompute_ops = {torch.ops.aten.t.default}
from .util import get_no_copy_ops, is_cast_op
def _find_recompute_nodes(graph: Graph, ds_param_node: Node) -> List[Node]:
"""
Given a graph and a node that represents a parameter that was allgathered,
find all nodes that use the parameter and require recomputation.
def _recompute_param_aliases(joint_graph: Graph, param_indices: List[Tuple[int, int, torch.Size]]):
"""Recompute nodes aliasing or downcasting any parameter
In ZeRO3, sharded parameters are gathered before use and the gathered
parameters should be freed once they are no longer needed to save GPU
memory.
When DeepCompile is active for ZeRO3, parameter gathering is done by custom
passes after the joint graph captured by Dynamo and AOT Autograd is
partitioned into fwd and bwd parts. Since the partitioner has no clue about
parameter sharding now, the partitioned graphs will save for backward all
intermediate activations including those aliasing the gathered parameters.
That essentially nullifies the memory reduction that ZeRO3 is designed to
bring.
The solution is to recompute the parameter-aliasing activations in the
backward. It is done by marking such nodes as MUST_RECOMPUTE and reusing the
min-cut partitioner originally designed for checkpointing. If autocast is
enabled, parameter downcasts are also recomputed.
This cannot be converted to a standalone pass because it must be applied
before partitioning the joint graph, but passes run after the partitioning.
TODO(eternalNight) `min_cut_rematerialization_partition` may recompute more
nodes than required for ZeRO3. Need investigate its performance
implications.
"""
no_copy_ops = get_no_copy_ops()
recompute_nodes = set()
for node in graph.nodes:
if node.target in no_copy_ops:
if ds_param_node in node.args:
recompute_nodes.add(node)
if any(a in recompute_nodes for a in node.args):
recompute_nodes.add(node)
return recompute_nodes
def need_recompute(n: Node) -> bool:
if n.op == "call_function":
is_cast, _ = is_cast_op(n)
return n.target in no_copy_ops or is_cast
return False
def _get_values_from_ds_params(joint_graph, param_indices):
primal_inputs = list(filter(_is_primal, joint_graph.nodes))
ds_param_inputs = [primal_inputs[arg_idx] for arg_idx, _, _ in param_indices]
no_copy_ops = get_no_copy_ops()
ds_param_inputs = set(ds_param_inputs)
ds_param_users = {}
ds_param_inputs = set([primal_inputs[arg_idx] for arg_idx, _, _ in param_indices])
recomputed_nodes = set()
for node in joint_graph.nodes:
if node.target in no_copy_ops and any((a in ds_param_inputs or a in ds_param_users) for a in node.args):
for a in node.args:
if a in ds_param_inputs:
ds_param_users[node] = a
elif a in ds_param_users:
ds_param_users[node] = ds_param_users[a]
# The `ac_graph_id` tag tracks the checkpoint module that a node belongs
# to, and is for enforcing the saving of activations at the boundary of
# consecutive checkpointed blocks. It starts from 1 and increments by 1
# each time a graph module is checkpointed.
#
# `min_cut_rematerialization_partition` requires every node to have
# `ac_graph_id`. If this graph is not checkpointed (and thus
# `ac_graph_id` is missing), we tag all nodes to 1 to prevent the
# partition function from modifying the recompute tag.
node.meta.setdefault("ac_graph_id", 1)
return ds_param_users
# Arguments can be non-tensor types some of which are not hashable. So
# we must inspect the type of an argument before checking if it is in
# any set.
if need_recompute(node) and \
any([(isinstance(a, Node) and (a in ds_param_inputs or a in recomputed_nodes)) for a in node.args]):
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
recomputed_nodes.add(node)
else:
# If checkpointing is not enabled for this graph, assume all
# activations required by the backward pass should be saved.
node.meta.setdefault("recompute", CheckpointPolicy.MUST_SAVE)
def get_wrapped_choose_saved_values_set(param_indices: List[Tuple[int, int, torch.Size]]):
def ds_choose_saved_values_set(joint_graph: torch.fx.Graph, node_info, memory_budget=1) -> List[Node]:
saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget)
ds_param_users = _get_values_from_ds_params(joint_graph, param_indices)
new_saved_values = []
for v in saved_values:
if v in ds_param_users:
ds_val = ds_param_users[v]
if ds_val not in new_saved_values:
new_saved_values.append(ds_val)
else:
new_saved_values.append(v)
return new_saved_values
return ds_choose_saved_values_set
def get_wrapped_partitioner(param_indices: List[Tuple[int, int, torch.Size]]):
def get_wrapped_partitioner(
z3_partition: bool,
param_indices: List[Tuple[int, int, torch.Size]],
partition_fn,
):
def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *,
num_fwd_outputs) -> Tuple[GraphModule, GraphModule]:
"""
This is basically the same as the default_partition function, but
it doesn't save the gathered params and values computed from them.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs, "forward")
forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != "output"}
saved_values = []
saved_sym_nodes = []
fwd_inputs = list(filter(_is_primal, forward_only_graph.nodes))
ds_param_inputs = [fwd_inputs[arg_idx] for arg_idx, _, _ in param_indices]
ds_param_input_names = {node.name for node in ds_param_inputs}
ds_param_recompute_nodes = set()
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif "tensor_meta" not in node.meta and node.op == "call_function":
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target == operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
if node.name in ds_param_input_names:
saved_values.append(node)
recompute_nodes = _find_recompute_nodes(joint_module.graph, node)
recompute_nodes = [n for n in recompute_nodes if n.name in forward_node_names]
for recompute_node in recompute_nodes:
ds_param_recompute_nodes.add(recompute_node)
if len(recompute_nodes) > 0:
saved_values.append(node)
else:
if node not in ds_param_recompute_nodes:
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
f_gm, b_gm = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
num_fwd_outputs=num_fwd_outputs,
)
return f_gm, b_gm
if z3_partition:
_recompute_param_aliases(joint_module.graph, param_indices)
return partition_fn(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
return partition_recompute_ds_params