mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
An improved heuristic for operator reordering for peak memory + debugging logs (#161810)
Revisiting the idea in https://github.com/pytorch/pytorch/pull/140195 For the lpmf algorithm in the memory reorder pass, in some cases, when all the nodes that can be scheduled are quite large, it is beneficial to switch the scheduling strategy. So instead of using size as the criterion, we choose a node that can unlock more nodes to become schedulable by analyzing their successor nodes. For an internal use case, we observe up to 20 GiB memory difference and here are the before and after memory snapshot. More information can be found in [D81270682](https://www.internalfb.com/diff/D81270682) (internal only). <img width="348" height="227" alt="image" src="https://github.com/user-attachments/assets/fb71e840-1508-44ed-bc9d-5eb4d364607d" /> In addition, add the functionality to upload the graph to tlparse for offline debugging. The format of the json is in consistency with the simulator [here](https://fburl.com/code/3l3d3qi4) (internal only). Pull Request resolved: https://github.com/pytorch/pytorch/pull/161810 Approved by: https://github.com/yf225
This commit is contained in:
committed by
PyTorch MergeBot
parent
a94ddd9b00
commit
ddc5107601
@ -10,8 +10,9 @@ from torch._environment import is_fbcode
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from . import config
|
||||
from .ir import MultiOutputLayout, NoneLayout
|
||||
from .utils import get_dtype_size
|
||||
from .utils import get_dtype_size, is_nonfreeable_buffers
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
@ -92,14 +93,7 @@ def get_freeable_input_buf(
|
||||
for node in nodes:
|
||||
for dep in node.read_writes.reads:
|
||||
if dep.name in graph_inputs:
|
||||
dep_name = dep.name
|
||||
# Subgraphs have a prefix for the name, cleanup the prefix
|
||||
# before checking for known strings.
|
||||
if V.graph.name:
|
||||
dep_name = dep_name.removeprefix(V.graph.name + "_")
|
||||
if not dep_name.startswith(
|
||||
("primals_", "arg", "fwd_rng_state", "bwd_rng_state")
|
||||
):
|
||||
if not is_nonfreeable_buffers(dep):
|
||||
dep_name_to_succ_nodes[dep.name].add(node)
|
||||
dep_name_to_size[dep.name] = _dep_size_hint(dep)
|
||||
|
||||
@ -574,6 +568,7 @@ def topological_sort_lpmf(
|
||||
elif buf_name in name_to_freeable_input_buf:
|
||||
output_memory += name_to_freeable_input_buf[buf_name].mpi_buffer.size_free
|
||||
max_memory = max(live_memory, output_memory)
|
||||
memory_gap = max_memory - live_memory
|
||||
|
||||
# compute the amount of memory that is allocated when a node is scheduled
|
||||
# and the amount of memory that can be freed when a node is scheduled
|
||||
@ -589,17 +584,33 @@ def topological_sort_lpmf(
|
||||
|
||||
# schedule nodes one at a time
|
||||
schedule: list[BaseSchedulerNode] = []
|
||||
size_threshold = config.size_threshold_for_succ_based_strategy
|
||||
num_iters: int = 0
|
||||
while num_iters < len(nodes) and nodes_to_schedule:
|
||||
# select a node to schedule:
|
||||
selected_node = min(
|
||||
nodes_to_schedule,
|
||||
key=lambda node: (
|
||||
max(live_memory + node.mpi_node.size, max_memory),
|
||||
node.mpi_node.size - node_info[node]["memory_to_free"],
|
||||
node.mpi_node.index,
|
||||
),
|
||||
)
|
||||
if (
|
||||
size_threshold > 0
|
||||
and min(node.mpi_node.size for node in nodes_to_schedule) > size_threshold
|
||||
):
|
||||
selected_node = min(
|
||||
nodes_to_schedule,
|
||||
key=lambda node: min(
|
||||
(
|
||||
succ_node.mpi_node.index
|
||||
for succ_node in node.mpi_node.succ_nodes
|
||||
),
|
||||
default=len(nodes),
|
||||
),
|
||||
)
|
||||
else:
|
||||
selected_node = min(
|
||||
nodes_to_schedule,
|
||||
key=lambda node: (
|
||||
node.mpi_node.size if node.mpi_node.size > memory_gap else 0,
|
||||
node.mpi_node.size - node_info[node]["memory_to_free"],
|
||||
node.mpi_node.index,
|
||||
),
|
||||
)
|
||||
nodes_to_schedule.remove(selected_node)
|
||||
schedule.append(selected_node)
|
||||
num_iters += 1
|
||||
@ -608,6 +619,7 @@ def topological_sort_lpmf(
|
||||
live_memory += selected_node.mpi_node.size
|
||||
max_memory = max(max_memory, live_memory)
|
||||
live_memory -= node_info[selected_node]["memory_to_free"]
|
||||
memory_gap = max_memory - live_memory
|
||||
|
||||
# update successor nodes and nodes_to_schedule
|
||||
for succ_node in selected_node.mpi_node.succ_nodes:
|
||||
@ -887,6 +899,16 @@ def reorder_for_peak_memory(
|
||||
graph_outputs,
|
||||
)
|
||||
|
||||
# export graph for simulator if needed
|
||||
if config.reorder_for_peak_memory_debug:
|
||||
export_graph_for_simulator(
|
||||
nodes,
|
||||
name_to_freeable_input_buf,
|
||||
name_to_fused_node,
|
||||
graph_inputs,
|
||||
graph_outputs,
|
||||
)
|
||||
|
||||
# Validate planning info before proceeding with reordering
|
||||
try:
|
||||
validate_graph_acyclic(nodes)
|
||||
@ -937,3 +959,112 @@ def reorder_for_peak_memory(
|
||||
best_result = min(peak_memory_diff_methods, key=lambda x: x.peak_memory)
|
||||
|
||||
return best_result.order
|
||||
|
||||
|
||||
def export_graph_for_simulator(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
name_to_fused_node: dict[str, BaseSchedulerNode],
|
||||
graph_inputs: OrderedSet[str],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> None:
|
||||
"""
|
||||
This is for debugging purposes. It will dump a json file that records graph information.
|
||||
The graph can then be used in a simulator: https://fburl.com/code/3l3d3qi4
|
||||
"""
|
||||
|
||||
class ORMBuffer(TypedDict):
|
||||
name: str
|
||||
size_alloc: int
|
||||
size_free: int
|
||||
size: int # for backward compatibility
|
||||
is_input: bool
|
||||
is_output: bool
|
||||
deps: list[str]
|
||||
unmet_deps: list[str]
|
||||
|
||||
class ORMNode(TypedDict):
|
||||
name: str
|
||||
buffer_names: list[str]
|
||||
|
||||
class ORMGraph(TypedDict):
|
||||
nodes: list[ORMNode]
|
||||
buffers: list[ORMBuffer]
|
||||
|
||||
orm_buffers: list[ORMBuffer] = []
|
||||
orm_nodes: list[ORMNode] = []
|
||||
|
||||
# get orm buffers for freeable input buffers
|
||||
for buf_name, input_buf in name_to_freeable_input_buf.items():
|
||||
orm_buf_input_buffer: ORMBuffer = {
|
||||
"name": buf_name,
|
||||
"size_alloc": input_buf.mpi_buffer.size_free,
|
||||
"size_free": input_buf.mpi_buffer.size_free,
|
||||
"size": input_buf.mpi_buffer.size_free,
|
||||
"is_input": True,
|
||||
"is_output": buf_name in graph_outputs,
|
||||
"deps": [],
|
||||
"unmet_deps": [],
|
||||
}
|
||||
orm_buffers.append(orm_buf_input_buffer)
|
||||
|
||||
# get orm buffers for scheduler buffers
|
||||
name_to_buf: dict[str, SchedulerBuffer] = {
|
||||
buf.get_name(): buf for node in nodes for buf in node.get_outputs()
|
||||
} # need to reassign due to probably node pruning
|
||||
for buf_name, sched_buf in name_to_buf.items():
|
||||
if sched_buf.defining_op is None:
|
||||
continue
|
||||
deps = [
|
||||
pred_buf.get_name()
|
||||
for pred_buf in name_to_fused_node[
|
||||
sched_buf.defining_op.get_name()
|
||||
].mpi_node.pred_buffers
|
||||
]
|
||||
orm_buf_scheduler_buffer: ORMBuffer = {
|
||||
"name": buf_name,
|
||||
"size_alloc": sched_buf.mpi_buffer.size_alloc,
|
||||
"size_free": sched_buf.mpi_buffer.size_free,
|
||||
"size": sched_buf.mpi_buffer.size_free,
|
||||
"is_input": False,
|
||||
"is_output": buf_name in graph_outputs,
|
||||
"deps": deps,
|
||||
"unmet_deps": [
|
||||
buf_name for buf_name in deps if buf_name not in graph_inputs
|
||||
],
|
||||
}
|
||||
orm_buffers.append(orm_buf_scheduler_buffer)
|
||||
|
||||
# get orm nodes
|
||||
for node in nodes:
|
||||
orm_node: ORMNode = {
|
||||
"name": node.get_name(),
|
||||
"buffer_names": list(node.get_buffer_names()),
|
||||
}
|
||||
orm_nodes.append(orm_node)
|
||||
|
||||
# create the graph object
|
||||
g: ORMGraph = {
|
||||
"nodes": orm_nodes,
|
||||
"buffers": orm_buffers,
|
||||
}
|
||||
|
||||
# dump the graph
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from functorch.compile import get_graph_being_compiled
|
||||
|
||||
name = os.path.splitext(get_graph_being_compiled())[0] + "_fused"
|
||||
|
||||
g_str = json.dumps(g, indent=2)
|
||||
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": name,
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: g_str,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user