mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)
1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory: ``` """ Alternative version of estimate_peak_memory, that respects the fact, that every SchedulerNode has multiple phases: 1. alloc ( outputs ) 2. run_kernel 3. dealloc last_use buffers estimate_peak_memory collapses memory into one value: size_alloc - size_free While peak memory happens after alloc. Duplicating the code to not migrate all callsites at once, In future usages of estimate_peak_memory will migrate to this version. """ ``` - Applying this in `reorder_communication_preserving_peak_memory` pass. 2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode. - Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size). 4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder. What is after this PR: Iterative recomputation of memory estimations matches full memory estimations. Active memory is not regressing a lot, but reserved memory is significantly regressed. Investigation and fix of "reserved" memory will be in following PRs. BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb ``` [rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step: 1 loss: 12.2722 grad_norm: 4.2192 active_memory: 24.66GiB(25.96%) reserved_memory: 25.38GiB(26.72%) tps: 99 tflops: 5.71 mfu: 0.58% [rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step: 2 loss: 13.1738 grad_norm: 50.5566 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 4,448 tflops: 257.63 mfu: 26.05% [rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step: 3 loss: 15.6866 grad_norm: 80.0862 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,900 tflops: 341.72 mfu: 34.55% [rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step: 4 loss: 13.4853 grad_norm: 7.8538 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,881 tflops: 340.57 mfu: 34.44% [rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step: 5 loss: 16.1191 grad_norm: 53.2481 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,867 tflops: 339.77 mfu: 34.35% ``` REORDER: active: 32Gb reserved: 36Gb ``` [rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step: 1 loss: 12.2490 grad_norm: 4.1944 active_memory: 24.66GiB(25.96%) reserved_memory: 26.81GiB(28.22%) tps: 85 tflops: 4.90 mfu: 0.50% [rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step: 2 loss: 13.1427 grad_norm: 39.5942 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 3,205 tflops: 185.61 mfu: 18.77% [rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step: 3 loss: 14.6084 grad_norm: 51.0743 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,688 tflops: 329.44 mfu: 33.31% [rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step: 4 loss: 13.6181 grad_norm: 8.1122 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,744 tflops: 332.68 mfu: 33.64% [rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step: 5 loss: 15.8913 grad_norm: 59.8510 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,046 tflops: 292.22 mfu: 29.55% ``` REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb ``` [rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step: 1 loss: 12.2646 grad_norm: 4.1282 active_memory: 27.60GiB(29.05%) reserved_memory: 32.49GiB(34.20%) tps: 173 tflops: 10.00 mfu: 1.01% [rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step: 2 loss: 13.2353 grad_norm: 42.4234 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,152 tflops: 356.26 mfu: 36.02% [rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step: 3 loss: 13.8205 grad_norm: 24.0156 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,169 tflops: 357.29 mfu: 36.13% [rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step: 4 loss: 13.1033 grad_norm: 9.1167 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,183 tflops: 358.10 mfu: 36.21% [rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step: 5 loss: 16.3530 grad_norm: 51.8118 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,130 tflops: 355.03 mfu: 35.90% ``` Differential Revision: [D80718143](https://our.internmc.facebook.com/intern/diff/D80718143) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113 Approved by: https://github.com/wconstab, https://github.com/eellison Co-authored-by: eellison <elias.ellison@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
639b8cc51d
commit
db44de4c0d
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import heapq
|
import heapq
|
||||||
import importlib
|
import importlib
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
import sys
|
import sys
|
||||||
@ -23,8 +24,15 @@ from .dependencies import WeakDep
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .ir import IRNode, Operation
|
from .ir import IRNode, Operation
|
||||||
|
from .scheduler import SchedulerBuffer
|
||||||
|
|
||||||
from .memory import estimate_peak_memory, FreeableInputBuffer, get_freeable_input_buf
|
from .memory import (
|
||||||
|
estimate_peak_memory,
|
||||||
|
estimate_peak_memory_allocfree,
|
||||||
|
FreeableInputBuffer,
|
||||||
|
get_freeable_input_buf,
|
||||||
|
SNodeMemory,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
contains_collective,
|
contains_collective,
|
||||||
contains_wait,
|
contains_wait,
|
||||||
@ -188,6 +196,46 @@ def _is_fake_dep(d):
|
|||||||
return isinstance(d, WeakDep) and d.is_fake
|
return isinstance(d, WeakDep) and d.is_fake
|
||||||
|
|
||||||
|
|
||||||
|
def _group_names(gns: list[BaseSchedulerNode]) -> str:
|
||||||
|
return "~".join([gn.get_name() for gn in gns])
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs):
|
||||||
|
"""Initialize memory tracking data structures"""
|
||||||
|
name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs)
|
||||||
|
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
|
||||||
|
estimate_peak_memory_allocfree(
|
||||||
|
snodes, name_to_freeable_input_buf, graph_outputs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_curr_memory = dict(zip(snodes, snodes_curr_memory))
|
||||||
|
_curr_memory[None] = (0, 0)
|
||||||
|
return (
|
||||||
|
peak_memory,
|
||||||
|
_curr_memory,
|
||||||
|
snodes_allocfree,
|
||||||
|
buf_to_snode_last_use,
|
||||||
|
name_to_freeable_input_buf,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_double_linked_list(
|
||||||
|
snodes: list[BaseSchedulerNode],
|
||||||
|
) -> tuple[
|
||||||
|
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
|
||||||
|
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
|
||||||
|
BaseSchedulerNode,
|
||||||
|
]:
|
||||||
|
"""Create double-linked list structure from snodes"""
|
||||||
|
_prev = {}
|
||||||
|
_next = {}
|
||||||
|
for i, snode in enumerate(snodes):
|
||||||
|
_prev[snode] = snodes[i - 1] if i > 0 else None
|
||||||
|
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
||||||
|
_head = snodes[0]
|
||||||
|
return _prev, _next, _head
|
||||||
|
|
||||||
|
|
||||||
def _reorder_communication_preserving_peak_memory_internal(
|
def _reorder_communication_preserving_peak_memory_internal(
|
||||||
snodes: list[BaseSchedulerNode],
|
snodes: list[BaseSchedulerNode],
|
||||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
||||||
@ -211,20 +259,22 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||||||
# heuristic to avoid degenerating to quadratic time
|
# heuristic to avoid degenerating to quadratic time
|
||||||
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
||||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
(
|
||||||
snodes, graph_inputs
|
peak_memory,
|
||||||
)
|
_curr_memory,
|
||||||
peak_memory, curr_memory = estimate_peak_memory(
|
snodes_allocfree,
|
||||||
snodes, name_to_freeable_input_buf, graph_outputs
|
buf_to_snode_last_use,
|
||||||
)
|
name_to_freeable_input_buf,
|
||||||
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
|
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
|
||||||
_curr_memory = dict(zip(snodes, curr_memory))
|
runtimes: dict[BaseSchedulerNode, float] = {
|
||||||
_curr_memory[None] = 0 # type: ignore[index]
|
snode: estimate_op_runtime(snode) for snode in snodes
|
||||||
|
}
|
||||||
# debug stats
|
# debug stats
|
||||||
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
|
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
|
||||||
|
|
||||||
def exposed_communication_time(collective_snode, remaining_snodes):
|
def exposed_communication_time(
|
||||||
|
collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode]
|
||||||
|
) -> float:
|
||||||
# assumes a linear schedule and computes the overlap of the collective with the remaining nodes
|
# assumes a linear schedule and computes the overlap of the collective with the remaining nodes
|
||||||
comm_time = estimate_op_runtime(collective_snode)
|
comm_time = estimate_op_runtime(collective_snode)
|
||||||
compute_time = 0.0
|
compute_time = 0.0
|
||||||
@ -236,7 +286,7 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||||||
# we can ignore it. Otherwise, it's the end of the road for overlap opportunities
|
# we can ignore it. Otherwise, it's the end of the road for overlap opportunities
|
||||||
break
|
break
|
||||||
|
|
||||||
def accumulate_time(_snode):
|
def accumulate_time(_snode: BaseSchedulerNode) -> None:
|
||||||
nonlocal compute_time
|
nonlocal compute_time
|
||||||
compute_time += runtimes[_snode]
|
compute_time += runtimes[_snode]
|
||||||
|
|
||||||
@ -245,18 +295,11 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||||||
|
|
||||||
total_moves = 0
|
total_moves = 0
|
||||||
|
|
||||||
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping
|
_prev, _next, _head = _initialize_double_linked_list(snodes)
|
||||||
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
|
||||||
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
|
||||||
for i, snode in enumerate(snodes):
|
|
||||||
_prev[snode] = snodes[i - 1] if i > 0 else None
|
|
||||||
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
|
||||||
_curr_memory = dict(zip(snodes, curr_memory))
|
|
||||||
_curr_memory[None] = 0 # type: ignore[index]
|
|
||||||
|
|
||||||
_head = snodes[0]
|
def _group_nodes(
|
||||||
|
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
|
||||||
def _group_nodes(head, tail):
|
) -> list[BaseSchedulerNode]:
|
||||||
ret = []
|
ret = []
|
||||||
n = head
|
n = head
|
||||||
while True:
|
while True:
|
||||||
@ -264,37 +307,167 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||||||
ret.append(n)
|
ret.append(n)
|
||||||
if n == tail:
|
if n == tail:
|
||||||
break
|
break
|
||||||
n = _next[n]
|
n = _next[n] # type: ignore[index]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _group_names(head, tail):
|
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
|
||||||
ret = ""
|
# swap (candidate, group_head...group_tail)
|
||||||
for n in _group_nodes(head, tail):
|
# Before:
|
||||||
if ret:
|
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
|
||||||
ret += "~"
|
# After:
|
||||||
ret += n.get_name()
|
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
|
||||||
return ret
|
# 0
|
||||||
|
candidate_prev = _prev[candidate]
|
||||||
|
if candidate_prev:
|
||||||
|
_next[candidate_prev] = group_head
|
||||||
|
_prev[group_head] = candidate_prev
|
||||||
|
|
||||||
|
# 2
|
||||||
|
group_tail_next = _next[group_tail]
|
||||||
|
if group_tail_next:
|
||||||
|
_prev[group_tail_next] = candidate
|
||||||
|
_next[candidate] = group_tail_next
|
||||||
|
|
||||||
|
# 1
|
||||||
|
_prev[candidate] = group_tail
|
||||||
|
_next[group_tail] = candidate
|
||||||
|
|
||||||
|
nonlocal _head
|
||||||
|
if _head == candidate:
|
||||||
|
_head = group_head
|
||||||
|
|
||||||
|
def _calculate_potential_peak_memory(
|
||||||
|
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate
|
||||||
|
):
|
||||||
|
# Caching calculations of memory for group nodes and candidate,
|
||||||
|
# to apply without recalculation after swap.
|
||||||
|
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
|
||||||
|
potential_peak: int = 0
|
||||||
|
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
|
||||||
|
# Not accounting for buffers last use change
|
||||||
|
potential_peak = max(
|
||||||
|
group_peak_memory - candidate_delta_mem,
|
||||||
|
_curr_memory[group_tail][1]
|
||||||
|
- candidate_delta_mem
|
||||||
|
+ candidate_allocfree.size_alloc,
|
||||||
|
)
|
||||||
|
return potential_peak, _post_alloc_update
|
||||||
|
|
||||||
|
# If candidate will be after group, the starting memory level of group nodes
|
||||||
|
# changes to the -(candidate.size_alloc - candidate.size_free)
|
||||||
|
mem_after_reorder_delta: int = -candidate_delta_mem
|
||||||
|
for gn in gns:
|
||||||
|
gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta
|
||||||
|
_post_alloc_update[gn] = gn_post_alloc_mem
|
||||||
|
potential_peak = max(potential_peak, gn_post_alloc_mem)
|
||||||
|
|
||||||
|
bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None)
|
||||||
|
if bufs is not None:
|
||||||
|
for buf in bufs:
|
||||||
|
# Candidate will deallocate those buffers
|
||||||
|
mem_after_reorder_delta += buf.mpi_buffer.size_free
|
||||||
|
|
||||||
|
candidate_mem_post_alloc = (
|
||||||
|
_curr_memory[group_tail][1]
|
||||||
|
+ mem_after_reorder_delta
|
||||||
|
+ candidate_allocfree.size_alloc
|
||||||
|
)
|
||||||
|
_post_alloc_update[candidate] = candidate_mem_post_alloc
|
||||||
|
potential_peak = max(potential_peak, candidate_mem_post_alloc)
|
||||||
|
return potential_peak, _post_alloc_update
|
||||||
|
|
||||||
|
def _update_memory_tracking_after_swap(
|
||||||
|
candidate,
|
||||||
|
gns,
|
||||||
|
group_n_to_bufs_after_swap_dealloc_by_candidate,
|
||||||
|
_post_alloc_update,
|
||||||
|
):
|
||||||
|
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
|
||||||
|
for gn in gns:
|
||||||
|
cm = _curr_memory[gn]
|
||||||
|
_curr_memory[gn] = (
|
||||||
|
cm[0] - candidate_delta_mem,
|
||||||
|
cm[1] - candidate_delta_mem,
|
||||||
|
)
|
||||||
|
_candidate_post_alloc_mem = (
|
||||||
|
_curr_memory[group_tail][1] + candidate_allocfree.size_alloc
|
||||||
|
)
|
||||||
|
_candidate_post_free_mem = (
|
||||||
|
_candidate_post_alloc_mem - candidate_allocfree.size_free
|
||||||
|
)
|
||||||
|
_curr_memory[candidate] = (
|
||||||
|
_candidate_post_alloc_mem,
|
||||||
|
_candidate_post_free_mem,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Candidate becomes last use of some bufs
|
||||||
|
for (
|
||||||
|
gn,
|
||||||
|
bufs,
|
||||||
|
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
|
||||||
|
for buf in bufs:
|
||||||
|
buf_to_snode_last_use[buf] = candidate
|
||||||
|
|
||||||
|
size_free_to_move_to_candidate_sum: int = 0
|
||||||
|
for n in gns:
|
||||||
|
_gn_post_alloc_mem: int = _post_alloc_update[n]
|
||||||
|
size_free_to_move_to_candidate: int = sum(
|
||||||
|
buf.mpi_buffer.size_free
|
||||||
|
for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n]
|
||||||
|
)
|
||||||
|
size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate
|
||||||
|
# group node does not deallocate this after swap
|
||||||
|
snodes_allocfree[n].size_free -= size_free_to_move_to_candidate
|
||||||
|
gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free
|
||||||
|
_curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem)
|
||||||
|
_candidate_post_alloc_mem = _post_alloc_update[candidate]
|
||||||
|
snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum
|
||||||
|
candidate_post_free_mem = (
|
||||||
|
_candidate_post_alloc_mem - snodes_allocfree[candidate].size_free
|
||||||
|
)
|
||||||
|
_curr_memory[candidate] = (
|
||||||
|
_candidate_post_alloc_mem,
|
||||||
|
candidate_post_free_mem,
|
||||||
|
)
|
||||||
|
|
||||||
|
debug_num_collectives_to_reorder: Optional[int] = (
|
||||||
|
config.reorder_iterative_debug_limit_to_reorder
|
||||||
|
)
|
||||||
|
|
||||||
|
num_processed_collectives: int = 0
|
||||||
curr = _head
|
curr = _head
|
||||||
|
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
|
||||||
|
iterative_recompute_error = False
|
||||||
|
|
||||||
while _next[curr] is not None:
|
while _next[curr] is not None:
|
||||||
|
if iterative_recompute_error:
|
||||||
|
break
|
||||||
if contains_collective(curr):
|
if contains_collective(curr):
|
||||||
reorder_info = stats[curr] = ReorderInfo()
|
if debug_num_collectives_to_reorder is not None and (
|
||||||
reorder_info.initial_exposed = reorder_info.final_exposed = (
|
num_processed_collectives >= debug_num_collectives_to_reorder
|
||||||
exposed_communication_time(curr, _group_nodes(_next[curr], None))
|
):
|
||||||
|
break
|
||||||
|
num_processed_collectives += 1
|
||||||
|
|
||||||
|
info = stats[curr] = ReorderInfo()
|
||||||
|
info.initial_exposed = info.final_exposed = exposed_communication_time(
|
||||||
|
curr, _group_nodes(_next[curr], None)
|
||||||
)
|
)
|
||||||
|
|
||||||
candidate = _prev[curr]
|
candidate = _prev[curr]
|
||||||
group_head = curr
|
group_head = curr
|
||||||
group_tail = curr
|
group_tail = curr
|
||||||
group_peak_memory = _curr_memory[curr]
|
group_peak_memory = _curr_memory[curr][0] # post_alloc memory
|
||||||
while candidate is not None:
|
while candidate is not None:
|
||||||
if contains_collective(candidate):
|
if contains_collective(candidate):
|
||||||
reorder_info.limiting_factor = "collective ordering"
|
info.limiting_factor = "collective ordering"
|
||||||
break
|
break
|
||||||
|
|
||||||
|
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
|
||||||
group = GroupedSchedulerNode(
|
group = GroupedSchedulerNode(
|
||||||
curr.scheduler,
|
curr.scheduler,
|
||||||
_group_nodes(group_head, group_tail),
|
gns,
|
||||||
temp_grouping=True,
|
temp_grouping=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -314,7 +487,9 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||||||
|
|
||||||
if data_dep is not None:
|
if data_dep is not None:
|
||||||
|
|
||||||
def is_groupable(candidate):
|
def is_groupable(
|
||||||
|
candidate: BaseSchedulerNode,
|
||||||
|
) -> tuple[bool, Optional[str]]:
|
||||||
# preserve ordering
|
# preserve ordering
|
||||||
if contains_collective(candidate):
|
if contains_collective(candidate):
|
||||||
return False, "contains_collective"
|
return False, "contains_collective"
|
||||||
@ -323,73 +498,106 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||||||
return False, "contains_gemm_like"
|
return False, "contains_gemm_like"
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
is_grp, grp_reason = is_groupable(candidate)
|
is_groupable_result, grouping_reason = is_groupable(candidate)
|
||||||
if is_grp:
|
if is_groupable_result:
|
||||||
group_head = candidate
|
group_head = candidate
|
||||||
group_peak_memory = max(
|
group_peak_memory = max(
|
||||||
group_peak_memory, _curr_memory[candidate]
|
group_peak_memory, _curr_memory[candidate][0]
|
||||||
)
|
)
|
||||||
reorder_info.grouped += 1
|
info.grouped += 1
|
||||||
reorder_info.grouped_info = _group_names(group_head, group_tail)
|
info.grouped_info = _group_names(gns)
|
||||||
candidate = _prev[candidate]
|
candidate = _prev[candidate]
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
||||||
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})"
|
||||||
f"dep on {_group_names(group_head, group_tail)}"
|
f"dep on {_group_names(gns)}"
|
||||||
f"\n non_group_reason:{grp_reason}"
|
f"\n non_group_reason:{grouping_reason}"
|
||||||
)
|
)
|
||||||
reorder_info.limiting_factor = msg
|
info.limiting_factor = msg
|
||||||
break
|
break
|
||||||
|
|
||||||
delta_memory_candidate = (
|
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
|
||||||
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
|
candidate_delta_mem: int = (
|
||||||
|
candidate_allocfree.size_alloc - candidate_allocfree.size_free
|
||||||
|
)
|
||||||
|
# candidate and one of group nodes are successors of the same buffer
|
||||||
|
# and last use of the buffer happen in group nodes.
|
||||||
|
# This last use deallocates it.
|
||||||
|
# If we swap [candidate [group]] to [[group] candidate],
|
||||||
|
# candidate becomes the last use
|
||||||
|
# and deallocated this buffer instead of group node.
|
||||||
|
# we need to update size_free accordingly to group_node and candidate,
|
||||||
|
# and recalculate post_alloc, post_free for them.
|
||||||
|
#
|
||||||
|
# Buf that changes its last use snode,
|
||||||
|
# after swap will be deallocated only by candidate,
|
||||||
|
# while before it was deallocated by group node.
|
||||||
|
group_n_to_bufs_after_swap_dealloc_by_candidate: dict[
|
||||||
|
BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
|
||||||
|
] = defaultdict(list)
|
||||||
|
for (
|
||||||
|
buf,
|
||||||
|
snode_last_use,
|
||||||
|
) in buf_to_snode_last_use.items():
|
||||||
|
succ_nodes = buf.mpi_buffer.succ_nodes
|
||||||
|
if candidate not in succ_nodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not any(gn == snode_last_use for gn in gns):
|
||||||
|
continue
|
||||||
|
|
||||||
|
group_n_to_bufs_after_swap_dealloc_by_candidate[
|
||||||
|
snode_last_use
|
||||||
|
].append(buf)
|
||||||
|
|
||||||
|
potential_peak, _post_alloc_update = _calculate_potential_peak_memory(
|
||||||
|
candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate
|
||||||
)
|
)
|
||||||
|
|
||||||
if group_peak_memory - delta_memory_candidate > peak_memory:
|
if potential_peak > peak_memory:
|
||||||
reorder_info.limiting_factor = "peak memory"
|
info.limiting_factor = (
|
||||||
|
f"peak memory new:{potential_peak} vs base:{peak_memory}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
info.moves += 1
|
||||||
reorder_info.moves += 1
|
|
||||||
total_moves += 1
|
total_moves += 1
|
||||||
|
|
||||||
mem_deltas = {}
|
_perform_double_linked_list_swap(candidate, group_head, group_tail)
|
||||||
for n in [candidate, *_group_nodes(group_head, group_tail)]:
|
|
||||||
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
|
|
||||||
# swap (candidate, group_head...group_tail)
|
|
||||||
# Before:
|
|
||||||
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
|
|
||||||
# After:
|
|
||||||
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
|
|
||||||
# 0
|
|
||||||
candidate_prev = _prev[candidate]
|
|
||||||
if candidate_prev:
|
|
||||||
_next[candidate_prev] = group_head
|
|
||||||
_prev[group_head] = candidate_prev
|
|
||||||
|
|
||||||
# 2
|
info.final_exposed = exposed_communication_time(
|
||||||
group_tail_next = _next[group_tail]
|
|
||||||
if group_tail_next:
|
|
||||||
_prev[group_tail_next] = candidate
|
|
||||||
_next[candidate] = group_tail_next
|
|
||||||
|
|
||||||
# 1
|
|
||||||
_prev[candidate] = group_tail
|
|
||||||
_next[group_tail] = candidate
|
|
||||||
|
|
||||||
if _head == candidate:
|
|
||||||
_head = group_head
|
|
||||||
|
|
||||||
reorder_info.final_exposed = exposed_communication_time(
|
|
||||||
curr, _group_nodes(_next[curr], None)
|
curr, _group_nodes(_next[curr], None)
|
||||||
)
|
)
|
||||||
# Recompute curr_memory
|
|
||||||
_prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index]
|
_update_memory_tracking_after_swap(
|
||||||
for n in _group_nodes(group_head, candidate):
|
candidate,
|
||||||
_curr_memory[n] = _prev_curr_memory = (
|
gns,
|
||||||
_prev_curr_memory + mem_deltas[n]
|
group_n_to_bufs_after_swap_dealloc_by_candidate,
|
||||||
|
_post_alloc_update,
|
||||||
|
)
|
||||||
|
|
||||||
|
if debug_iterative_memory_recompute:
|
||||||
|
# Compare iteratively recomputed memory data
|
||||||
|
# with full run of estimate_peak_memory
|
||||||
|
|
||||||
|
from .comms_debug import _debug_iterative_memory_recompute
|
||||||
|
|
||||||
|
iterative_recompute_error = _debug_iterative_memory_recompute(
|
||||||
|
candidate,
|
||||||
|
gns,
|
||||||
|
_group_names(gns),
|
||||||
|
_group_nodes(_head, None),
|
||||||
|
name_to_freeable_input_buf,
|
||||||
|
graph_outputs,
|
||||||
|
peak_memory,
|
||||||
|
_curr_memory,
|
||||||
|
snodes_allocfree,
|
||||||
|
"reorder_communication_preserving_peak_memory",
|
||||||
|
group_n_to_bufs_after_swap_dealloc_by_candidate,
|
||||||
)
|
)
|
||||||
|
if iterative_recompute_error:
|
||||||
|
break
|
||||||
candidate = _prev[group_head]
|
candidate = _prev[group_head]
|
||||||
curr = _next[curr] # type: ignore[assignment]
|
curr = _next[curr] # type: ignore[assignment]
|
||||||
|
|
||||||
@ -415,15 +623,15 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||||||
rows = [
|
rows = [
|
||||||
[
|
[
|
||||||
node_summary(snode),
|
node_summary(snode),
|
||||||
node_reorder_info.initial_exposed,
|
node_info.initial_exposed,
|
||||||
node_reorder_info.final_exposed,
|
node_info.final_exposed,
|
||||||
node_reorder_info.improvement,
|
node_info.improvement,
|
||||||
node_reorder_info.limiting_factor,
|
node_info.limiting_factor,
|
||||||
node_reorder_info.moves,
|
node_info.moves,
|
||||||
node_reorder_info.grouped,
|
node_info.grouped,
|
||||||
node_reorder_info.grouped_info,
|
node_info.grouped_info,
|
||||||
]
|
]
|
||||||
for snode, node_reorder_info in node_stats.items()
|
for snode, node_info in node_stats.items()
|
||||||
]
|
]
|
||||||
if importlib.util.find_spec("tabulate"):
|
if importlib.util.find_spec("tabulate"):
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
@ -441,7 +649,7 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||||||
|
|
||||||
new_snodes = _group_nodes(_head, None)
|
new_snodes = _group_nodes(_head, None)
|
||||||
assert len(new_snodes) == original_snodes_num
|
assert len(new_snodes) == original_snodes_num
|
||||||
new_peak_memory, curr_memory = estimate_peak_memory(
|
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
|
||||||
new_snodes, name_to_freeable_input_buf, graph_outputs
|
new_snodes, name_to_freeable_input_buf, graph_outputs
|
||||||
)
|
)
|
||||||
reorder_log_str += f"\n peak_memory_before:{peak_memory}"
|
reorder_log_str += f"\n peak_memory_before:{peak_memory}"
|
||||||
@ -657,24 +865,21 @@ def _sink_waits_iterative_internal(
|
|||||||
return snodes, {}
|
return snodes, {}
|
||||||
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
||||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
(
|
||||||
snodes, graph_inputs
|
peak_memory,
|
||||||
)
|
_curr_memory,
|
||||||
peak_memory, curr_memory = estimate_peak_memory(
|
snodes_allocfree,
|
||||||
snodes, name_to_freeable_input_buf, graph_outputs
|
buf_to_snode_last_use,
|
||||||
)
|
name_to_freeable_input_buf,
|
||||||
|
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
|
||||||
|
|
||||||
|
_prev, _next, _head = _initialize_double_linked_list(snodes)
|
||||||
|
|
||||||
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
|
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
|
||||||
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
|
||||||
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
|
||||||
_head = snodes[0]
|
|
||||||
for i, snode in enumerate(snodes):
|
|
||||||
_prev[snode] = snodes[i - 1] if i > 0 else None
|
|
||||||
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
|
||||||
_curr_memory = dict(zip(snodes, curr_memory))
|
|
||||||
_curr_memory[None] = 0 # type: ignore[index]
|
|
||||||
|
|
||||||
def _group_nodes(head, tail):
|
def _group_nodes(
|
||||||
|
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
|
||||||
|
) -> list[BaseSchedulerNode]:
|
||||||
ret = []
|
ret = []
|
||||||
n = head
|
n = head
|
||||||
while True:
|
while True:
|
||||||
@ -682,21 +887,125 @@ def _sink_waits_iterative_internal(
|
|||||||
ret.append(n)
|
ret.append(n)
|
||||||
if n == tail:
|
if n == tail:
|
||||||
break
|
break
|
||||||
n = _next[n]
|
n = _next[n] # type: ignore[index]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _group_names(head, tail):
|
def _calculate_potential_peak_memory(
|
||||||
ret = ""
|
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate
|
||||||
for n in _group_nodes(head, tail):
|
):
|
||||||
if ret:
|
pre_group_mem = (
|
||||||
ret += "~"
|
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
|
||||||
ret += n.get_name()
|
)
|
||||||
return ret
|
# Stash memory tracing updates to not recompute them after swap
|
||||||
|
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
|
||||||
|
_size_free_delta_update: dict[BaseSchedulerNode, int] = {}
|
||||||
|
|
||||||
|
potential_peak = 0
|
||||||
|
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
|
||||||
|
# Not accounting for buffers liveliness change
|
||||||
|
potential_peak = max(
|
||||||
|
group_peak_memory + candidate_delta_mem,
|
||||||
|
pre_group_mem + candidate_allocfree.size_alloc,
|
||||||
|
)
|
||||||
|
return potential_peak, _post_alloc_update, _size_free_delta_update
|
||||||
|
|
||||||
|
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
|
||||||
|
_post_alloc_update[candidate] = candidate_post_alloc
|
||||||
|
potential_peak = candidate_post_alloc
|
||||||
|
candidate_size_free_to_move = sum(
|
||||||
|
buf.mpi_buffer.size_free # type: ignore[attr-defined]
|
||||||
|
for buf in itertools.chain.from_iterable(
|
||||||
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_size_free_delta_update[candidate] = -candidate_size_free_to_move
|
||||||
|
delta_mem = candidate_delta_mem + candidate_size_free_to_move
|
||||||
|
for gn in gns:
|
||||||
|
gn_post_alloc = _curr_memory[gn][0] + delta_mem
|
||||||
|
_post_alloc_update[gn] = gn_post_alloc
|
||||||
|
potential_peak = max(potential_peak, gn_post_alloc)
|
||||||
|
gn_size_free_to_add = 0
|
||||||
|
if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
|
||||||
|
bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn]
|
||||||
|
for buf in bufs:
|
||||||
|
gn_size_free_to_add += buf.mpi_buffer.size_free
|
||||||
|
_size_free_delta_update[gn] = gn_size_free_to_add
|
||||||
|
delta_mem -= gn_size_free_to_add
|
||||||
|
return potential_peak, _post_alloc_update, _size_free_delta_update
|
||||||
|
|
||||||
|
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
|
||||||
|
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
|
||||||
|
# 0:
|
||||||
|
group_head_prev = _prev[group_head]
|
||||||
|
if group_head_prev:
|
||||||
|
_next[group_head_prev] = candidate
|
||||||
|
_prev[candidate] = group_head_prev
|
||||||
|
|
||||||
|
# 2:
|
||||||
|
candidate_next = _next[candidate]
|
||||||
|
if candidate_next:
|
||||||
|
_prev[candidate_next] = group_tail
|
||||||
|
_next[group_tail] = candidate_next
|
||||||
|
|
||||||
|
# 1:
|
||||||
|
_prev[group_head] = candidate
|
||||||
|
_next[candidate] = group_head
|
||||||
|
nonlocal _head
|
||||||
|
if group_head == _head:
|
||||||
|
_head = candidate
|
||||||
|
|
||||||
|
def _update_memory_tracking_after_swap(
|
||||||
|
candidate,
|
||||||
|
gns,
|
||||||
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
||||||
|
_post_alloc_update,
|
||||||
|
_size_free_delta_update,
|
||||||
|
):
|
||||||
|
group_head = gns[0]
|
||||||
|
pre_group_mem = (
|
||||||
|
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
|
||||||
|
)
|
||||||
|
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
|
||||||
|
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
|
||||||
|
_curr_memory[candidate] = (
|
||||||
|
candidate_post_alloc,
|
||||||
|
candidate_post_alloc - candidate_allocfree.size_free,
|
||||||
|
)
|
||||||
|
for gn in gns:
|
||||||
|
cm = _curr_memory[gn]
|
||||||
|
_curr_memory[gn] = (
|
||||||
|
cm[0] + candidate_delta_mem,
|
||||||
|
cm[1] + candidate_delta_mem,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for n in [candidate, *gns]:
|
||||||
|
post_alloc = _post_alloc_update[n]
|
||||||
|
snodes_allocfree[n].size_free += _size_free_delta_update[n]
|
||||||
|
_curr_memory[n] = (
|
||||||
|
post_alloc,
|
||||||
|
post_alloc - snodes_allocfree[n].size_free,
|
||||||
|
)
|
||||||
|
|
||||||
curr = snodes[-1]
|
curr = snodes[-1]
|
||||||
|
|
||||||
processed_waits = OrderedSet() # type: ignore[var-annotated]
|
processed_waits = OrderedSet() # type: ignore[var-annotated]
|
||||||
|
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
|
||||||
|
debug_num_sink_waits_to_reorder: Optional[int] = (
|
||||||
|
config.sink_waits_iterative_debug_limit_to_sink
|
||||||
|
)
|
||||||
|
|
||||||
|
iterative_recompute_error = False
|
||||||
|
|
||||||
while _prev[curr] is not None:
|
while _prev[curr] is not None:
|
||||||
|
if iterative_recompute_error:
|
||||||
|
break
|
||||||
|
if (
|
||||||
|
debug_num_sink_waits_to_reorder is not None
|
||||||
|
and len(processed_waits) >= debug_num_sink_waits_to_reorder
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
if contains_wait(curr) and curr not in processed_waits:
|
if contains_wait(curr) and curr not in processed_waits:
|
||||||
processed_waits.add(curr)
|
processed_waits.add(curr)
|
||||||
info = stats[curr] = SinkWaitInfo()
|
info = stats[curr] = SinkWaitInfo()
|
||||||
@ -704,11 +1013,14 @@ def _sink_waits_iterative_internal(
|
|||||||
wait_snode = curr
|
wait_snode = curr
|
||||||
group_head = curr
|
group_head = curr
|
||||||
group_tail = curr
|
group_tail = curr
|
||||||
group_peak_memory = _curr_memory[curr]
|
group_peak_memory = _curr_memory[curr][0]
|
||||||
while candidate is not None:
|
while candidate is not None:
|
||||||
|
if iterative_recompute_error:
|
||||||
|
break
|
||||||
|
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
|
||||||
group = GroupedSchedulerNode(
|
group = GroupedSchedulerNode(
|
||||||
wait_snode.scheduler,
|
wait_snode.scheduler,
|
||||||
_group_nodes(group_head, group_tail),
|
gns,
|
||||||
temp_grouping=True,
|
temp_grouping=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -753,15 +1065,15 @@ def _sink_waits_iterative_internal(
|
|||||||
if is_grp:
|
if is_grp:
|
||||||
group_tail = candidate
|
group_tail = candidate
|
||||||
group_peak_memory = max(
|
group_peak_memory = max(
|
||||||
group_peak_memory, _curr_memory[candidate]
|
group_peak_memory, _curr_memory[candidate][0]
|
||||||
)
|
)
|
||||||
info.grouped += 1
|
info.grouped += 1
|
||||||
info.grouped_info = _group_names(group_head, group_tail)
|
info.grouped_info = _group_names(gns)
|
||||||
candidate = _next[candidate]
|
candidate = _next[candidate]
|
||||||
continue
|
continue
|
||||||
elif (data_dep is None) and both_contain_comms:
|
elif (data_dep is None) and both_contain_comms:
|
||||||
info.limiting_factor = (
|
info.limiting_factor = (
|
||||||
f"collective ordering {_group_names(group_head, group_tail)}"
|
f"collective ordering {_group_names(gns)}"
|
||||||
f" with candidate:{candidate.get_name()}"
|
f" with candidate:{candidate.get_name()}"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
@ -769,49 +1081,89 @@ def _sink_waits_iterative_internal(
|
|||||||
info.limiting_factor = (
|
info.limiting_factor = (
|
||||||
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
||||||
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
||||||
f"dep on {_group_names(group_head, group_tail)}"
|
f"dep on {gns}"
|
||||||
f"\n outs:{[o.get_name() for o in group_outs]}"
|
f"\n outs:{[o.get_name() for o in group_outs]}"
|
||||||
f"\n non_group_reason:{grp_reason}"
|
f"\n non_group_reason:{grp_reason}"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
candidate_delta_memory = (
|
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
|
||||||
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
|
candidate_delta_mem = (
|
||||||
|
candidate_allocfree.size_alloc - candidate_allocfree.size_free
|
||||||
)
|
)
|
||||||
if group_peak_memory + candidate_delta_memory > peak_memory:
|
# [group] candidate -> candidate [group]
|
||||||
info.limiting_factor = "peak_memory"
|
# Check for buffers with successors in group and candidate last successor
|
||||||
|
#
|
||||||
|
# Buf that changes its last use snode,
|
||||||
|
# It was deallocated by candidate,
|
||||||
|
# but after swap it will be deallocated by group node.
|
||||||
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[
|
||||||
|
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
|
||||||
|
] = defaultdict(list)
|
||||||
|
for (
|
||||||
|
buf,
|
||||||
|
snode_last_use,
|
||||||
|
) in buf_to_snode_last_use.items():
|
||||||
|
succ_nodes = buf.mpi_buffer.succ_nodes
|
||||||
|
if snode_last_use != candidate: # noqa: E711
|
||||||
|
continue
|
||||||
|
# candidate is last use of buf
|
||||||
|
last_succ_gn = None
|
||||||
|
for gn in gns:
|
||||||
|
if gn in succ_nodes:
|
||||||
|
last_succ_gn = gn
|
||||||
|
if last_succ_gn is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# gn has successors of buf that after potential swap will become
|
||||||
|
# last use of buf and start deallocating buf instead of candidate
|
||||||
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate[
|
||||||
|
last_succ_gn
|
||||||
|
].append(buf)
|
||||||
|
|
||||||
|
potential_peak, _post_alloc_update, _size_free_delta_update = (
|
||||||
|
_calculate_potential_peak_memory(
|
||||||
|
candidate,
|
||||||
|
gns,
|
||||||
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if potential_peak > peak_memory:
|
||||||
|
info.limiting_factor = (
|
||||||
|
f"peak memory new:{potential_peak} vs base:{peak_memory}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
info.moves += 1
|
info.moves += 1
|
||||||
info.moves_info += f"+{candidate.get_name()}"
|
info.moves_info += f"+{candidate.get_name()}"
|
||||||
|
|
||||||
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
|
_perform_double_linked_list_swap(candidate, group_head, group_tail)
|
||||||
mem_deltas = {}
|
|
||||||
for n in [candidate, *_group_nodes(group_head, group_tail)]:
|
|
||||||
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
|
|
||||||
# 0:
|
|
||||||
group_head_prev = _prev[group_head]
|
|
||||||
if group_head_prev:
|
|
||||||
_next[group_head_prev] = candidate
|
|
||||||
_prev[candidate] = group_head_prev
|
|
||||||
|
|
||||||
# 2:
|
_update_memory_tracking_after_swap(
|
||||||
candidate_next = _next[candidate]
|
candidate,
|
||||||
if candidate_next:
|
gns,
|
||||||
_prev[candidate_next] = group_tail
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
||||||
_next[group_tail] = candidate_next
|
_post_alloc_update,
|
||||||
|
_size_free_delta_update,
|
||||||
|
)
|
||||||
|
|
||||||
# 1:
|
if debug_iterative_memory_recompute:
|
||||||
_prev[group_head] = candidate
|
from .comms_debug import _debug_iterative_memory_recompute
|
||||||
_next[candidate] = group_head
|
|
||||||
if group_head == _head:
|
|
||||||
_head = candidate
|
|
||||||
|
|
||||||
# Recompute curr_memory
|
iterative_recompute_error = _debug_iterative_memory_recompute(
|
||||||
_prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index]
|
candidate,
|
||||||
for n in _group_nodes(candidate, group_tail):
|
gns,
|
||||||
_curr_memory[n] = _prev_curr_memory = (
|
_group_names(gns),
|
||||||
_prev_curr_memory + mem_deltas[n]
|
_group_nodes(_head, None),
|
||||||
|
name_to_freeable_input_buf,
|
||||||
|
graph_outputs,
|
||||||
|
peak_memory,
|
||||||
|
_curr_memory,
|
||||||
|
snodes_allocfree,
|
||||||
|
"sink_waits_iterative",
|
||||||
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
||||||
)
|
)
|
||||||
|
if iterative_recompute_error:
|
||||||
|
break
|
||||||
|
|
||||||
candidate = _next[group_tail]
|
candidate = _next[group_tail]
|
||||||
curr = _prev[curr] # type: ignore[assignment]
|
curr = _prev[curr] # type: ignore[assignment]
|
||||||
@ -850,11 +1202,11 @@ def _sink_waits_iterative_internal(
|
|||||||
overlap_log.info(log_str)
|
overlap_log.info(log_str)
|
||||||
new_snodes = _group_nodes(_head, None)
|
new_snodes = _group_nodes(_head, None)
|
||||||
assert len(new_snodes) == original_snodes_num
|
assert len(new_snodes) == original_snodes_num
|
||||||
new_peak_memory, curr_memory = estimate_peak_memory(
|
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
|
||||||
new_snodes, name_to_freeable_input_buf, graph_outputs
|
new_snodes, name_to_freeable_input_buf, graph_outputs
|
||||||
)
|
)
|
||||||
log_str += f"\n peak_memory_before:{peak_memory}"
|
log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}"
|
||||||
log_str += f"\n peak_memory_after:{new_peak_memory}"
|
log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}"
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
metadata_fn=lambda: {
|
metadata_fn=lambda: {
|
||||||
|
112
torch/_inductor/comms_debug.py
Normal file
112
torch/_inductor/comms_debug.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
from torch._logging import trace_structured
|
||||||
|
|
||||||
|
from .memory import estimate_peak_memory_allocfree
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
from .memory import FreeableInputBuffer, SNodeMemory
|
||||||
|
from .scheduler import BaseSchedulerNode, SchedulerBuffer
|
||||||
|
|
||||||
|
|
||||||
|
def _debug_iterative_memory_recompute(
|
||||||
|
candidate: BaseSchedulerNode,
|
||||||
|
gns: list[BaseSchedulerNode],
|
||||||
|
group_names: str,
|
||||||
|
snodes: list[BaseSchedulerNode],
|
||||||
|
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||||
|
graph_outputs: OrderedSet[str],
|
||||||
|
peak_memory: int,
|
||||||
|
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
|
||||||
|
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
|
||||||
|
tlparse_name: str,
|
||||||
|
gn_to_bufs_last_use: dict[
|
||||||
|
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
|
||||||
|
],
|
||||||
|
) -> bool:
|
||||||
|
iterative_recompute_error = False
|
||||||
|
candidate_allocfree = snodes_allocfree[candidate]
|
||||||
|
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
|
||||||
|
estimate_peak_memory_allocfree(
|
||||||
|
snodes, name_to_freeable_input_buf, graph_outputs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
|
||||||
|
iter_cm = iter_curr_memory[candidate]
|
||||||
|
new_cm = est_curr_memory[candidate]
|
||||||
|
log = ""
|
||||||
|
if est_peak_memory > peak_memory:
|
||||||
|
log = "ITERATIVE PEAK DOES NOT MATCH"
|
||||||
|
iterative_recompute_error = True
|
||||||
|
if iter_cm != new_cm:
|
||||||
|
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
|
||||||
|
iterative_recompute_error = True
|
||||||
|
for i, gn in enumerate(gns):
|
||||||
|
iter_gnm = iter_curr_memory[gn]
|
||||||
|
new_gnm = est_curr_memory[gn]
|
||||||
|
if iter_gnm != new_gnm:
|
||||||
|
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
|
||||||
|
iterative_recompute_error = True
|
||||||
|
if iterative_recompute_error:
|
||||||
|
log += (
|
||||||
|
f"\nCANDIDATE:{candidate.get_name()}"
|
||||||
|
f"\nGROUP:{group_names}"
|
||||||
|
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
|
||||||
|
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
|
||||||
|
f"\nCANDIDATE:{candidate.debug_str()}"
|
||||||
|
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
|
||||||
|
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
|
||||||
|
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
|
||||||
|
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
|
||||||
|
)
|
||||||
|
peak_log = ""
|
||||||
|
for i, (pre, post) in enumerate(snodes_curr_memory):
|
||||||
|
if est_peak_memory == pre:
|
||||||
|
n = snodes[i]
|
||||||
|
peak_log = (
|
||||||
|
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
|
||||||
|
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
group_log = ""
|
||||||
|
for i, gn in enumerate(gns):
|
||||||
|
iter_gnm = iter_curr_memory[gn]
|
||||||
|
new_gnm = est_curr_memory[gn]
|
||||||
|
group_log += (
|
||||||
|
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
|
||||||
|
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
|
||||||
|
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
|
||||||
|
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
|
||||||
|
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
|
||||||
|
)
|
||||||
|
log += peak_log
|
||||||
|
log += group_log
|
||||||
|
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
|
||||||
|
log += "\n\n".join(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
f"\nSNODE[{i}]\n{n.debug_str()}"
|
||||||
|
f"\nITER_cur_mem:{iter_curr_memory[n]}"
|
||||||
|
f"\nESTM_cur_mem:{est_curr_memory[n]}"
|
||||||
|
f"\nITER_allocfree:{snodes_allocfree[n]}"
|
||||||
|
f"\nESTM_allocfree:{snodes_allocfree[n]}"
|
||||||
|
)
|
||||||
|
for i, n in enumerate(snodes)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
|
||||||
|
print(f"{tname}:\n{log}")
|
||||||
|
trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": tname,
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: log,
|
||||||
|
)
|
||||||
|
return iterative_recompute_error
|
@ -389,6 +389,16 @@ reorder_prefetch_limit: Optional[int] = None
|
|||||||
# enable operator reordering for peak memory optimization
|
# enable operator reordering for peak memory optimization
|
||||||
reorder_for_peak_memory = True
|
reorder_for_peak_memory = True
|
||||||
|
|
||||||
|
reorder_iterative_debug_memory_recompute: bool = False
|
||||||
|
reorder_iterative_debug_limit_to_reorder: Optional[int] = (
|
||||||
|
None
|
||||||
|
if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None
|
||||||
|
else int(env_str)
|
||||||
|
)
|
||||||
|
sink_waits_iterative_debug_limit_to_sink: Optional[int] = (
|
||||||
|
None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str)
|
||||||
|
)
|
||||||
|
|
||||||
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
|
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
|
||||||
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
|
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
|
||||||
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None
|
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None
|
||||||
|
@ -4,7 +4,7 @@ import collections
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import heapq
|
import heapq
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, TYPE_CHECKING, TypedDict, Union
|
from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union
|
||||||
|
|
||||||
from torch._environment import is_fbcode
|
from torch._environment import is_fbcode
|
||||||
from torch._utils_internal import signpost_event
|
from torch._utils_internal import signpost_event
|
||||||
@ -76,7 +76,7 @@ def get_freeable_input_buf(
|
|||||||
Create and keep track of all input buffers that can be freed during the program
|
Create and keep track of all input buffers that can be freed during the program
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing all freeble input buffers, keyed by their names.
|
A dictionary containing all freeable input buffers, keyed by their names.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _dep_size_hint(dep: Dep) -> int:
|
def _dep_size_hint(dep: Dep) -> int:
|
||||||
@ -315,7 +315,11 @@ def compute_memory_timeline(
|
|||||||
nodes: list[BaseSchedulerNode],
|
nodes: list[BaseSchedulerNode],
|
||||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||||
graph_outputs: OrderedSet[str],
|
graph_outputs: OrderedSet[str],
|
||||||
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
|
) -> tuple[
|
||||||
|
list[BufferInfo],
|
||||||
|
dict[BaseSchedulerNode, int],
|
||||||
|
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Compute buffer allocation and deallocation sizes and map their
|
Compute buffer allocation and deallocation sizes and map their
|
||||||
lifetime to the node schedule
|
lifetime to the node schedule
|
||||||
@ -329,15 +333,33 @@ def compute_memory_timeline(
|
|||||||
|
|
||||||
# get buffers' size and liveliness information
|
# get buffers' size and liveliness information
|
||||||
buf_info_list: list[BufferInfo] = []
|
buf_info_list: list[BufferInfo] = []
|
||||||
|
buf_to_snode_last_use: dict[
|
||||||
|
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
|
||||||
|
] = {}
|
||||||
|
|
||||||
|
def _get_end_step_and_snode(
|
||||||
|
buf: Union[FreeableInputBuffer, SchedulerBuffer],
|
||||||
|
) -> tuple[int, Optional[BaseSchedulerNode]]:
|
||||||
|
max_step: int = -1
|
||||||
|
max_step_snode: Optional[BaseSchedulerNode] = None
|
||||||
|
succ_nodes = buf.mpi_buffer.succ_nodes
|
||||||
|
if succ_nodes:
|
||||||
|
for succ_node in succ_nodes:
|
||||||
|
step = node_to_step[succ_node]
|
||||||
|
if step > max_step:
|
||||||
|
max_step = step
|
||||||
|
max_step_snode = succ_node
|
||||||
|
assert max_step_snode is not None
|
||||||
|
return max_step, max_step_snode
|
||||||
|
|
||||||
# 1. for freeable input buffers
|
# 1. for freeable input buffers
|
||||||
for buf_name, input_buf in name_to_freeable_input_buf.items():
|
for buf_name, input_buf in name_to_freeable_input_buf.items():
|
||||||
end_step = (
|
end_step = -1
|
||||||
len(nodes) - 1
|
if buf_name not in graph_outputs:
|
||||||
if buf_name in graph_outputs
|
end_step, end_step_snode = _get_end_step_and_snode(input_buf)
|
||||||
else max(
|
assert end_step_snode is not None
|
||||||
node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes
|
buf_to_snode_last_use[input_buf] = end_step_snode
|
||||||
)
|
|
||||||
)
|
|
||||||
buf_info_list.append(
|
buf_info_list.append(
|
||||||
BufferInfo(
|
BufferInfo(
|
||||||
input_buf,
|
input_buf,
|
||||||
@ -354,17 +376,17 @@ def compute_memory_timeline(
|
|||||||
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
|
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
|
||||||
# to be only used by its defining op (e.g., due to fusion when all consumers of
|
# to be only used by its defining op (e.g., due to fusion when all consumers of
|
||||||
# the buffer are fused with its defining op). In such cases, end_step is step.
|
# the buffer are fused with its defining op). In such cases, end_step is step.
|
||||||
end_step = (
|
buf_name = sched_buf.get_name()
|
||||||
len(nodes) - 1
|
end_step = -1
|
||||||
if sched_buf.get_name() in graph_outputs
|
if buf_name not in graph_outputs:
|
||||||
else max(
|
end_step, end_step_snode = _get_end_step_and_snode(sched_buf)
|
||||||
[
|
if end_step == -1:
|
||||||
node_to_step[succ_node]
|
end_step = step
|
||||||
for succ_node in sched_buf.mpi_buffer.succ_nodes
|
buf_to_snode_last_use[sched_buf] = node
|
||||||
],
|
else:
|
||||||
default=step,
|
assert end_step_snode is not None
|
||||||
)
|
buf_to_snode_last_use[sched_buf] = end_step_snode
|
||||||
)
|
|
||||||
buf_info_list.append(
|
buf_info_list.append(
|
||||||
BufferInfo(
|
BufferInfo(
|
||||||
sched_buf,
|
sched_buf,
|
||||||
@ -375,7 +397,7 @@ def compute_memory_timeline(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return buf_info_list, node_to_step
|
return buf_info_list, node_to_step, buf_to_snode_last_use
|
||||||
|
|
||||||
|
|
||||||
def estimate_peak_memory(
|
def estimate_peak_memory(
|
||||||
@ -392,7 +414,7 @@ def estimate_peak_memory(
|
|||||||
List[int]: memory usage at each node (or each step).
|
List[int]: memory usage at each node (or each step).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
buf_info_list, _ = compute_memory_timeline(
|
buf_info_list, _, _ = compute_memory_timeline(
|
||||||
nodes, name_to_freeable_input_buf, graph_outputs
|
nodes, name_to_freeable_input_buf, graph_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -416,6 +438,73 @@ def estimate_peak_memory(
|
|||||||
return (max_memory, memories_at_nodes)
|
return (max_memory, memories_at_nodes)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SNodeMemory:
|
||||||
|
size_alloc: int
|
||||||
|
size_free: int
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_peak_memory_allocfree(
|
||||||
|
nodes: list[BaseSchedulerNode],
|
||||||
|
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||||
|
graph_outputs: OrderedSet[str],
|
||||||
|
) -> tuple[
|
||||||
|
int,
|
||||||
|
list[tuple[int, int]],
|
||||||
|
dict[BaseSchedulerNode, SNodeMemory],
|
||||||
|
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Alternative version of estimate_peak_memory, that respects the fact,
|
||||||
|
that every SchedulerNode has multiple phases:
|
||||||
|
1. alloc ( outputs )
|
||||||
|
2. run_kernel
|
||||||
|
3. dealloc last_use buffers
|
||||||
|
estimate_peak_memory collapses memory into one value: size_alloc - size_free
|
||||||
|
While peak memory happens after alloc.
|
||||||
|
|
||||||
|
Duplicating the code to not migrate all callsites at once,
|
||||||
|
In future usages of estimate_peak_memory will migrate to this version.
|
||||||
|
"""
|
||||||
|
|
||||||
|
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
|
||||||
|
nodes, name_to_freeable_input_buf, graph_outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
# incremental memory changes at each step
|
||||||
|
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))]
|
||||||
|
|
||||||
|
# for each buffer, update memory when created and when freed
|
||||||
|
for buf_info in buf_info_list:
|
||||||
|
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc
|
||||||
|
if buf_info.end_step != -1:
|
||||||
|
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
|
||||||
|
|
||||||
|
snodes_allocfree = {}
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
snodes_allocfree[node] = step_idx_allocfree[i]
|
||||||
|
|
||||||
|
max_memory = 0
|
||||||
|
cur_memory = 0
|
||||||
|
snodes_curr_memory = []
|
||||||
|
for t in range(len(nodes)):
|
||||||
|
alloc = step_idx_allocfree[t].size_alloc
|
||||||
|
free = step_idx_allocfree[t].size_free
|
||||||
|
cur_memory += alloc
|
||||||
|
post_alloc = cur_memory
|
||||||
|
max_memory = max(max_memory, cur_memory)
|
||||||
|
cur_memory -= free
|
||||||
|
post_free = cur_memory
|
||||||
|
snodes_curr_memory.append((post_alloc, post_free))
|
||||||
|
|
||||||
|
return (
|
||||||
|
max_memory,
|
||||||
|
snodes_curr_memory,
|
||||||
|
snodes_allocfree,
|
||||||
|
buf_to_snode_last_use,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def topological_sort_lpmf(
|
def topological_sort_lpmf(
|
||||||
nodes: list[BaseSchedulerNode],
|
nodes: list[BaseSchedulerNode],
|
||||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||||
@ -429,7 +518,7 @@ def topological_sort_lpmf(
|
|||||||
Buffer memory optimization for video codec application modeled in Simulink
|
Buffer memory optimization for video codec application modeled in Simulink
|
||||||
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
|
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
|
||||||
|
|
||||||
The algorithm maintain the max memory so far.
|
The algorithm maintains the max memory so far.
|
||||||
At every iteration, for each scheduleable node, it computes:
|
At every iteration, for each scheduleable node, it computes:
|
||||||
- how much memory needs to be allocated for the output buffers of this node;
|
- how much memory needs to be allocated for the output buffers of this node;
|
||||||
- how much memory can be freed as a result of executing this node.
|
- how much memory can be freed as a result of executing this node.
|
||||||
|
@ -2160,6 +2160,12 @@ class Scheduler:
|
|||||||
OrderedSet(V.graph.get_output_names()),
|
OrderedSet(V.graph.get_output_names()),
|
||||||
)
|
)
|
||||||
if config.reorder_for_compute_comm_overlap:
|
if config.reorder_for_compute_comm_overlap:
|
||||||
|
if not config.reorder_for_peak_memory:
|
||||||
|
from .memory import assign_memory_planning_info_for_scheduler_buffers
|
||||||
|
|
||||||
|
assign_memory_planning_info_for_scheduler_buffers(
|
||||||
|
self.nodes, self.name_to_buf
|
||||||
|
)
|
||||||
from torch._logging import trace_structured
|
from torch._logging import trace_structured
|
||||||
|
|
||||||
trace_structured(
|
trace_structured(
|
||||||
@ -2556,7 +2562,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||||
buf_info_list, _ = compute_memory_timeline(
|
buf_info_list, _, _ = compute_memory_timeline(
|
||||||
self.nodes,
|
self.nodes,
|
||||||
name_to_freeable_input_buf,
|
name_to_freeable_input_buf,
|
||||||
graph_outputs,
|
graph_outputs,
|
||||||
|
Reference in New Issue
Block a user