An improved heuristic for operator reordering for peak memory + debugging logs (#161810)

Revisiting the idea in https://github.com/pytorch/pytorch/pull/140195

For the lpmf algorithm in the memory reorder pass, in some cases, when all the nodes that can be scheduled are quite large, it is beneficial to switch the scheduling strategy. So instead of using size as the criterion, we choose a node that can unlock more nodes to become schedulable by analyzing their successor nodes.

For an internal use case, we observe up to 20 GiB memory difference and here are the before and after memory snapshot. More information can be found in [D81270682](https://www.internalfb.com/diff/D81270682) (internal only).

<img width="348" height="227" alt="image" src="https://github.com/user-attachments/assets/fb71e840-1508-44ed-bc9d-5eb4d364607d" />

In addition, add the functionality to upload the graph to tlparse for offline debugging. The format of the json is in consistency with the simulator [here](https://fburl.com/code/3l3d3qi4) (internal only).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161810
Approved by: https://github.com/yf225
This commit is contained in:
Xuan Zhang
2025-09-12 15:00:36 -07:00
committed by PyTorch MergeBot
parent a94ddd9b00
commit ddc5107601
3 changed files with 168 additions and 17 deletions

View File

@ -10,8 +10,9 @@ from torch._environment import is_fbcode
from torch._utils_internal import signpost_event
from torch.utils._ordered_set import OrderedSet
from . import config
from .ir import MultiOutputLayout, NoneLayout
from .utils import get_dtype_size
from .utils import get_dtype_size, is_nonfreeable_buffers
from .virtualized import V
@ -92,14 +93,7 @@ def get_freeable_input_buf(
for node in nodes:
for dep in node.read_writes.reads:
if dep.name in graph_inputs:
dep_name = dep.name
# Subgraphs have a prefix for the name, cleanup the prefix
# before checking for known strings.
if V.graph.name:
dep_name = dep_name.removeprefix(V.graph.name + "_")
if not dep_name.startswith(
("primals_", "arg", "fwd_rng_state", "bwd_rng_state")
):
if not is_nonfreeable_buffers(dep):
dep_name_to_succ_nodes[dep.name].add(node)
dep_name_to_size[dep.name] = _dep_size_hint(dep)
@ -574,6 +568,7 @@ def topological_sort_lpmf(
elif buf_name in name_to_freeable_input_buf:
output_memory += name_to_freeable_input_buf[buf_name].mpi_buffer.size_free
max_memory = max(live_memory, output_memory)
memory_gap = max_memory - live_memory
# compute the amount of memory that is allocated when a node is scheduled
# and the amount of memory that can be freed when a node is scheduled
@ -589,17 +584,33 @@ def topological_sort_lpmf(
# schedule nodes one at a time
schedule: list[BaseSchedulerNode] = []
size_threshold = config.size_threshold_for_succ_based_strategy
num_iters: int = 0
while num_iters < len(nodes) and nodes_to_schedule:
# select a node to schedule:
selected_node = min(
nodes_to_schedule,
key=lambda node: (
max(live_memory + node.mpi_node.size, max_memory),
node.mpi_node.size - node_info[node]["memory_to_free"],
node.mpi_node.index,
),
)
if (
size_threshold > 0
and min(node.mpi_node.size for node in nodes_to_schedule) > size_threshold
):
selected_node = min(
nodes_to_schedule,
key=lambda node: min(
(
succ_node.mpi_node.index
for succ_node in node.mpi_node.succ_nodes
),
default=len(nodes),
),
)
else:
selected_node = min(
nodes_to_schedule,
key=lambda node: (
node.mpi_node.size if node.mpi_node.size > memory_gap else 0,
node.mpi_node.size - node_info[node]["memory_to_free"],
node.mpi_node.index,
),
)
nodes_to_schedule.remove(selected_node)
schedule.append(selected_node)
num_iters += 1
@ -608,6 +619,7 @@ def topological_sort_lpmf(
live_memory += selected_node.mpi_node.size
max_memory = max(max_memory, live_memory)
live_memory -= node_info[selected_node]["memory_to_free"]
memory_gap = max_memory - live_memory
# update successor nodes and nodes_to_schedule
for succ_node in selected_node.mpi_node.succ_nodes:
@ -887,6 +899,16 @@ def reorder_for_peak_memory(
graph_outputs,
)
# export graph for simulator if needed
if config.reorder_for_peak_memory_debug:
export_graph_for_simulator(
nodes,
name_to_freeable_input_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
)
# Validate planning info before proceeding with reordering
try:
validate_graph_acyclic(nodes)
@ -937,3 +959,112 @@ def reorder_for_peak_memory(
best_result = min(peak_memory_diff_methods, key=lambda x: x.peak_memory)
return best_result.order
def export_graph_for_simulator(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
name_to_fused_node: dict[str, BaseSchedulerNode],
graph_inputs: OrderedSet[str],
graph_outputs: OrderedSet[str],
) -> None:
"""
This is for debugging purposes. It will dump a json file that records graph information.
The graph can then be used in a simulator: https://fburl.com/code/3l3d3qi4
"""
class ORMBuffer(TypedDict):
name: str
size_alloc: int
size_free: int
size: int # for backward compatibility
is_input: bool
is_output: bool
deps: list[str]
unmet_deps: list[str]
class ORMNode(TypedDict):
name: str
buffer_names: list[str]
class ORMGraph(TypedDict):
nodes: list[ORMNode]
buffers: list[ORMBuffer]
orm_buffers: list[ORMBuffer] = []
orm_nodes: list[ORMNode] = []
# get orm buffers for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items():
orm_buf_input_buffer: ORMBuffer = {
"name": buf_name,
"size_alloc": input_buf.mpi_buffer.size_free,
"size_free": input_buf.mpi_buffer.size_free,
"size": input_buf.mpi_buffer.size_free,
"is_input": True,
"is_output": buf_name in graph_outputs,
"deps": [],
"unmet_deps": [],
}
orm_buffers.append(orm_buf_input_buffer)
# get orm buffers for scheduler buffers
name_to_buf: dict[str, SchedulerBuffer] = {
buf.get_name(): buf for node in nodes for buf in node.get_outputs()
} # need to reassign due to probably node pruning
for buf_name, sched_buf in name_to_buf.items():
if sched_buf.defining_op is None:
continue
deps = [
pred_buf.get_name()
for pred_buf in name_to_fused_node[
sched_buf.defining_op.get_name()
].mpi_node.pred_buffers
]
orm_buf_scheduler_buffer: ORMBuffer = {
"name": buf_name,
"size_alloc": sched_buf.mpi_buffer.size_alloc,
"size_free": sched_buf.mpi_buffer.size_free,
"size": sched_buf.mpi_buffer.size_free,
"is_input": False,
"is_output": buf_name in graph_outputs,
"deps": deps,
"unmet_deps": [
buf_name for buf_name in deps if buf_name not in graph_inputs
],
}
orm_buffers.append(orm_buf_scheduler_buffer)
# get orm nodes
for node in nodes:
orm_node: ORMNode = {
"name": node.get_name(),
"buffer_names": list(node.get_buffer_names()),
}
orm_nodes.append(orm_node)
# create the graph object
g: ORMGraph = {
"nodes": orm_nodes,
"buffers": orm_buffers,
}
# dump the graph
import json
import os
import torch
from functorch.compile import get_graph_being_compiled
name = os.path.splitext(get_graph_being_compiled())[0] + "_fused"
g_str = json.dumps(g, indent=2)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": name,
"encoding": "string",
},
payload_fn=lambda: g_str,
)