mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
# Motivation PyTorch provides `min_cut_rematerialization_partition()` to partition a joint graph while respecting recomputation annotation. That algorithm forms a data-flow-like graph from the joint graph, adds to edges weights from some recomputation-cost-related heuristics and applies the min-cut algorithm to determine which nodes to recompute. Users can force recomputation of a node by annotating its `node.meta["recompute"]` to MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1]. While originally designed for activation checkpointing, min_cut_rematerialization can also be used to recompute param aliases. When partitioning a joint graph, we don't want to save for backward the gathered parameters and values computed from them via aliasing ops, as that essentially means the gathered parameter will be saved. Instead of customizing the partitioner or patching `choose_saved_values_set`, we can achieve that by annotating such nodes to be MUST_RECOMPUTE. Both eager and inductor backends can use min_cut_rematerialization easily. The eager backend can use min-cut by customizing the partition_fn for `aot_module_simplified`, and is already using that for graphs with activation checkpointing enabled. The inductor backend uses that algorithm since torch 2.0.0 [2] and is still the default after the inductor partitioner is made configurable a few weeks ago [3]. That approach also helps DeepCompile + torch autocast nicely. When autocast is enabled, downcasted parameters are preferred to be recomputed. It suffices to mark such casting nodes as must-recompute. [1] https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813 [2] https://github.com/pytorch/pytorch/blob/v2.0.0/torch/_inductor/compile_fx.py#L459 [3] https://github.com/pytorch/pytorch/pull/157580 # Proposal Motivated by the flexibility and the requirement for optimizing DeepCompile + autocast, I propose to switch to the min-cut-based partitioner for both backends. This PR implements that switch, cleans up dead code and also recomputes downcasted parameters in the backward. # Preliminary Evaluation Here's a summary of the tests using https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 on a 8x RTX 5090 node. | Configuration | Base Time (ms) | Base Mem (GB) | Time with this PR (ms) | Mem with this PR (GB) | |---------------------|----------------|---------------|------------------------|-----------------------| | eager + autocast | 551.92 | 12.07 | 571.24 | 9.96 | | eager + bf16 | 419.87 | 9.47 | 445.76 | 7.30 | | inductor + autocast | 546.97 | 12.84 | 570.09 | 13.04 | | inductor + bf16 | 444.03 | 10.01 | 444.70 | 10.19 | ## Reduced memory with eager backend The initial goal of this PR is to reduce peak memory usage when torch autocast is enabled. That is achieved according to the first row of the table, but in two different ways simultaneously. 1. Downcasted parameters during forward are throwed away and recomputed (by the fused cast + allgather) in the backward pass. 2. Without this PR, `fast_free_schedule` will arange most allgather at the beginning of the graph. That leads to a even higher peak during forward, but is no longer seen with PR. 3. By diffing the graphs passed to `add_z3_gather_release`, I noticed that recomputations selected by min-cut is slightly different (that test script has activation checkpointing enabled for the LLM module). That can also impact computation time and memory usage. Here's the shape of memory usage before this PR with eager backend + torch autocast. eager + BF16 shows similar shapes. Numbers reported in the table are peak during forward. The peak memory usage during backend reduces ~0.7GB in both cases. <img width="1482" height="629" alt="image" src="https://github.com/user-attachments/assets/7e7ec859-9a04-4ddd-ba37-c2d475a81058" /> After this PR: <img width="1482" height="453" alt="image" src="https://github.com/user-attachments/assets/f15c71b8-f823-4aa5-801a-a36188c5e866" /> ## Similar memory with inductor backend Unlike eager backend, the inductor backend uses similar memory with or without this PR. The memory usage pattern is as follows, which requires further analysis. Before this PR: <img width="1070" height="613" alt="image" src="https://github.com/user-attachments/assets/317b9a58-d4ef-459f-ac7b-67ef2318a9de" /> After this PR: <img width="911" height="536" alt="image" src="https://github.com/user-attachments/assets/7e737a81-cf27-402c-aeea-dfe661043fc1" /> Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
222 lines
9.1 KiB
Python
222 lines
9.1 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
|
|
try:
|
|
import torch.utils._pytree as pytree
|
|
from torch._functorch.aot_autograd import create_aot_dispatcher_function
|
|
from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs
|
|
from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode
|
|
from torch._inductor.virtualized import V
|
|
from torch._inductor.scheduler import Scheduler
|
|
|
|
original_create_aot_dispatcher_function = create_aot_dispatcher_function
|
|
except ImportError:
|
|
pass
|
|
|
|
from deepspeed.utils.torch import required_torch_version
|
|
from .util import get_input_nodes
|
|
from .graph_param import DSGraphParamManager
|
|
from .partitioner import get_wrapped_partitioner
|
|
|
|
|
|
def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):
|
|
|
|
def wrapped_compiler(gm, fake_inputs):
|
|
mod_graph = dc_compiler(gm, fake_inputs)
|
|
|
|
# For symint case
|
|
if mod_graph is None:
|
|
return None
|
|
|
|
if z3_partition:
|
|
# Inductor validates input size estimated by the first trace, where ds tensor is materialized.
|
|
# We need to patch the input tensors to avoid the validation error.
|
|
patched_inputs = []
|
|
if bwd:
|
|
param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph)
|
|
param_names = [n.name for n in param_nodes_bw]
|
|
else:
|
|
param_names = graph_param_manager[graph_id].param_names
|
|
input_nodes = get_input_nodes(gm.graph)
|
|
|
|
for in_node, in_v in zip(input_nodes, fake_inputs):
|
|
ds_param = in_node.name in param_names
|
|
if ds_param:
|
|
from torch._subclasses.fake_tensor import is_fake
|
|
from torch._dynamo.utils import to_fake_tensor
|
|
assert is_fake(in_v), f"Input {in_v} should be fake tensor"
|
|
patched_inputs.append(
|
|
to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode))
|
|
else:
|
|
patched_inputs.append(in_v)
|
|
|
|
patched_inputs = tuple(patched_inputs)
|
|
else:
|
|
patched_inputs = fake_inputs
|
|
|
|
return original_compiler(gm, patched_inputs)
|
|
|
|
return wrapped_compiler
|
|
|
|
|
|
def wrap_partition_fn(partition_fn, real_inputs, param_indices):
|
|
|
|
def wrapped_partition_fn(*args, **kwargs):
|
|
|
|
fn = get_wrapped_partitioner(True, param_indices, partition_fn=partition_fn)
|
|
fw_module, bw_module = fn(*args, **kwargs)
|
|
|
|
# get parameter names
|
|
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)
|
|
|
|
def fix_placeholder_meta(graph):
|
|
for n in graph.nodes:
|
|
if n.op == "placeholder" and n.name in pm.param_names:
|
|
n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device)
|
|
|
|
fix_placeholder_meta(fw_module.graph)
|
|
fix_placeholder_meta(bw_module.graph)
|
|
|
|
return fw_module, bw_module
|
|
|
|
return wrapped_partition_fn
|
|
|
|
|
|
def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs,
|
|
param_indices, param_manager):
|
|
|
|
from torch._dynamo.backends.common import AotAutograd
|
|
import functools
|
|
|
|
def patch_aotautograd():
|
|
# Unpatch if it was already patched
|
|
if hasattr(AotAutograd, "__original_init"):
|
|
AotAutograd.__init__ = AotAutograd.__original_init
|
|
|
|
original_init = AotAutograd.__init__
|
|
|
|
@functools.wraps(original_init)
|
|
def patched_init(self, **kwargs):
|
|
kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"],
|
|
make_fw_graph,
|
|
z3_partition,
|
|
graph_id,
|
|
param_manager,
|
|
bwd=False)
|
|
kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"],
|
|
make_bw_graph,
|
|
z3_partition,
|
|
graph_id,
|
|
param_manager,
|
|
bwd=True)
|
|
kwargs["inference_compiler"] = kwargs["fw_compiler"]
|
|
|
|
if z3_partition:
|
|
kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices)
|
|
|
|
original_init(self, **kwargs)
|
|
|
|
AotAutograd.__original_init = original_init
|
|
AotAutograd.__init__ = patched_init
|
|
|
|
patch_aotautograd()
|
|
|
|
|
|
def register_custom_ops():
|
|
|
|
def fallback_handler_no_reuse(kernel,
|
|
never_reuse_input,
|
|
never_reuse_output,
|
|
force_free_input,
|
|
add_to_fallback_set=True):
|
|
if add_to_fallback_set:
|
|
fallbacks.add(kernel)
|
|
|
|
def handler(*args, **kwargs):
|
|
|
|
def wrap_tensors(x):
|
|
out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x
|
|
if out is not None and never_reuse_output:
|
|
V.graph.never_reuse_buffers.add(out.get_name())
|
|
return out
|
|
|
|
class CustomDCKernel(FallbackKernel):
|
|
|
|
def __init__(self, op, *args, **kwargs):
|
|
super().__init__(op, *args, **kwargs)
|
|
|
|
def add_to_never_reuse(x):
|
|
if isinstance(x, IRNode):
|
|
assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}"
|
|
V.graph.never_reuse_buffers.add(x.get_name())
|
|
|
|
if never_reuse_input:
|
|
pytree.tree_map(add_to_never_reuse, args)
|
|
|
|
def get_var_name_for_arg(self, arg: str):
|
|
if arg.isidentifier():
|
|
return arg
|
|
|
|
import re
|
|
match = re.match(r"reinterpret_tensor\((\w+),", arg)
|
|
if match:
|
|
return match.group(1)
|
|
return None
|
|
|
|
def codegen(self, wrapper):
|
|
if not force_free_input:
|
|
return super().codegen(wrapper)
|
|
|
|
kernel = self.op_overload
|
|
self.codegen_comment(wrapper)
|
|
args = [*self.codegen_args(), *self.codegen_kwargs()]
|
|
|
|
if required_torch_version(min_version=2.8):
|
|
V.graph.wrapper_code.generate_fallback_kernel(self)
|
|
else:
|
|
V.graph.wrapper_code.generate_fallback_kernel(self, args)
|
|
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
var_name = self.get_var_name_for_arg(args[0])
|
|
if var_name:
|
|
wrapper.writeline(f"{var_name} = None")
|
|
|
|
self.codegen_unbacked_symbol_defs(wrapper)
|
|
|
|
kernel_cls = CustomDCKernel if force_free_input else FallbackKernel
|
|
return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs))
|
|
|
|
return handler
|
|
|
|
def register_fallback_no_reuse(op_overload,
|
|
never_reuse_input=False,
|
|
never_reuse_output=False,
|
|
force_free_input=False):
|
|
add_needs_realized_inputs(op_overload)
|
|
return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse(
|
|
op_overload,
|
|
never_reuse_input=never_reuse_input,
|
|
never_reuse_output=never_reuse_output,
|
|
force_free_input=force_free_input))
|
|
|
|
# Inductor tries to reuse output buffer when possible. We need to disable this behavior for some custom ops.
|
|
# -> It seems that memory region is still reused in some cases. So we clone the inputs for some ops.
|
|
register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True)
|
|
register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True)
|
|
register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False)
|
|
register_fallback_no_reuse(torch.ops.dc.reduce_grad.default,
|
|
never_reuse_input=True,
|
|
never_reuse_output=True,
|
|
force_free_input=True)
|
|
register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True)
|
|
|
|
if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched:
|
|
Scheduler.is_dc_patched = True
|
|
Scheduler.dead_node_elimination = lambda _: None
|