mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031 Approved by: https://github.com/mlazos
293 lines
9.9 KiB
Python
293 lines
9.9 KiB
Python
import copy
|
|
from collections import OrderedDict
|
|
from typing import cast, TypedDict
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.distributed._tools.mem_tracker import (
|
|
_MemRefType,
|
|
_ModMemStats,
|
|
_ModState,
|
|
MemTracker,
|
|
)
|
|
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
|
|
from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats
|
|
|
|
|
|
class ModOrder(TypedDict):
|
|
fw_pre_order: list[str]
|
|
bw_pre_order: list[str]
|
|
fw_post_order: list[str]
|
|
bw_post_order: list[str]
|
|
|
|
|
|
class ModRuntime(TypedDict):
|
|
fw: float
|
|
bw: float
|
|
|
|
|
|
class ModStats(TypedDict):
|
|
fqn: str
|
|
# per-module params
|
|
param_per_module: int
|
|
# per-module grads
|
|
grad_per_module: int
|
|
# total accumulated gradients up to and including this module
|
|
grad_total: int
|
|
# per module fw activation size (excluding input and output)
|
|
act_fw_per_module: int
|
|
# per module bw activation size during peak_bw
|
|
act_bw_per_module: int
|
|
# per module activation grad size during peak_bw
|
|
act_grad_per_module: int
|
|
# total activation size up to but excluding the current module
|
|
# includes input of the current module (i.e., output of previous module)
|
|
act_total: int
|
|
# Inputs to the module
|
|
input_per_module: int
|
|
# Outputs of the module
|
|
output_per_module: int
|
|
# Total fw run-time of the module
|
|
fw_runtime_per_module: float
|
|
# Total bw run-time of the module
|
|
bw_runtime_per_module: float
|
|
# Is this module a leaf module
|
|
is_leaf: bool
|
|
# Total ac run-time of the module
|
|
sac_runtime: float
|
|
# Total ac_memory for the module
|
|
sac_memory: int
|
|
# Number of piecewise-linear functions used for approximating ac tradeoff curve
|
|
n_segments: int
|
|
# Slopes of the of piecewise-linear functions
|
|
slopes: list[float]
|
|
# Intercepts of the of piecewise-linear functions
|
|
intercepts: list[float]
|
|
# X breakpoints of the of piecewise-linear functions
|
|
breakpoints: list[float]
|
|
# Original trade-off curves
|
|
tradeoff_curve: OrderedDict[float, float]
|
|
|
|
|
|
class ModuleInfo(TypedDict):
|
|
mod_order: ModOrder
|
|
mod_stats: list[ModStats]
|
|
|
|
|
|
def aggregate_stats(
|
|
model: torch.nn.Module,
|
|
mem_tracker: MemTracker,
|
|
runtime_estimator: RuntimeEstimator,
|
|
sac_estimator: SACEstimator,
|
|
dev: torch.device,
|
|
) -> ModuleInfo:
|
|
"""
|
|
Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats.
|
|
|
|
Args:
|
|
model: nn.Module object
|
|
runtime_estimator: RuntimeEstimator object with runtime stats
|
|
mem_tracker: MemTracker object with memory stats
|
|
sac_estimator: SACEstimator object with AC tradeoff stats
|
|
dev: device the model was run on (used to extract memory stats from MemTracker)
|
|
|
|
Returns:
|
|
ModuleInfo: A dictionary with module order and module stats.
|
|
"""
|
|
|
|
# Memory stats
|
|
mod_mem_stats: dict[torch.nn.Module, _ModMemStats] = dict(
|
|
copy.deepcopy(mem_tracker.memory_tracking)
|
|
)
|
|
|
|
# Runtime stats
|
|
mod_runtime_stats: dict[str, ModRuntime] = {
|
|
fqn: {"fw": v["fw"], "bw": v["bw"]}
|
|
for fqn, v in runtime_estimator.mod_runtimes.items()
|
|
}
|
|
|
|
# Module order
|
|
mod_order: ModOrder = {
|
|
"fw_pre_order": list(runtime_estimator.mod_fw_pre_order),
|
|
"bw_pre_order": list(runtime_estimator.mod_bw_pre_order),
|
|
"fw_post_order": list(runtime_estimator.mod_fw_post_order),
|
|
"bw_post_order": list(runtime_estimator.mod_bw_post_order),
|
|
}
|
|
|
|
# Selective Activation Checkpointing stats
|
|
sac_estimator.pwlf_sac_tradeoff_curve()
|
|
mod_sac_tradeoff_stats: dict[str, SACTradeOffStats] = copy.deepcopy(
|
|
sac_estimator.sac_mod_tradeoff_stats
|
|
)
|
|
|
|
module_info: ModuleInfo = {
|
|
"mod_order": mod_order,
|
|
"mod_stats": [],
|
|
}
|
|
|
|
for mod in model.modules():
|
|
if mod_mem_stat := mod_mem_stats.get(mod):
|
|
if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None):
|
|
sac_runtime = tradeoff_stats.sac_runtime
|
|
sac_memory = tradeoff_stats.sac_memory
|
|
n_segments = tradeoff_stats.n_segments
|
|
slopes = tradeoff_stats.slopes
|
|
intercepts = tradeoff_stats.intercepts
|
|
breakpoints = tradeoff_stats.fit_breaks
|
|
tradeoff_curve = tradeoff_stats.tradeoff_curve
|
|
is_leaf = False
|
|
else:
|
|
sac_runtime = sac_memory = n_segments = 0
|
|
slopes = intercepts = breakpoints = []
|
|
tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef]
|
|
is_leaf = True
|
|
mod_stat: ModStats = {
|
|
"fqn": mod_mem_stat.mod_fqn,
|
|
"param_per_module": mod_mem_stat.parameter_mem,
|
|
"grad_per_module": mod_mem_stat.parameter_mem,
|
|
"grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
|
|
_MemRefType.GRAD
|
|
],
|
|
"act_fw_per_module": max(
|
|
0,
|
|
mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT]
|
|
- mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT]
|
|
- mod_mem_stat.output_mem,
|
|
),
|
|
"act_bw_per_module": max(
|
|
0,
|
|
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT],
|
|
),
|
|
"act_grad_per_module": (
|
|
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP]
|
|
- mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
|
|
_MemRefType.TEMP
|
|
]
|
|
),
|
|
"act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][
|
|
_MemRefType.ACT
|
|
],
|
|
"input_per_module": mod_mem_stat.input_mem,
|
|
"output_per_module": mod_mem_stat.output_mem,
|
|
"fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"],
|
|
"bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"],
|
|
"is_leaf": is_leaf,
|
|
"sac_runtime": sac_runtime,
|
|
"sac_memory": sac_memory,
|
|
"n_segments": n_segments,
|
|
"slopes": slopes,
|
|
"intercepts": intercepts,
|
|
"breakpoints": breakpoints,
|
|
"tradeoff_curve": tradeoff_curve,
|
|
}
|
|
module_info["mod_stats"].append(mod_stat)
|
|
|
|
return module_info
|
|
|
|
|
|
class Node(ModStats):
|
|
index: int # index according to forward pre-order
|
|
pos_fw_post_order: int # index according to forward post-order
|
|
|
|
|
|
class Graph:
|
|
def __init__(self, n: int) -> None:
|
|
self.nodes: list[Node] = []
|
|
self.name2node: dict[str, Node] = {}
|
|
self.ad_matrix = np.zeros((n, n))
|
|
self.fw_post_order: list[str] = []
|
|
|
|
def add_node(self, node: Node) -> None:
|
|
self.nodes.append(node)
|
|
self.name2node[node["fqn"]] = node
|
|
|
|
|
|
def parse_module_info(module_info: ModuleInfo) -> Graph:
|
|
"""
|
|
Parse module info and create a graph (tree) of modules. The graph will be
|
|
used by MILP solver to find optimal SAC and/or FSDP configurations.
|
|
"""
|
|
mod_stats = module_info["mod_stats"]
|
|
fw_pre_order = module_info["mod_order"]["fw_pre_order"]
|
|
# assertion and number of nodes
|
|
assert len(mod_stats) == len(fw_pre_order)
|
|
n_nodes = len(mod_stats)
|
|
|
|
# create graph
|
|
g = Graph(n_nodes)
|
|
g.fw_post_order = module_info["mod_order"]["fw_post_order"]
|
|
|
|
# sort the modules by pre-order and add them to the graph
|
|
module_info["mod_stats"] = sorted(
|
|
mod_stats, key=lambda x: fw_pre_order.index(x["fqn"])
|
|
)
|
|
for i, one_mod_stats in enumerate(mod_stats):
|
|
node: Node = cast(Node, one_mod_stats)
|
|
node["index"] = i
|
|
node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"])
|
|
g.add_node(node)
|
|
|
|
# set up ancestor-descendant matrix
|
|
for i in range(n_nodes):
|
|
for j in range(i, n_nodes):
|
|
if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]):
|
|
g.ad_matrix[i][j] = 1
|
|
else:
|
|
break
|
|
|
|
return g
|
|
|
|
|
|
def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool:
|
|
"""
|
|
check if name_descendant is a submodule of name_ancestor, or if they are the same
|
|
"""
|
|
return name_descendant == name_ancestor or name_ancestor + "." in name_descendant
|
|
|
|
|
|
def is_submodule(name_descendant: str, name_ancestor: str) -> bool:
|
|
"""
|
|
if name_descendant is a submodule of name_ancestor, but not the same
|
|
"""
|
|
return name_ancestor + "." in name_descendant
|
|
|
|
|
|
def display_bytes(b: int, unit: str = "MiB") -> str:
|
|
"""
|
|
return a string that represent the number of bytes in a desired unit
|
|
"""
|
|
if unit == "KiB":
|
|
return f"{b / 2**10:.2f} KiB"
|
|
if unit == "MiB":
|
|
return f"{b / 2**20:.2f} MiB"
|
|
if unit == "GiB":
|
|
return f"{b / 2**30:.2f} GiB"
|
|
return f"{b:.2f} bytes"
|
|
|
|
|
|
def get_peak_memory_runtime_baseline(graph: Graph) -> tuple[int, float]:
|
|
"""
|
|
Get the baseline peak memory and runtime.
|
|
Baseline here means there is no FSDP or AC.
|
|
Memory includes the parameters, gradients, activations, and activation gradients.
|
|
Memory does not include e.g., optimizer states, embedding tables, etc.
|
|
|
|
Returns:
|
|
int: peak memory in bytes
|
|
float: compute time in ms
|
|
"""
|
|
P_1 = graph.nodes[0]["param_per_module"]
|
|
num_nodes = len(graph.nodes)
|
|
peak_mem = 0
|
|
for i in range(num_nodes):
|
|
TG_i = graph.nodes[i]["grad_total"]
|
|
AG_i = graph.nodes[i]["act_grad_per_module"]
|
|
TA_i = graph.nodes[i]["act_total"]
|
|
peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i)
|
|
compute_time = (
|
|
graph.nodes[0]["fw_runtime_per_module"]
|
|
+ graph.nodes[0]["bw_runtime_per_module"]
|
|
)
|
|
return (peak_mem, compute_time)
|