Files
pytorch/torch/_inductor/memory.py
Xuan Zhang ddc5107601 An improved heuristic for operator reordering for peak memory + debugging logs (#161810)
Revisiting the idea in https://github.com/pytorch/pytorch/pull/140195

For the lpmf algorithm in the memory reorder pass, in some cases, when all the nodes that can be scheduled are quite large, it is beneficial to switch the scheduling strategy. So instead of using size as the criterion, we choose a node that can unlock more nodes to become schedulable by analyzing their successor nodes.

For an internal use case, we observe up to 20 GiB memory difference and here are the before and after memory snapshot. More information can be found in [D81270682](https://www.internalfb.com/diff/D81270682) (internal only).

<img width="348" height="227" alt="image" src="https://github.com/user-attachments/assets/fb71e840-1508-44ed-bc9d-5eb4d364607d" />

In addition, add the functionality to upload the graph to tlparse for offline debugging. The format of the json is in consistency with the simulator [here](https://fburl.com/code/3l3d3qi4) (internal only).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161810
Approved by: https://github.com/yf225
2025-09-13 00:42:32 +00:00

1071 lines
38 KiB
Python

from __future__ import annotations
import collections
import dataclasses
import heapq
import logging
from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union
from torch._environment import is_fbcode
from torch._utils_internal import signpost_event
from torch.utils._ordered_set import OrderedSet
from . import config
from .ir import MultiOutputLayout, NoneLayout
from .utils import get_dtype_size, is_nonfreeable_buffers
from .virtualized import V
if TYPE_CHECKING:
from .dependencies import Dep
from .scheduler import BaseSchedulerNode, SchedulerBuffer
torch_log = logging.getLogger(__name__)
@dataclasses.dataclass
class PeakMemoryResult:
order: list[BaseSchedulerNode]
peak_memory: int
method: str
@dataclasses.dataclass
class MemoryPlanningInfoForBuffer:
size_alloc: int = 0
size_free: int = 0
succ_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
default_factory=OrderedSet
)
@dataclasses.dataclass
class MemoryPlanningInfoForNode:
index: int = 0
size: int = 0
pred_buffers: OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]] = (
dataclasses.field(default_factory=OrderedSet)
)
pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
default_factory=OrderedSet
)
succ_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
default_factory=OrderedSet
)
@dataclasses.dataclass
class FreeableInputBuffer:
name: str
mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field(
default_factory=MemoryPlanningInfoForBuffer
)
def get_name(self) -> str:
return self.name
def __hash__(self) -> int:
return hash(self.name)
def get_freeable_input_buf(
nodes: list[BaseSchedulerNode],
graph_inputs: OrderedSet[str],
) -> dict[str, FreeableInputBuffer]:
"""
Create and keep track of all input buffers that can be freed during the program
Returns:
A dictionary containing all freeable input buffers, keyed by their names.
"""
def _dep_size_hint(dep: Dep) -> int:
return V.graph.get_dep_size_hint(dep)
# get freeable input buffers' successor nodes and their sizes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
collections.defaultdict(OrderedSet)
)
dep_name_to_size: dict[str, int] = dict()
for node in nodes:
for dep in node.read_writes.reads:
if dep.name in graph_inputs:
if not is_nonfreeable_buffers(dep):
dep_name_to_succ_nodes[dep.name].add(node)
dep_name_to_size[dep.name] = _dep_size_hint(dep)
# create FreeableInputBuffer objects and add them to the returned dictionary
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = dict()
for dep_name, succ_nodes in dep_name_to_succ_nodes.items():
name_to_freeable_input_buf[dep_name] = FreeableInputBuffer(
dep_name,
MemoryPlanningInfoForBuffer(
size_free=dep_name_to_size[dep_name], succ_nodes=succ_nodes
),
)
return name_to_freeable_input_buf
def compute_size_for_scheduler_buffer(
name_to_buf: dict[str, SchedulerBuffer],
) -> dict[str, tuple[int, int]]:
"""
Compute the size of each scheduler buffer, including (1) memory allocated when
it is created and (2) memory deallocated when it is freed.
We specially handle the case of MultiOutputLayout.
Consider the following case:
buf0 = some_ops_with_multi_outputs(...)
buf1 = buf0[0] # assume 10 bytes
buf2 = buf0[1] # assume 20 bytes
In such cases,
buf0: at creation, 30 bytes allocated, when deleted, 0 bytes freed
buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed
buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed
When an operation mutates a buffer in-place, the scheduler creates a new buffer name
to track the "before" and "after" states, even though they share the same memory.
The mutated buffer represents a rename with zero allocation and deallocation cost.
During dependency tracking, we transfer dependencies from the mutated name back to
the original buffer, ensuring the original memory is only freed when all aliases
are done.
This handles cases where a buffer has multiple non-overlapping aliases - rather than
trying to assign free costs to individual aliases, we forward all alias dependencies
to the original buffer.
Consider:
buf0 = op0()
buf1 = mutation_op_(buf0)
del buf0
...
op(buf1)
del buf1
The only memory events are the creation prior to op0, and the deletion following buf1.
Returns:
A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free).
"""
from .ir import MultiOutput
from .scheduler import OutputNode
sched_buf_to_size: dict[str, tuple[int, int]] = dict()
def _compute_and_update_buf_size(
sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False
) -> int:
if sched_buf.get_name() in V.graph.scheduler.mutation_real_name:
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
return 0
elif isinstance(sched_buf.node.layout, NoneLayout):
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
return 0
elif isinstance(sched_buf.node.layout, MultiOutputLayout):
size_alloc = 0
for user in sched_buf.users:
if isinstance(user.node, OutputNode):
continue
for buf in user.node.get_outputs():
if isinstance(buf.node, MultiOutput):
size_alloc += _compute_and_update_buf_size(buf, True)
sched_buf_to_size[sched_buf.get_name()] = (
0 if user_of_MultiOutputLayout else size_alloc,
0,
)
return size_alloc
else:
buf_size = V.graph.sizevars.size_hint(
sched_buf.node.get_numel(), fallback=0
) * get_dtype_size(sched_buf.node.get_dtype())
sched_buf_to_size[sched_buf.get_name()] = (
0 if user_of_MultiOutputLayout else buf_size,
buf_size,
)
return buf_size
for sched_buf in name_to_buf.values():
# skip if sched_buf is already processed as an user of another SchedulerBuffer
# whose layout is of the type MultiOutputLayout
if sched_buf.get_name() not in sched_buf_to_size:
_compute_and_update_buf_size(sched_buf)
return sched_buf_to_size
def assign_memory_planning_info_for_scheduler_buffers(
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
) -> None:
"""
For each SchedulerBuffer, assign its size info and successor nodes.
A buffer's successor nodes determines when a buffer can be freed.
"""
# get buffer sizes
sched_buf_to_size = compute_size_for_scheduler_buffer(name_to_buf)
# get buffer's successor nodes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
collections.defaultdict(OrderedSet)
)
for node in nodes:
for dep in node.unmet_dependencies:
dep_name_to_succ_nodes[dep.name].add(node)
# iterate in reverse, so dependencies are picked up transitively.
for mutating_buf_name, real_buf_name in reversed(
V.graph.scheduler.mutation_real_name.items()
):
dep_name_to_succ_nodes[real_buf_name] |= dep_name_to_succ_nodes[
mutating_buf_name
]
# populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer
# note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs)
for buf_name in name_to_buf.keys():
name_to_buf[buf_name].mpi_buffer = MemoryPlanningInfoForBuffer(
size_alloc=sched_buf_to_size[buf_name][0],
size_free=sched_buf_to_size[buf_name][1],
succ_nodes=dep_name_to_succ_nodes[buf_name],
)
def assign_memory_planning_info_for_scheduler_nodes(
nodes: list[BaseSchedulerNode],
name_to_fused_node: dict[str, BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
) -> None:
"""
Assign to each scheduler node its predecessor and successor nodes.
"""
node_to_pred_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = (
collections.defaultdict(OrderedSet)
)
node_to_succ_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = {}
node_to_pred_buffers: dict[
BaseSchedulerNode, OrderedSet[SchedulerBuffer | FreeableInputBuffer]
] = collections.defaultdict(OrderedSet)
# collect all predecessors using existing successor mappings
for node in nodes:
succ_nodes = OrderedSet(
succ_node
for buffer in node.get_outputs()
for succ_node in buffer.mpi_buffer.succ_nodes
)
node_to_succ_nodes[node] = succ_nodes
# For each successor, add current node as its predecessor
for succ_node in succ_nodes:
node_to_pred_nodes[succ_node].add(node)
# For each output buffer, add it as predecessor to its successor nodes
# TODO - is pred buffers needed ?
for buffer in node.get_outputs():
for succ_node in buffer.mpi_buffer.succ_nodes:
node_to_pred_buffers[succ_node].add(buffer)
for freeable_buffer in name_to_freeable_input_buf.values():
for succ_node in freeable_buffer.mpi_buffer.succ_nodes:
node_to_pred_buffers[succ_node].add(freeable_buffer)
# Second pass: assign memory planning info using completed predecessor mappings
for index, node in enumerate(nodes):
size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs())
succ_nodes = node_to_succ_nodes[node]
pred_nodes = node_to_pred_nodes[node]
# make sure we do not make node a successor or predecessor of itself
succ_nodes.discard(node)
pred_nodes.discard(node)
node.mpi_node = MemoryPlanningInfoForNode(
index=index,
size=size_alloc,
pred_buffers=node_to_pred_buffers[node],
pred_nodes=node_to_pred_nodes[node],
succ_nodes=succ_nodes,
)
# map each scheduler buffer to its size, start step, and end step
@dataclasses.dataclass
class BufferInfo:
buffer: Union[SchedulerBuffer, FreeableInputBuffer]
size_alloc: int
size_free: int
start_step: int
end_step: int
def compute_memory_timeline(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[
list[BufferInfo],
dict[BaseSchedulerNode, int],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Compute buffer allocation and deallocation sizes and map their
lifetime to the node schedule
"""
# get the execution step of each node, this will be used to determine
# the end_step of buffers
node_to_step: dict[BaseSchedulerNode, int] = {
node: step for step, node in enumerate(nodes)
}
# get buffers' size and liveliness information
buf_info_list: list[BufferInfo] = []
buf_to_snode_last_use: dict[
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
] = {}
def _get_end_step_and_snode(
buf: Union[FreeableInputBuffer, SchedulerBuffer],
) -> tuple[int, Optional[BaseSchedulerNode]]:
max_step: int = -1
max_step_snode: Optional[BaseSchedulerNode] = None
succ_nodes = buf.mpi_buffer.succ_nodes
if succ_nodes:
for succ_node in succ_nodes:
step = node_to_step[succ_node]
if step > max_step:
max_step = step
max_step_snode = succ_node
assert max_step_snode is not None
return max_step, max_step_snode
# 1. for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items():
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(input_buf)
assert end_step_snode is not None
buf_to_snode_last_use[input_buf] = end_step_snode
buf_info_list.append(
BufferInfo(
input_buf,
input_buf.mpi_buffer.size_free,
input_buf.mpi_buffer.size_free,
0,
end_step,
)
)
# 2. for scheduler buffers
for step, node in enumerate(nodes):
for sched_buf in node.get_outputs():
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
# to be only used by its defining op (e.g., due to fusion when all consumers of
# the buffer are fused with its defining op). In such cases, end_step is step.
buf_name = sched_buf.get_name()
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(sched_buf)
if end_step == -1:
end_step = step
buf_to_snode_last_use[sched_buf] = node
else:
assert end_step_snode is not None
buf_to_snode_last_use[sched_buf] = end_step_snode
buf_info_list.append(
BufferInfo(
sched_buf,
sched_buf.mpi_buffer.size_alloc,
sched_buf.mpi_buffer.size_free,
step,
end_step,
)
)
return buf_info_list, node_to_step, buf_to_snode_last_use
def estimate_peak_memory(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[int, list[int]]:
"""
Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
Returns:
int: peak memory
List[int]: memory usage at each node (or each step).
"""
buf_info_list, _, _ = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs
)
# incremental memory changes at each step
memory = [0 for _ in range(len(nodes) + 1)]
# for each buffer, update memory when created and when freed
for buf_info in buf_info_list:
memory[buf_info.start_step] += buf_info.size_alloc
memory[buf_info.end_step + 1] -= buf_info.size_free
# get peak memory by compute the cumulative memories
max_memory = 0
cur_memory = 0
memories_at_nodes = []
for t in range(len(nodes) + 1):
cur_memory += memory[t]
memories_at_nodes.append(cur_memory)
max_memory = max(max_memory, cur_memory)
return (max_memory, memories_at_nodes)
@dataclasses.dataclass
class SNodeMemory:
size_alloc: int
size_free: int
def estimate_peak_memory_allocfree(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[
int,
list[tuple[int, int]],
dict[BaseSchedulerNode, SNodeMemory],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Alternative version of estimate_peak_memory, that respects the fact,
that every SchedulerNode has multiple phases:
1. alloc ( outputs )
2. run_kernel
3. dealloc last_use buffers
estimate_peak_memory collapses memory into one value: size_alloc - size_free
While peak memory happens after alloc.
Duplicating the code to not migrate all callsites at once,
In future usages of estimate_peak_memory will migrate to this version.
"""
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs
)
# incremental memory changes at each step
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))]
# for each buffer, update memory when created and when freed
for buf_info in buf_info_list:
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc
if buf_info.end_step != -1:
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
snodes_allocfree = {}
for i, node in enumerate(nodes):
snodes_allocfree[node] = step_idx_allocfree[i]
max_memory = 0
cur_memory = 0
snodes_curr_memory = []
for t in range(len(nodes)):
alloc = step_idx_allocfree[t].size_alloc
free = step_idx_allocfree[t].size_free
cur_memory += alloc
post_alloc = cur_memory
max_memory = max(max_memory, cur_memory)
cur_memory -= free
post_free = cur_memory
snodes_curr_memory.append((post_alloc, post_free))
return (
max_memory,
snodes_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
)
def topological_sort_lpmf(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
name_to_buf: dict[str, SchedulerBuffer],
graph_outputs: OrderedSet[str],
) -> list[BaseSchedulerNode]:
"""
A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First".
The idea is from this paper:
Buffer memory optimization for video codec application modeled in Simulink
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
The algorithm maintains the max memory so far.
At every iteration, for each scheduleable node, it computes:
- how much memory needs to be allocated for the output buffers of this node;
- how much memory can be freed as a result of executing this node.
This gives us two values for each node:
(1) mem1: memory during the execution of the node;
(2) mem2: memory after executing the node, after some input buffers are freed.
The greedy approach select as follows:
(i) if there are nodes whose mem1 values are below the max memory so far,
then pick the node with the lowest mem2 value;
(ii) otherwise, pick the one with the lowest mem1 value.
"""
class NodeInfo(TypedDict):
indegree: int
memory_to_free: int
class BufferInfo(TypedDict):
outdegree: int
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
buf_info: dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict()
# compute nodes' number of unmet dependencies (for schedulability)
# initialize the list of nodes ready to be scheduled
nodes_to_schedule: OrderedSet[BaseSchedulerNode] = OrderedSet()
for node in nodes:
node_info[node] = {
"indegree": len(node.mpi_node.pred_nodes),
"memory_to_free": 0,
}
if node_info[node]["indegree"] == 0:
nodes_to_schedule.add(node)
# compute buffers' number of unmet successors (used to decide when to free)
for buf in list(name_to_buf.values()) + list(name_to_freeable_input_buf.values()):
buf_info[buf] = {
"outdegree": len(buf.mpi_buffer.succ_nodes)
+ (1 if buf.get_name() in graph_outputs else 0)
}
# initialize memory estimations
live_memory = sum(
input_buf.mpi_buffer.size_free
for input_buf in name_to_freeable_input_buf.values()
)
# this is the total output memory, which is a lower bound for peak memory
# we do not include the memory of non freeable input buffers
output_memory = 0
for buf_name in graph_outputs:
if buf_name in name_to_buf:
output_memory += name_to_buf[buf_name].mpi_buffer.size_free
elif buf_name in name_to_freeable_input_buf:
output_memory += name_to_freeable_input_buf[buf_name].mpi_buffer.size_free
max_memory = max(live_memory, output_memory)
memory_gap = max_memory - live_memory
# compute the amount of memory that is allocated when a node is scheduled
# and the amount of memory that can be freed when a node is scheduled
for node in nodes:
# 1. if a buffer read by this node is last used by this node
for buf in node.mpi_node.pred_buffers:
if buf_info[buf]["outdegree"] == 1:
node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free
# 2. if a buffer written by this node is used internally and not used later
for buf in node.get_outputs():
if buf_info[buf]["outdegree"] == 0:
node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free
# schedule nodes one at a time
schedule: list[BaseSchedulerNode] = []
size_threshold = config.size_threshold_for_succ_based_strategy
num_iters: int = 0
while num_iters < len(nodes) and nodes_to_schedule:
# select a node to schedule:
if (
size_threshold > 0
and min(node.mpi_node.size for node in nodes_to_schedule) > size_threshold
):
selected_node = min(
nodes_to_schedule,
key=lambda node: min(
(
succ_node.mpi_node.index
for succ_node in node.mpi_node.succ_nodes
),
default=len(nodes),
),
)
else:
selected_node = min(
nodes_to_schedule,
key=lambda node: (
node.mpi_node.size if node.mpi_node.size > memory_gap else 0,
node.mpi_node.size - node_info[node]["memory_to_free"],
node.mpi_node.index,
),
)
nodes_to_schedule.remove(selected_node)
schedule.append(selected_node)
num_iters += 1
# update memory usage
live_memory += selected_node.mpi_node.size
max_memory = max(max_memory, live_memory)
live_memory -= node_info[selected_node]["memory_to_free"]
memory_gap = max_memory - live_memory
# update successor nodes and nodes_to_schedule
for succ_node in selected_node.mpi_node.succ_nodes:
assert node_info[succ_node]["indegree"] > 0
node_info[succ_node]["indegree"] -= 1
if node_info[succ_node]["indegree"] == 0:
nodes_to_schedule.add(succ_node)
# update predecessor nodes
for buf in selected_node.mpi_node.pred_buffers:
assert buf_info[buf]["outdegree"] > 0
buf_info[buf]["outdegree"] -= 1
if buf_info[buf]["outdegree"] == 1:
for succ_node in buf.mpi_buffer.succ_nodes:
node_info[succ_node]["memory_to_free"] += buf.mpi_buffer.size_free
if num_iters > len(nodes):
raise RuntimeError("Failed to schedule, while loop ran too long for lpmf")
return schedule
def topological_sort_bfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
A BFS topological sort that selects nodes whose dependencies are executed the
earliest. This follows a FIFO idea. Specifically, at every iteration, for each node
that is schedulable, we gather the order in which its predecessor nodes are executed,
and this sorted list of execution orders of predecessor nodes defines the priority.
We select the node whose predecessors nodes are executed the earliest. The FIFO
idea aims to reduce the liveness duration of buffers created.
"""
class NodeInfo(TypedDict):
indegree: int
order: int
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
@dataclasses.dataclass
class NodeWithPriority:
priority: list[int]
node: BaseSchedulerNode
def __lt__(self, other: NodeWithPriority) -> bool:
if self.priority == other.priority:
return self.node.mpi_node.index < other.node.mpi_node.index
return self.priority < other.priority
def _node_priority(node: BaseSchedulerNode) -> list[int]:
# priority is the order in which predecessor nodes are executed
assert node_info[node]["indegree"] == 0
exec_orders = sorted(
OrderedSet(
node_info[pred_node]["order"] for pred_node in node.mpi_node.pred_nodes
)
)
return exec_orders
# compute nodes' number of unmet dependencies (for schedulability)
# initialize the list of nodes ready to be scheduled
nodes_to_schedule: list[NodeWithPriority] = []
for node in nodes:
node_info[node] = {"indegree": len(node.mpi_node.pred_nodes), "order": -1}
if node_info[node]["indegree"] == 0:
heapq.heappush(
nodes_to_schedule, NodeWithPriority(_node_priority(node), node)
)
# schedule nodes one at a time
schedule: list[BaseSchedulerNode] = []
num_iters: int = 0
while num_iters < len(nodes) and nodes_to_schedule:
# select a node to schedule
selected_node = heapq.heappop(nodes_to_schedule).node
node_info[selected_node]["order"] = len(schedule)
schedule.append(selected_node)
num_iters += 1
# update successor nodes and nodes_to_schedule
for succ_node in selected_node.mpi_node.succ_nodes:
assert node_info[succ_node]["indegree"] > 0
node_info[succ_node]["indegree"] -= 1
if node_info[succ_node]["indegree"] == 0:
heapq.heappush(
nodes_to_schedule,
NodeWithPriority(_node_priority(succ_node), succ_node),
)
if num_iters > len(nodes):
raise RuntimeError("Failed to schedule, while loop ran too long for bfs")
return schedule
def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
This is a DFS topological sort. The setup is similar to `topological_sort_schedule`
in scheduler.py. The difference is the order nodes are visited in the outer loop.
In `topological_sort_schedule`, nodes are visited in their original order.
In this function, nodes are visited based on their priority -- for each node, we
compute the total memory of all buffers it reads from or writes to, and we visit
the nodes in ascending order of this priority.
"""
seen: OrderedSet[BaseSchedulerNode] = OrderedSet()
name_to_node: dict[str, BaseSchedulerNode] = dict()
result: list[BaseSchedulerNode] = []
size_with_reads: dict[BaseSchedulerNode, int] = dict()
def visit(n: BaseSchedulerNode) -> None:
if n not in seen:
seen.add(n)
dep_nodes = [
name_to_node[dep.name]
for dep in n.unmet_dependencies
if dep.name in name_to_node
]
for node in sorted(
dep_nodes, key=lambda n: (size_with_reads[n], n.mpi_node.index)
):
visit(node)
result.append(n)
for node in nodes:
for name in node.get_buffer_names():
name_to_node[name] = node
for node in nodes:
size_with_reads[node] = node.mpi_node.size + sum(
pred_buf.mpi_buffer.size_free for pred_buf in node.mpi_node.pred_buffers
)
for node in sorted(nodes, key=lambda n: (size_with_reads[n], n.mpi_node.index)):
visit(node)
return result
def validate_graph_acyclic(nodes: list[BaseSchedulerNode]) -> None:
"""
Validate that the graph is acyclic by checking predecessor relationships.
Raises:
RuntimeError: If a cycle is detected in the graph
"""
# DFS coloring scheme for cycle detection:
# WHITE (0): Node has not been visited yet
# GRAY (1): Node is currently being processed (in the recursion stack)
# BLACK (2): Node has been completely processed (finished exploring all its predecessors)
# A back edge (cycle) is detected when we encounter a GRAY node during DFS traversal
WHITE, GRAY, BLACK = 0, 1, 2
color = dict.fromkeys(nodes, WHITE)
path: list[BaseSchedulerNode] = [] # Track current DFS path
def dfs_visit(node: BaseSchedulerNode) -> None:
if color[node] == BLACK:
return
if color[node] == GRAY:
path.append(node)
path_info = " -> ".join([node.get_name() for node in path])
raise RuntimeError(
f"Cycle detected in memory planning graph"
f"Path containing cycle (i -> j: j is a dependency of i): {path_info} "
f"This indicates invalid dependency relationships in the scheduler graph"
)
color[node] = GRAY
path.append(node)
for pred_node in node.mpi_node.pred_nodes:
assert pred_node != node
dfs_visit(pred_node)
path.pop()
color[node] = BLACK
# Start DFS from all unvisited nodes
for node in nodes:
if color[node] == WHITE:
dfs_visit(node)
def validate_unique_buffer_names(
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
) -> None:
"""
Validate that for each node's output buffer, the name_to_buf mapping is correct.
For each output buffer buf, we should have name_to_buf[buf.get_name()] == buf.
Also validate that no buffer names overlap with freeable input buffer names.
Raises:
RuntimeError: If buffer name mapping is incorrect or names overlap
"""
for node in nodes:
for buf in node.get_outputs():
buf_name = buf.get_name()
# Check if buffer name exists in the mapping
if buf_name not in name_to_buf:
raise RuntimeError(
f"{buf_name} from {node.get_name()} is not found in name_to_buf mapping."
f" This indicates a missing buffer mapping."
)
# Check if the mapping points to the correct buffer object
if name_to_buf[buf_name] != buf:
raise RuntimeError(
f"Buffer name mapping is incorrect for '{buf_name}'."
f"Expected name_to_buf['{buf_name}'] to be {buf.debug_str()}"
f"but got {name_to_buf[buf_name].debug_str()}"
f"This indicates some buffers share the same name"
)
# Check if buffer name conflicts with freeable input buffer names
if buf_name in name_to_freeable_input_buf:
raise RuntimeError(
f"Buffer name conflict detected: '{buf_name}' from node {node.get_name()} "
f"is also used as a freeable input buffer name. "
)
def prepare_planning_info(
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
name_to_fused_node: dict[str, BaseSchedulerNode],
graph_inputs: OrderedSet[str],
graph_outputs: OrderedSet[str],
) -> tuple[int, dict[str, FreeableInputBuffer]]:
"""
Prepare planning info. As nodes are scheduled one at a time, these help
keep track of when a buffer can be freed, and when a node can be scheduled
Returns:
int: peak memory estimation
dict[str, FreeableInputBuffer]: name to freeable input buffer
"""
name_to_freeable_input_buf = get_freeable_input_buf(nodes, graph_inputs)
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
assign_memory_planning_info_for_scheduler_nodes(
nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf
)
# the default
estimated_peak_memory, _ = estimate_peak_memory(
nodes, name_to_freeable_input_buf, graph_outputs
)
return estimated_peak_memory, name_to_freeable_input_buf
def reorder_for_peak_memory(
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
name_to_fused_node: dict[str, BaseSchedulerNode],
graph_inputs: OrderedSet[str],
graph_outputs: OrderedSet[str],
methods: list[Callable[..., list[BaseSchedulerNode]]] = [ # noqa: B006
topological_sort_lpmf,
topological_sort_bfs,
topological_sort_dfs,
],
) -> list[BaseSchedulerNode]:
"""
Try a few heuristics based topological sort algorithms, and pick the one whose
resulting topological order has the lowest peak memory estimation.
"""
torch_log.info("Reordering for peak memory -- %d nodes", len(nodes))
estimated_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
)
# export graph for simulator if needed
if config.reorder_for_peak_memory_debug:
export_graph_for_simulator(
nodes,
name_to_freeable_input_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
)
# Validate planning info before proceeding with reordering
try:
validate_graph_acyclic(nodes)
validate_unique_buffer_names(nodes, name_to_buf, name_to_freeable_input_buf)
except RuntimeError as e:
torch_log.error("Memory planning validation failed: %s", e)
if not is_fbcode(): # TODO: remove after ensuring OSS side is safe
raise
# keep track of the peak memory estimates of different methods
peak_memory_diff_methods: list[PeakMemoryResult] = []
peak_memory_diff_methods.append(
PeakMemoryResult(nodes, estimated_peak_memory, "baseline")
)
torch_log.info("Baseline peak memory: %d", estimated_peak_memory)
# other methods
for method in methods:
try:
if method == topological_sort_lpmf:
order = method(
nodes, name_to_freeable_input_buf, name_to_buf, graph_outputs
)
else:
order = method(nodes)
assert len(order) == len(nodes)
peak_memory, _ = estimate_peak_memory(
order, name_to_freeable_input_buf, graph_outputs
)
peak_memory_diff_methods.append(
PeakMemoryResult(order, peak_memory, method.__name__)
)
torch_log.info("%s peak memory: %d", method.__name__, peak_memory)
except Exception as e:
torch_log.error("Failed to reorder for %s: %s", method.__name__, e)
if not is_fbcode(): # TODO: remove after ensuring OSS side is safe
raise
signpost_event(
category="inductor",
name="memory",
parameters={
"orm": {elem.method: elem.peak_memory for elem in peak_memory_diff_methods},
},
)
# get the optimal one
best_result = min(peak_memory_diff_methods, key=lambda x: x.peak_memory)
return best_result.order
def export_graph_for_simulator(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
name_to_fused_node: dict[str, BaseSchedulerNode],
graph_inputs: OrderedSet[str],
graph_outputs: OrderedSet[str],
) -> None:
"""
This is for debugging purposes. It will dump a json file that records graph information.
The graph can then be used in a simulator: https://fburl.com/code/3l3d3qi4
"""
class ORMBuffer(TypedDict):
name: str
size_alloc: int
size_free: int
size: int # for backward compatibility
is_input: bool
is_output: bool
deps: list[str]
unmet_deps: list[str]
class ORMNode(TypedDict):
name: str
buffer_names: list[str]
class ORMGraph(TypedDict):
nodes: list[ORMNode]
buffers: list[ORMBuffer]
orm_buffers: list[ORMBuffer] = []
orm_nodes: list[ORMNode] = []
# get orm buffers for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items():
orm_buf_input_buffer: ORMBuffer = {
"name": buf_name,
"size_alloc": input_buf.mpi_buffer.size_free,
"size_free": input_buf.mpi_buffer.size_free,
"size": input_buf.mpi_buffer.size_free,
"is_input": True,
"is_output": buf_name in graph_outputs,
"deps": [],
"unmet_deps": [],
}
orm_buffers.append(orm_buf_input_buffer)
# get orm buffers for scheduler buffers
name_to_buf: dict[str, SchedulerBuffer] = {
buf.get_name(): buf for node in nodes for buf in node.get_outputs()
} # need to reassign due to probably node pruning
for buf_name, sched_buf in name_to_buf.items():
if sched_buf.defining_op is None:
continue
deps = [
pred_buf.get_name()
for pred_buf in name_to_fused_node[
sched_buf.defining_op.get_name()
].mpi_node.pred_buffers
]
orm_buf_scheduler_buffer: ORMBuffer = {
"name": buf_name,
"size_alloc": sched_buf.mpi_buffer.size_alloc,
"size_free": sched_buf.mpi_buffer.size_free,
"size": sched_buf.mpi_buffer.size_free,
"is_input": False,
"is_output": buf_name in graph_outputs,
"deps": deps,
"unmet_deps": [
buf_name for buf_name in deps if buf_name not in graph_inputs
],
}
orm_buffers.append(orm_buf_scheduler_buffer)
# get orm nodes
for node in nodes:
orm_node: ORMNode = {
"name": node.get_name(),
"buffer_names": list(node.get_buffer_names()),
}
orm_nodes.append(orm_node)
# create the graph object
g: ORMGraph = {
"nodes": orm_nodes,
"buffers": orm_buffers,
}
# dump the graph
import json
import os
import torch
from functorch.compile import get_graph_being_compiled
name = os.path.splitext(get_graph_being_compiled())[0] + "_fused"
g_str = json.dumps(g, indent=2)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": name,
"encoding": "string",
},
payload_fn=lambda: g_str,
)