[inductor][scheduler] reorder scheduler nodes after fusion to reduce peak memory (#134874)

**Motivations**:
A topological order of the scheduler nodes that optimize the liveness of buffers can reduce the peak memory utilization. This has been observed and studied e.g., [here](https://arxiv.org/pdf/1910.02653) and [here](https://proceedings.mlr.press/v202/steiner23a/steiner23a.pdf).

**Solutions**:
1. implement a peak memory estimator via liveness analysis
2. implement a few memory aware topological sorting algorithms and pick the one with the lowest peak memory

**Results**:
On some models we can reduce the peak memory significantly:
|             model             | batch size | peak_memory baseline | peak_memory new | ratio |
|:-----------------------------:|:----------:|:--------------------:|:---------------:|:-----:|
| alexnet                       | 128        |         1.17         |       0.99      | 1.19  |
| vgg16                         | 64         |         4.10         |       3.57      | 1.15  |
| DebertaV2ForQuestionAnswering | 1          |         11.60        |      10.56      | 1.10  |

In the presence of compiler based AC, peak memory can be further reduced:
|              model             | batch size | peak_memory baseline | peak_memory new | ratio |
|:------------------------------:|:----------:|:--------------------:|:---------------:|:-----:|
| AlbertForMaskedLM              | 4          |         6.87         |       6.43      | 1.07  |
| AlbertForQuestionAnswering     | 4          |         8.69         |       7.76      | 1.12  |
| MobileBertForQuestionAnswering | 128        |         4.67         |       3.90      | 1.20  |

[Here](https://fb.workplace.com/groups/1075192433118967/posts/1499920537312819/?comment_id=1499938843977655&reply_comment_id=1499951630643043) is an internal use case.

**Other infos:**
* neutral model runtime, because the the reordering happens after fusion. So memory saving is _for free_.
* minimal compile time overhead as the algorithm is linear in the number of edges of the inductor graph. For all hugglingface benchmark models, the additional compile time is less than 1 second.
* no peak memory regression since we only adopt a new order if the peak memory is reduced based on the estimator. However, the model is unaware of operators' working memories, but for large models, the working memory should be negligible. We haven't observed any significant regressions on all of our tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134874
Approved by: https://github.com/yf225
This commit is contained in:
Xuan Zhang
2024-09-21 16:28:38 +00:00
committed by PyTorch MergeBot
parent fb4670a1f9
commit 03957efa5d
5 changed files with 779 additions and 1 deletions

View File

@ -1403,7 +1403,7 @@ test_linux_aarch64() {
inductor/test_pattern_matcher inductor/test_perf inductor/test_profiler inductor/test_select_algorithm inductor/test_smoke \
inductor/test_split_cat_fx_passes inductor/test_standalone_compile inductor/test_torchinductor \
inductor/test_torchinductor_codegen_dynamic_shapes inductor/test_torchinductor_dynamic_shapes \
--shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose
--shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose inductor/test_memory
}
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then

View File

@ -0,0 +1,205 @@
# Owner(s): ["module: inductor"]
from unittest import mock
import torch
from torch._C import FileCheck
from torch._dynamo.utils import same
from torch._inductor import config, memory
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.inductor_utils import HAS_GPU
class Foo(torch.nn.Module):
"""
The default compiled graph is
graph():
...
%op0 : [num_users=2] = call_function[...](args = (%primals_2, %primals_1), ...)
%op1 : [num_users=2] = call_function[...](args = (%primals_2, %primals_3), ...)
%op2 : [num_users=1] = call_function[...](args = (%op0, %primals_4), ...)
%op3 : [num_users=1] = call_function[...](args = (%op1, %primals_5), ...)
%op4 : [num_users=1] = call_function[...](args = (%op2,), ...)
%op5 : [num_users=1] = call_function[...](args = (%op3,), ...)
%op6_op7 : [num_users=1] = call_function[...](args = (%op5, %op4), ...)
"""
def __init__(self):
super().__init__()
self.w1 = torch.nn.Parameter(torch.ones(1, 10))
self.w2 = torch.nn.Parameter(torch.ones(1, 1))
self.w3 = torch.nn.Parameter(torch.ones(10, 1))
self.w4 = torch.nn.Parameter(torch.ones(1, 10))
def forward(self, x):
t1 = torch.matmul(x, self.w1)
t2 = torch.matmul(x, self.w2)
t3 = torch.matmul(t1, self.w3)
t4 = torch.matmul(t2, self.w4)
return t3.sum() + t4.sum()
class TestOperatorReorderForPeakMemory(TestCase):
def setUp(self):
super().setUp()
self.model = Foo().to("cuda")
self.inputs = torch.ones((2048, 1), device="cuda")
self.orig_reorder_method = memory.reorder_for_peak_memory
@mock.patch.object(config, "reorder_for_peak_memory", True)
def test_reorder_peak_memory(self):
outp_corr = self.model(self.inputs)
compiled_model = torch.compile(self.model)
code = run_and_get_triton_code(compiled_model, self.inputs)
(
FileCheck()
.check("def call(args):")
.check("buf1 = ")
.check("buf0 = ")
.check("buf2 = ")
.check("buf4 = ")
.check("buf3 = ")
.check("buf5 = ")
.check("buf7 = ")
.run(code)
)
# check for correctness
outp = compiled_model(self.inputs)
self.assertTrue(same(outp, outp_corr))
@mock.patch.object(config, "reorder_for_peak_memory", True)
def test_reorder_peak_memory_lpmf(self):
outp_corr = self.model(self.inputs)
def reorder_with_only_lpmf(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=None,
):
return self.orig_reorder_method(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=[memory.topological_sort_lpmf],
)
with mock.patch.object(
memory, "reorder_for_peak_memory", reorder_with_only_lpmf
):
compiled_model = torch.compile(self.model)
code = run_and_get_triton_code(compiled_model, self.inputs)
(
FileCheck()
.check("def call(args):")
.check("buf1 = ")
.check("buf0 = ")
.check("buf2 = ")
.check("buf4 = ")
.check("buf3 = ")
.check("buf5 = ")
.check("buf7 = ")
.run(code)
)
# check for correctness
outp = compiled_model(self.inputs)
self.assertTrue(same(outp, outp_corr))
@mock.patch.object(config, "reorder_for_peak_memory", True)
def test_reorder_peak_memory_bfs(self):
outp_corr = self.model(self.inputs)
def reorder_with_only_bfs(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=None,
):
return self.orig_reorder_method(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=[memory.topological_sort_bfs],
)
with mock.patch.object(
memory, "reorder_for_peak_memory", reorder_with_only_bfs
):
compiled_model = torch.compile(self.model)
code = run_and_get_triton_code(compiled_model, self.inputs)
(
FileCheck()
.check("def call(args):")
.check("buf0 = ")
.check("buf2 = ")
.check("buf1 = ")
.check("buf4 = ")
.check("buf3 = ")
.check("buf5 = ")
.check("buf7 = ")
.run(code)
)
# check for correctness
outp = compiled_model(self.inputs)
self.assertTrue(same(outp, outp_corr))
@mock.patch.object(config, "reorder_for_peak_memory", True)
def test_reorder_peak_memory_dfs(self):
outp_corr = self.model(self.inputs)
def reorder_with_only_dfs(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=None,
):
return self.orig_reorder_method(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=[memory.topological_sort_dfs],
)
with mock.patch.object(
memory, "reorder_for_peak_memory", reorder_with_only_dfs
):
compiled_model = torch.compile(self.model)
code = run_and_get_triton_code(compiled_model, self.inputs)
(
FileCheck()
.check("def call(args):")
.check("buf0 = ")
.check("buf2 = ")
.check("buf4 = ")
.check("buf1 = ")
.check("buf3 = ")
.check("buf5 = ")
.check("buf7 = ")
.run(code)
)
# check for correctness
outp = compiled_model(self.inputs)
self.assertTrue(same(outp, outp_corr))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests()

View File

@ -259,6 +259,9 @@ reorder_for_compute_comm_overlap_passes = [
"raise_comms",
]
# enable operator reordering for peak memory optimization
reorder_for_peak_memory = os.environ.get("TORCHINDUCTOR_REORDER_FOR_PEAK_MEMORY") == "1"
# runtime estimation function for ops
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
estimate_op_runtime = "default"

555
torch/_inductor/memory.py Normal file
View File

@ -0,0 +1,555 @@
from __future__ import annotations
import dataclasses
import heapq
import logging
from typing import Callable, Dict, List, Set, Tuple, TYPE_CHECKING, Union
from torch.utils._ordered_set import OrderedSet
from .ir import MultiOutputLayout
from .utils import get_dtype_size
from .virtualized import V
if TYPE_CHECKING:
from .dependencies import Dep
from .scheduler import BaseSchedulerNode, SchedulerBuffer
torch_log = logging.getLogger(__name__)
@dataclasses.dataclass
class MemoryPlanningInfoForBuffer:
size_alloc: int = 0
size_free: int = 0
succ_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
default_factory=OrderedSet
)
outdegree: int = 0 # this is used only in topological_sort_lpmf
@dataclasses.dataclass
class MemoryPlanningInfoForNode:
pred_buffers: List[Union[SchedulerBuffer, FreeableInputBuffer]] = dataclasses.field(
default_factory=list
)
pred_nodes: List[BaseSchedulerNode] = dataclasses.field(default_factory=list)
succ_nodes: List[BaseSchedulerNode] = dataclasses.field(default_factory=list)
indegree: int = 0
index: int = 0
size: int = 0
memory_to_free: int = 0 # this is used only in topological_sort_lpmf
size_with_reads: int = 0 # this is used only in topological_sort_dfs
@dataclasses.dataclass
class FreeableInputBuffer:
dep: Dep
mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field(
default_factory=MemoryPlanningInfoForBuffer
)
def get_name(self) -> str:
return self.dep.name
def __hash__(self) -> int:
return hash(self.dep.name)
def dep_size_hint(dep: Dep) -> int:
res = 0
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
return res
def get_freeable_input_buf(
nodes: List[BaseSchedulerNode],
graph_inputs: Set[str],
) -> Dict[str, FreeableInputBuffer]:
"""
Create and keep track of all input buffers that can be freed during the program
"""
name_to_input_buf: Dict[str, FreeableInputBuffer] = {}
for node in nodes:
for dep in node.read_writes.reads:
if (
dep.name in graph_inputs
and not dep.name.startswith("primals_")
and dep.name not in name_to_input_buf
):
name_to_input_buf[dep.name] = FreeableInputBuffer(dep)
name_to_input_buf[dep.name].mpi_buffer.size_free = dep_size_hint(dep)
return name_to_input_buf
def compute_size_for_scheduler_buffer(name_to_buf: Dict[str, SchedulerBuffer]) -> None:
"""
Compute the size of each scheduler buffer, including (1) memory allocated when
it is created and (2) memory deallocated when it is freed.
We specially handle the case of MultiOutputLayout.
Consider the following case:
buf0 = some_ops_with_multi_outputs(...)
buf1 = buf0[0] # assume 10 bytes
buf2 = buf0[1] # assume 20 bytes
In such cases,
buf0: at creation, 30 bytes allocated, when deleted, 0 bytes freed
buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed
buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed
"""
from .scheduler import BaseSchedulerNode, OutputNode
# compute the size of SchedulerBuffer without MultiOutputLayout layout
for sched_buf in name_to_buf.values():
if not isinstance(sched_buf.node.layout, MultiOutputLayout):
sched_buf.mpi_buffer = MemoryPlanningInfoForBuffer()
sched_buf.mpi_buffer.size_alloc = V.graph.sizevars.size_hint(
sched_buf.node.get_numel(), fallback=0
) * get_dtype_size(sched_buf.node.get_dtype())
sched_buf.mpi_buffer.size_free = sched_buf.mpi_buffer.size_alloc
# compute the size of SchedulerBuffer with MultiOutputLayout layout
for sched_buf in name_to_buf.values():
if isinstance(sched_buf.node.layout, MultiOutputLayout):
sched_buf.mpi_buffer = MemoryPlanningInfoForBuffer()
for user in sched_buf.users:
if isinstance(user.node, OutputNode):
continue
assert isinstance(user.node, BaseSchedulerNode)
for buf in user.node.get_outputs():
sched_buf.mpi_buffer.size_alloc += buf.mpi_buffer.size_alloc
buf.mpi_buffer.size_alloc = 0
def map_successor_nodes_with_predecessor_buffers(
nodes: List[BaseSchedulerNode],
name_to_input_buf: Dict[str, FreeableInputBuffer],
name_to_buf: Dict[str, SchedulerBuffer],
) -> None:
"""
For scheduling and memory estimation, for each scheduler node, we maintain
a list of its dependency buffers (SchedulerBuffer and FreeableInputBuffer).
This is similar to node.read_writes.reads, which is a list of Dep.
Reversely, for each SchedulerBuffer / FreeableInputBuffer, assign its successor nodes.
A buffer's successor nodes determines when a buffer can be freed.
"""
for node in nodes:
node.mpi_node = MemoryPlanningInfoForNode()
node.mpi_node.pred_buffers = []
for dep_name in {dep.name for dep in node.unmet_dependencies}:
sched_buf = name_to_buf.get(dep_name)
if sched_buf:
node.mpi_node.pred_buffers.append(sched_buf)
sched_buf.mpi_buffer.succ_nodes.add(node)
for dep_name in {
dep.name for dep in node.read_writes.reads - node.unmet_dependencies
}:
input_buf = name_to_input_buf.get(dep_name)
if input_buf:
node.mpi_node.pred_buffers.append(input_buf)
input_buf.mpi_buffer.succ_nodes.add(node)
def estimate_peak_memory(
nodes: List[BaseSchedulerNode],
name_to_input_buf: Dict[str, FreeableInputBuffer],
graph_outputs: Set[str],
) -> Tuple[int, List[int]]:
"""
Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
Returns:
int: peak memory
List[int]: memory usage at each node.
"""
# map each scheduler buffer to its size, start time, and end time
@dataclasses.dataclass
class BufferInfo:
buffer: Union[SchedulerBuffer, FreeableInputBuffer]
size_alloc: int
size_free: int
start_time: int
end_time: int
name_to_buf_info: Dict[str, BufferInfo] = {}
node_name_to_time: Dict[str, int] = {}
# assign start_time
for buf_name, input_buf in name_to_input_buf.items():
name_to_buf_info[buf_name] = BufferInfo(
input_buf,
input_buf.mpi_buffer.size_free,
input_buf.mpi_buffer.size_free,
0,
0,
)
for t, node in enumerate(nodes):
node_name_to_time[node.get_name()] = t
for sched_buf in node.get_outputs():
name_to_buf_info[sched_buf.get_name()] = BufferInfo(
sched_buf,
sched_buf.mpi_buffer.size_alloc,
sched_buf.mpi_buffer.size_free,
t,
t,
)
# assign end_time
for buf_name, buf_info in name_to_buf_info.items():
succ_node_time = [
node_name_to_time[succ_node.get_name()]
for succ_node in buf_info.buffer.mpi_buffer.succ_nodes
if succ_node.get_name() in node_name_to_time
]
if succ_node_time:
buf_info.end_time = max(succ_node_time)
# the end time of output buffers should be at the end of the horizon
for buf_name in graph_outputs:
if buf_name in name_to_buf_info:
name_to_buf_info[buf_name].end_time = len(nodes) - 1
# incremental memory changes at each time period
memory = [0 for _ in range(len(nodes) + 1)]
# for each buffer, update memory when created and when freed
for buf_name, buf_info in name_to_buf_info.items():
memory[buf_info.start_time] += buf_info.size_alloc
memory[buf_info.end_time + 1] -= buf_info.size_free
# get peak memory by compute the cumulative memories
max_memory = 0
cur_memory = 0
memories_at_nodes = []
for t in range(len(nodes) + 1):
cur_memory += memory[t]
memories_at_nodes.append(cur_memory)
max_memory = max(max_memory, cur_memory)
return (max_memory, memories_at_nodes)
def assign_predcessor_and_successor_nodes_to_nodes(
nodes: List[BaseSchedulerNode], name_to_fused_node: Dict[str, BaseSchedulerNode]
) -> None:
"""
Assign to each scheduler node its predecessor and successor nodes.
"""
from .scheduler import SchedulerBuffer
for node in nodes:
node.mpi_node.pred_nodes = list(
{
name_to_fused_node[pred_buffer.defining_op.get_name()]
for pred_buffer in node.mpi_node.pred_buffers
if (
isinstance(pred_buffer, SchedulerBuffer)
and pred_buffer.defining_op.get_name() in name_to_fused_node
)
}
)
node.mpi_node.succ_nodes = list(
{
succ_node
for buffer in node.get_outputs()
for succ_node in buffer.mpi_buffer.succ_nodes
}
)
def topological_sort_lpmf(
nodes: List[BaseSchedulerNode],
name_to_input_buf: Dict[str, FreeableInputBuffer],
name_to_buf: Dict[str, SchedulerBuffer],
graph_outputs: Set[str],
) -> List[BaseSchedulerNode]:
"""
A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First".
The algorithm maintain the max memory so far.
At every iteration, for each scheduleable node, it computes:
- how much memory needs to be allocated for the output buffers of this node;
- how much memory can be freed as a result of executing this node.
This gives us two values for each node:
(1) mem1: memory during the execution of the node;
(2) mem2: memory after executing the node, after some input buffers are freed.
The greedy approach select as follows:
(i) if there are nodes whose mem1 values are below the max memory so far,
then pick the node with the lowest mem2 value;
(ii) otherwise, pick the one with the lowest mem1 value.
"""
# compute nodes' number of unmet dependencies (for schedulability)
# initialize the list of nodes ready to be scheduled
nodes_to_schedule: Set[BaseSchedulerNode] = set()
for node in nodes:
# note that .unmet_dependencies could have deps with the same name
# and in that case, it should only be counted once
node.mpi_node.indegree = len(node.mpi_node.pred_nodes)
if node.mpi_node.indegree == 0:
nodes_to_schedule.add(node)
# compute buffers' number of unmet successors (used to decide when to free)
for buf in list(name_to_buf.values()) + list(name_to_input_buf.values()):
buf.mpi_buffer.outdegree = len(buf.mpi_buffer.succ_nodes)
if buf.get_name() in graph_outputs:
buf.mpi_buffer.outdegree += 1
# initialize memory estimations
live_memory = sum(
input_buf.mpi_buffer.size_free for input_buf in name_to_input_buf.values()
)
# this is the total output memory, which is a lower bound for peak memory
output_memory = sum(
name_to_buf[buf_name].mpi_buffer.size_free
for buf_name in graph_outputs
if buf_name in name_to_buf
)
max_memory = max(live_memory, output_memory)
# compute the amount of memory that is allocated when a node is scheduled
# and the amount of memory that can be freed when a node is scheduled
for i, node in enumerate(nodes):
node.mpi_node.index = i # keep track of the original order
node.mpi_node.size = sum(
buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()
)
node.mpi_node.memory_to_free = 0
# 1. if a buffer read by this node is last used by this node
# then the buffer can be freed
for buf in node.mpi_node.pred_buffers:
if buf.mpi_buffer.outdegree == 1:
node.mpi_node.memory_to_free += buf.mpi_buffer.size_free
# 2. if a buffer written by this node is used internally and
# not needed afterwards, it can be freed
for buf in node.get_outputs():
if buf.mpi_buffer.outdegree == 0:
node.mpi_node.memory_to_free += buf.mpi_buffer.size_free
# schedule nodes one at a time
schedule: List[BaseSchedulerNode] = []
num_iters: int = 0
while num_iters < len(nodes) and nodes_to_schedule:
# select a node to schedule:
selected_node = min(
nodes_to_schedule,
key=lambda node: (
max(live_memory + node.mpi_node.size, max_memory),
node.mpi_node.size - node.mpi_node.memory_to_free,
node.mpi_node.index,
),
)
nodes_to_schedule.remove(selected_node)
schedule.append(selected_node)
num_iters += 1
# update memory usage
live_memory += selected_node.mpi_node.size
max_memory = max(max_memory, live_memory)
live_memory -= selected_node.mpi_node.memory_to_free
# update successor nodes and nodes_to_schedule
for succ_node in selected_node.mpi_node.succ_nodes:
assert succ_node.mpi_node.indegree > 0
succ_node.mpi_node.indegree -= 1
if succ_node.mpi_node.indegree == 0:
nodes_to_schedule.add(succ_node)
# update predecessor nodes
for buf in selected_node.mpi_node.pred_buffers:
assert buf.mpi_buffer.outdegree > 0
buf.mpi_buffer.outdegree -= 1
if buf.mpi_buffer.outdegree == 1:
for succ_node in buf.mpi_buffer.succ_nodes:
succ_node.mpi_node.memory_to_free += buf.mpi_buffer.size_free
if num_iters > len(nodes):
raise RuntimeError("Failed to schedule, while loop ran too long for lpmf")
return schedule
def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
"""
A BFS topological sort that selects nodes whose dependencies are executed the
earliest. This follows a FIFO idea. Specifically, at every iteration, for each node
that is schedulable, we gather the order in which its predecessor nodes are executed,
and this sorted list of execution orders of predecessor nodes defines the priority.
We select the node whose predecessors nodes are executed the earliest. The FIFO
idea aims to reduce the liveness duration of buffers created.
"""
@dataclasses.dataclass
class HeapElement:
priority: List[int]
node: BaseSchedulerNode
def __lt__(self, other: HeapElement) -> bool:
if self.priority == other.priority:
return self.node.mpi_node.index < other.node.mpi_node.index
return self.priority < other.priority
def _node_priority(node: BaseSchedulerNode) -> List[int]:
assert node.mpi_node.indegree == 0
ids = sorted(
{pred_node.mpi_node.index for pred_node in node.mpi_node.pred_nodes}
)
ids.append(node.mpi_node.index)
return ids
# compute nodes' number of unmet dependencies (for schedulability)
# initialize the list of nodes ready to be scheduled
nodes_to_schedule: List[HeapElement] = []
for t, node in enumerate(nodes):
node.mpi_node.index = t
# note that .unmet_dependencies could have deps with the same name
# and in that case, it should only be counted once
node.mpi_node.indegree = len(node.mpi_node.pred_nodes)
if node.mpi_node.indegree == 0:
heapq.heappush(nodes_to_schedule, HeapElement(_node_priority(node), node))
# schedule nodes one at a time
schedule: List[BaseSchedulerNode] = []
num_iters: int = 0
while num_iters < len(nodes) and nodes_to_schedule:
# select a node to schedule
selected_node = heapq.heappop(nodes_to_schedule).node
selected_node.mpi_node.index = len(schedule)
schedule.append(selected_node)
num_iters += 1
# update successor nodes and nodes_to_schedule
for succ_node in selected_node.mpi_node.succ_nodes:
assert succ_node.mpi_node.indegree > 0
succ_node.mpi_node.indegree -= 1
if succ_node.mpi_node.indegree == 0:
heapq.heappush(
nodes_to_schedule,
HeapElement(_node_priority(succ_node), succ_node),
)
if num_iters > len(nodes):
raise RuntimeError("Failed to schedule, while loop ran too long for bfs")
return schedule
def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
"""
This is a DFS topological sort. The setup is similar to `topological_sort_schedule`
in scheduler.py. The difference is the order nodes are visited in the outer loop.
In `topological_sort_schedule`, nodes are visited in their original order.
In this function, nodes are visited based on their priority -- for each node, we
compute the total memory of all buffers it reads from or writes to, and we visit
the nodes in ascending order of this priority.
"""
seen: OrderedSet[BaseSchedulerNode] = OrderedSet()
name_to_node: Dict[str, BaseSchedulerNode] = dict()
result: List[BaseSchedulerNode] = []
def visit(n: BaseSchedulerNode) -> None:
if n not in seen:
seen.add(n)
dep_nodes = [
name_to_node[dep.name]
for dep in n.unmet_dependencies
if dep.name in name_to_node
]
for node in sorted(
dep_nodes, key=lambda x: (x.mpi_node.size_with_reads, x.mpi_node.index)
):
visit(node)
result.append(n)
for node in nodes:
for name in node.get_buffer_names():
name_to_node[name] = node
for t, node in enumerate(nodes):
node.mpi_node.index = t
node.mpi_node.size = sum(
buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()
)
node.mpi_node.size_with_reads = node.mpi_node.size + sum(
pred_buf.mpi_buffer.size_free for pred_buf in node.mpi_node.pred_buffers
)
for node in sorted(
nodes, key=lambda x: (x.mpi_node.size_with_reads, x.mpi_node.index)
):
visit(node)
return result
def reorder_for_peak_memory(
nodes: List[BaseSchedulerNode],
name_to_buf: Dict[str, SchedulerBuffer],
name_to_fused_node: Dict[str, BaseSchedulerNode],
graph_inputs: Set[str],
graph_outputs: Set[str],
methods: List[Callable[..., List[BaseSchedulerNode]]] = [ # noqa: B006
topological_sort_lpmf,
topological_sort_bfs,
topological_sort_dfs,
],
) -> List[BaseSchedulerNode]:
"""
Try a few heuristics based topological sort algorithms, and pick the one whose
resulting topological order has the lowest peak memory estimation.
"""
torch_log.warning("Reordering for peak memory -- %d nodes", len(nodes))
@dataclasses.dataclass
class PeakMemoryResult:
order: List[BaseSchedulerNode]
peak_memory: int
# preparation -- as nodes are scheduled one at a time, these help
# keep track of when a buffer can be freed, and when a node can be scheduled
name_to_input_buf: Dict[str, FreeableInputBuffer] = get_freeable_input_buf(
nodes, graph_inputs
)
compute_size_for_scheduler_buffer(name_to_buf)
map_successor_nodes_with_predecessor_buffers(nodes, name_to_input_buf, name_to_buf)
assign_predcessor_and_successor_nodes_to_nodes(nodes, name_to_fused_node)
# keep track of the peak memory estimates of different methods
peak_memory_diff_methods: List[PeakMemoryResult] = []
# the default
estimated_peak_memory, _ = estimate_peak_memory(
nodes, name_to_input_buf, graph_outputs
)
peak_memory_diff_methods.append(PeakMemoryResult(nodes, estimated_peak_memory))
torch_log.warning("Baseline peak memory: %d", estimated_peak_memory)
# other methods
for method in methods:
try:
if method == topological_sort_lpmf:
order = method(nodes, name_to_input_buf, name_to_buf, graph_outputs)
else:
order = method(nodes)
assert len(order) == len(nodes)
peak_memory, _ = estimate_peak_memory(
order, name_to_input_buf, graph_outputs
)
peak_memory_diff_methods.append(PeakMemoryResult(order, peak_memory))
torch_log.warning("%s peak memory: %d", method.__name__, peak_memory)
except Exception as e:
torch_log.error("Failed to reorder for %s: %s", method.__name__, e)
# get the optimal one
best_result = min(peak_memory_diff_methods, key=lambda x: x.peak_memory)
return best_result.order

View File

@ -48,6 +48,7 @@ from .comm_analysis import estimate_nccl_collective_runtime
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout
from .loop_body import LoopBody
from .memory import MemoryPlanningInfoForBuffer, MemoryPlanningInfoForNode
from .runtime.runtime_utils import green_text, red_text
from .sizevars import SimplifyIndexing
from .utils import (
@ -77,6 +78,9 @@ class SchedulerBuffer:
node: ir.Buffer
defining_op: BaseSchedulerNode
users: List[NodeUser] = dataclasses.field(default_factory=list)
mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field(
default_factory=MemoryPlanningInfoForBuffer
)
def __hash__(self) -> int:
return hash(self.node.name)
@ -167,6 +171,7 @@ class BaseSchedulerNode:
# .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
min_order: int
max_order: int
mpi_node: MemoryPlanningInfoForNode
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
@ -1804,6 +1809,16 @@ class Scheduler:
if config._pre_fusion_custom_pass is not None:
self.nodes = config._pre_fusion_custom_pass(self.nodes)
self.nodes = self.fuse_nodes(self.nodes)
if config.reorder_for_peak_memory:
from .memory import reorder_for_peak_memory
self.nodes = reorder_for_peak_memory(
self.nodes,
self.name_to_buf,
self.name_to_fused_node,
set(V.graph.graph_inputs.keys()),
set(V.graph.get_output_names()),
)
self.merge_loops()
self.finalize_multi_template_buffers()
if config.reorder_for_compute_comm_overlap: