Files
pytorch/torch/distributed/_tools/ilp_utils.py
Yuanyuan Chen fbe0d20a17 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-14 14:22:54 +00:00

293 lines
9.9 KiB
Python

import copy
from collections import OrderedDict
from typing import cast, TypedDict
import numpy as np
import torch
from torch.distributed._tools.mem_tracker import (
_MemRefType,
_ModMemStats,
_ModState,
MemTracker,
)
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats
class ModOrder(TypedDict):
fw_pre_order: list[str]
bw_pre_order: list[str]
fw_post_order: list[str]
bw_post_order: list[str]
class ModRuntime(TypedDict):
fw: float
bw: float
class ModStats(TypedDict):
fqn: str
# per-module params
param_per_module: int
# per-module grads
grad_per_module: int
# total accumulated gradients up to and including this module
grad_total: int
# per module fw activation size (excluding input and output)
act_fw_per_module: int
# per module bw activation size during peak_bw
act_bw_per_module: int
# per module activation grad size during peak_bw
act_grad_per_module: int
# total activation size up to but excluding the current module
# includes input of the current module (i.e., output of previous module)
act_total: int
# Inputs to the module
input_per_module: int
# Outputs of the module
output_per_module: int
# Total fw run-time of the module
fw_runtime_per_module: float
# Total bw run-time of the module
bw_runtime_per_module: float
# Is this module a leaf module
is_leaf: bool
# Total ac run-time of the module
sac_runtime: float
# Total ac_memory for the module
sac_memory: int
# Number of piecewise-linear functions used for approximating ac tradeoff curve
n_segments: int
# Slopes of the of piecewise-linear functions
slopes: list[float]
# Intercepts of the of piecewise-linear functions
intercepts: list[float]
# X breakpoints of the of piecewise-linear functions
breakpoints: list[float]
# Original trade-off curves
tradeoff_curve: OrderedDict[float, float]
class ModuleInfo(TypedDict):
mod_order: ModOrder
mod_stats: list[ModStats]
def aggregate_stats(
model: torch.nn.Module,
mem_tracker: MemTracker,
runtime_estimator: RuntimeEstimator,
sac_estimator: SACEstimator,
dev: torch.device,
) -> ModuleInfo:
"""
Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats.
Args:
model: nn.Module object
runtime_estimator: RuntimeEstimator object with runtime stats
mem_tracker: MemTracker object with memory stats
sac_estimator: SACEstimator object with AC tradeoff stats
dev: device the model was run on (used to extract memory stats from MemTracker)
Returns:
ModuleInfo: A dictionary with module order and module stats.
"""
# Memory stats
mod_mem_stats: dict[torch.nn.Module, _ModMemStats] = dict(
copy.deepcopy(mem_tracker.memory_tracking)
)
# Runtime stats
mod_runtime_stats: dict[str, ModRuntime] = {
fqn: {"fw": v["fw"], "bw": v["bw"]}
for fqn, v in runtime_estimator.mod_runtimes.items()
}
# Module order
mod_order: ModOrder = {
"fw_pre_order": list(runtime_estimator.mod_fw_pre_order),
"bw_pre_order": list(runtime_estimator.mod_bw_pre_order),
"fw_post_order": list(runtime_estimator.mod_fw_post_order),
"bw_post_order": list(runtime_estimator.mod_bw_post_order),
}
# Selective Activation Checkpointing stats
sac_estimator.pwlf_sac_tradeoff_curve()
mod_sac_tradeoff_stats: dict[str, SACTradeOffStats] = copy.deepcopy(
sac_estimator.sac_mod_tradeoff_stats
)
module_info: ModuleInfo = {
"mod_order": mod_order,
"mod_stats": [],
}
for mod in model.modules():
if mod_mem_stat := mod_mem_stats.get(mod):
if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None):
sac_runtime = tradeoff_stats.sac_runtime
sac_memory = tradeoff_stats.sac_memory
n_segments = tradeoff_stats.n_segments
slopes = tradeoff_stats.slopes
intercepts = tradeoff_stats.intercepts
breakpoints = tradeoff_stats.fit_breaks
tradeoff_curve = tradeoff_stats.tradeoff_curve
is_leaf = False
else:
sac_runtime = sac_memory = n_segments = 0
slopes = intercepts = breakpoints = []
tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef]
is_leaf = True
mod_stat: ModStats = {
"fqn": mod_mem_stat.mod_fqn,
"param_per_module": mod_mem_stat.parameter_mem,
"grad_per_module": mod_mem_stat.parameter_mem,
"grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
_MemRefType.GRAD
],
"act_fw_per_module": max(
0,
mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT]
- mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT]
- mod_mem_stat.output_mem,
),
"act_bw_per_module": max(
0,
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT],
),
"act_grad_per_module": (
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP]
- mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
_MemRefType.TEMP
]
),
"act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][
_MemRefType.ACT
],
"input_per_module": mod_mem_stat.input_mem,
"output_per_module": mod_mem_stat.output_mem,
"fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"],
"bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"],
"is_leaf": is_leaf,
"sac_runtime": sac_runtime,
"sac_memory": sac_memory,
"n_segments": n_segments,
"slopes": slopes,
"intercepts": intercepts,
"breakpoints": breakpoints,
"tradeoff_curve": tradeoff_curve,
}
module_info["mod_stats"].append(mod_stat)
return module_info
class Node(ModStats):
index: int # index according to forward pre-order
pos_fw_post_order: int # index according to forward post-order
class Graph:
def __init__(self, n: int) -> None:
self.nodes: list[Node] = []
self.name2node: dict[str, Node] = {}
self.ad_matrix = np.zeros((n, n))
self.fw_post_order: list[str] = []
def add_node(self, node: Node) -> None:
self.nodes.append(node)
self.name2node[node["fqn"]] = node
def parse_module_info(module_info: ModuleInfo) -> Graph:
"""
Parse module info and create a graph (tree) of modules. The graph will be
used by MILP solver to find optimal SAC and/or FSDP configurations.
"""
mod_stats = module_info["mod_stats"]
fw_pre_order = module_info["mod_order"]["fw_pre_order"]
# assertion and number of nodes
assert len(mod_stats) == len(fw_pre_order)
n_nodes = len(mod_stats)
# create graph
g = Graph(n_nodes)
g.fw_post_order = module_info["mod_order"]["fw_post_order"]
# sort the modules by pre-order and add them to the graph
module_info["mod_stats"] = sorted(
mod_stats, key=lambda x: fw_pre_order.index(x["fqn"])
)
for i, one_mod_stats in enumerate(mod_stats):
node: Node = cast(Node, one_mod_stats)
node["index"] = i
node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"])
g.add_node(node)
# set up ancestor-descendant matrix
for i in range(n_nodes):
for j in range(i, n_nodes):
if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]):
g.ad_matrix[i][j] = 1
else:
break
return g
def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool:
"""
check if name_descendant is a submodule of name_ancestor, or if they are the same
"""
return name_descendant == name_ancestor or name_ancestor + "." in name_descendant
def is_submodule(name_descendant: str, name_ancestor: str) -> bool:
"""
if name_descendant is a submodule of name_ancestor, but not the same
"""
return name_ancestor + "." in name_descendant
def display_bytes(b: int, unit: str = "MiB") -> str:
"""
return a string that represent the number of bytes in a desired unit
"""
if unit == "KiB":
return f"{b / 2**10:.2f} KiB"
if unit == "MiB":
return f"{b / 2**20:.2f} MiB"
if unit == "GiB":
return f"{b / 2**30:.2f} GiB"
return f"{b:.2f} bytes"
def get_peak_memory_runtime_baseline(graph: Graph) -> tuple[int, float]:
"""
Get the baseline peak memory and runtime.
Baseline here means there is no FSDP or AC.
Memory includes the parameters, gradients, activations, and activation gradients.
Memory does not include e.g., optimizer states, embedding tables, etc.
Returns:
int: peak memory in bytes
float: compute time in ms
"""
P_1 = graph.nodes[0]["param_per_module"]
num_nodes = len(graph.nodes)
peak_mem = 0
for i in range(num_nodes):
TG_i = graph.nodes[i]["grad_total"]
AG_i = graph.nodes[i]["act_grad_per_module"]
TA_i = graph.nodes[i]["act_total"]
peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i)
compute_time = (
graph.nodes[0]["fw_runtime_per_module"]
+ graph.nodes[0]["bw_runtime_per_module"]
)
return (peak_mem, compute_time)