[inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)

1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory:
```
    """
    Alternative version of estimate_peak_memory, that respects the fact,
    that every SchedulerNode has multiple phases:
    1. alloc ( outputs )
    2. run_kernel
    3. dealloc last_use buffers
    estimate_peak_memory collapses memory into one value: size_alloc - size_free
    While peak memory happens after alloc.

    Duplicating the code to not migrate all callsites at once,
    In future usages of estimate_peak_memory will migrate to this version.
    """
```

- Applying this in `reorder_communication_preserving_peak_memory` pass.

2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode.

- Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size).

4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder.

What is after this PR:

Iterative recomputation of memory estimations matches full memory estimations.

Active memory is not regressing a lot, but reserved memory is significantly regressed.

Investigation and fix of "reserved" memory will be in following PRs.

BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb
```
[rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step:  1  loss: 12.2722  grad_norm:  4.2192  active_memory: 24.66GiB(25.96%)  reserved_memory: 25.38GiB(26.72%)  tps: 99  tflops: 5.71  mfu: 0.58%
[rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step:  2  loss: 13.1738  grad_norm: 50.5566  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 4,448  tflops: 257.63  mfu: 26.05%
[rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step:  3  loss: 15.6866  grad_norm: 80.0862  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,900  tflops: 341.72  mfu: 34.55%
[rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step:  4  loss: 13.4853  grad_norm:  7.8538  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,881  tflops: 340.57  mfu: 34.44%
[rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step:  5  loss: 16.1191  grad_norm: 53.2481  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,867  tflops: 339.77  mfu: 34.35%
```
REORDER: active: 32Gb reserved: 36Gb
```
[rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step:  1  loss: 12.2490  grad_norm:  4.1944  active_memory: 24.66GiB(25.96%)  reserved_memory: 26.81GiB(28.22%)  tps: 85  tflops: 4.90  mfu: 0.50%
[rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step:  2  loss: 13.1427  grad_norm: 39.5942  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 3,205  tflops: 185.61  mfu: 18.77%
[rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step:  3  loss: 14.6084  grad_norm: 51.0743  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,688  tflops: 329.44  mfu: 33.31%
[rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step:  4  loss: 13.6181  grad_norm:  8.1122  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,744  tflops: 332.68  mfu: 33.64%
[rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step:  5  loss: 15.8913  grad_norm: 59.8510  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,046  tflops: 292.22  mfu: 29.55%
```

REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb
```
[rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step:  1  loss: 12.2646  grad_norm:  4.1282  active_memory: 27.60GiB(29.05%)  reserved_memory: 32.49GiB(34.20%)  tps: 173  tflops: 10.00  mfu: 1.01%
[rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step:  2  loss: 13.2353  grad_norm: 42.4234  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,152  tflops: 356.26  mfu: 36.02%
[rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step:  3  loss: 13.8205  grad_norm: 24.0156  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,169  tflops: 357.29  mfu: 36.13%
[rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step:  4  loss: 13.1033  grad_norm:  9.1167  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,183  tflops: 358.10  mfu: 36.21%
[rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step:  5  loss: 16.3530  grad_norm: 51.8118  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,130  tflops: 355.03  mfu: 35.90%
```

Differential Revision: [D80718143](https://our.internmc.facebook.com/intern/diff/D80718143)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113
Approved by: https://github.com/wconstab, https://github.com/eellison

Co-authored-by: eellison <elias.ellison@gmail.com>
This commit is contained in:
IvanKobzarev
2025-08-22 00:47:04 -07:00
committed by PyTorch MergeBot
parent 639b8cc51d
commit db44de4c0d
5 changed files with 749 additions and 180 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import heapq import heapq
import importlib import importlib
import itertools
import logging import logging
import operator import operator
import sys import sys
@ -23,8 +24,15 @@ from .dependencies import WeakDep
if TYPE_CHECKING: if TYPE_CHECKING:
from .ir import IRNode, Operation from .ir import IRNode, Operation
from .scheduler import SchedulerBuffer
from .memory import estimate_peak_memory, FreeableInputBuffer, get_freeable_input_buf from .memory import (
estimate_peak_memory,
estimate_peak_memory_allocfree,
FreeableInputBuffer,
get_freeable_input_buf,
SNodeMemory,
)
from .utils import ( from .utils import (
contains_collective, contains_collective,
contains_wait, contains_wait,
@ -188,6 +196,46 @@ def _is_fake_dep(d):
return isinstance(d, WeakDep) and d.is_fake return isinstance(d, WeakDep) and d.is_fake
def _group_names(gns: list[BaseSchedulerNode]) -> str:
return "~".join([gn.get_name() for gn in gns])
def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs):
"""Initialize memory tracking data structures"""
name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs)
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
_curr_memory = dict(zip(snodes, snodes_curr_memory))
_curr_memory[None] = (0, 0)
return (
peak_memory,
_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
name_to_freeable_input_buf,
)
def _initialize_double_linked_list(
snodes: list[BaseSchedulerNode],
) -> tuple[
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
BaseSchedulerNode,
]:
"""Create double-linked list structure from snodes"""
_prev = {}
_next = {}
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_head = snodes[0]
return _prev, _next, _head
def _reorder_communication_preserving_peak_memory_internal( def _reorder_communication_preserving_peak_memory_internal(
snodes: list[BaseSchedulerNode], snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
@ -211,20 +259,22 @@ def _reorder_communication_preserving_peak_memory_internal(
# heuristic to avoid degenerating to quadratic time # heuristic to avoid degenerating to quadratic time
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( (
snodes, graph_inputs peak_memory,
) _curr_memory,
peak_memory, curr_memory = estimate_peak_memory( snodes_allocfree,
snodes, name_to_freeable_input_buf, graph_outputs buf_to_snode_last_use,
) name_to_freeable_input_buf,
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes} ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
_curr_memory = dict(zip(snodes, curr_memory)) runtimes: dict[BaseSchedulerNode, float] = {
_curr_memory[None] = 0 # type: ignore[index] snode: estimate_op_runtime(snode) for snode in snodes
}
# debug stats # debug stats
stats: dict[BaseSchedulerNode, ReorderInfo] = {} stats: dict[BaseSchedulerNode, ReorderInfo] = {}
def exposed_communication_time(collective_snode, remaining_snodes): def exposed_communication_time(
collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode]
) -> float:
# assumes a linear schedule and computes the overlap of the collective with the remaining nodes # assumes a linear schedule and computes the overlap of the collective with the remaining nodes
comm_time = estimate_op_runtime(collective_snode) comm_time = estimate_op_runtime(collective_snode)
compute_time = 0.0 compute_time = 0.0
@ -236,7 +286,7 @@ def _reorder_communication_preserving_peak_memory_internal(
# we can ignore it. Otherwise, it's the end of the road for overlap opportunities # we can ignore it. Otherwise, it's the end of the road for overlap opportunities
break break
def accumulate_time(_snode): def accumulate_time(_snode: BaseSchedulerNode) -> None:
nonlocal compute_time nonlocal compute_time
compute_time += runtimes[_snode] compute_time += runtimes[_snode]
@ -245,18 +295,11 @@ def _reorder_communication_preserving_peak_memory_internal(
total_moves = 0 total_moves = 0
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping _prev, _next, _head = _initialize_double_linked_list(snodes)
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
_head = snodes[0] def _group_nodes(
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
def _group_nodes(head, tail): ) -> list[BaseSchedulerNode]:
ret = [] ret = []
n = head n = head
while True: while True:
@ -264,37 +307,167 @@ def _reorder_communication_preserving_peak_memory_internal(
ret.append(n) ret.append(n)
if n == tail: if n == tail:
break break
n = _next[n] n = _next[n] # type: ignore[index]
return ret return ret
def _group_names(head, tail): def _perform_double_linked_list_swap(candidate, group_head, group_tail):
ret = "" # swap (candidate, group_head...group_tail)
for n in _group_nodes(head, tail): # Before:
if ret: # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
ret += "~" # After:
ret += n.get_name() # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
return ret # 0
candidate_prev = _prev[candidate]
if candidate_prev:
_next[candidate_prev] = group_head
_prev[group_head] = candidate_prev
# 2
group_tail_next = _next[group_tail]
if group_tail_next:
_prev[group_tail_next] = candidate
_next[candidate] = group_tail_next
# 1
_prev[candidate] = group_tail
_next[group_tail] = candidate
nonlocal _head
if _head == candidate:
_head = group_head
def _calculate_potential_peak_memory(
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate
):
# Caching calculations of memory for group nodes and candidate,
# to apply without recalculation after swap.
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
potential_peak: int = 0
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
# Not accounting for buffers last use change
potential_peak = max(
group_peak_memory - candidate_delta_mem,
_curr_memory[group_tail][1]
- candidate_delta_mem
+ candidate_allocfree.size_alloc,
)
return potential_peak, _post_alloc_update
# If candidate will be after group, the starting memory level of group nodes
# changes to the -(candidate.size_alloc - candidate.size_free)
mem_after_reorder_delta: int = -candidate_delta_mem
for gn in gns:
gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta
_post_alloc_update[gn] = gn_post_alloc_mem
potential_peak = max(potential_peak, gn_post_alloc_mem)
bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None)
if bufs is not None:
for buf in bufs:
# Candidate will deallocate those buffers
mem_after_reorder_delta += buf.mpi_buffer.size_free
candidate_mem_post_alloc = (
_curr_memory[group_tail][1]
+ mem_after_reorder_delta
+ candidate_allocfree.size_alloc
)
_post_alloc_update[candidate] = candidate_mem_post_alloc
potential_peak = max(potential_peak, candidate_mem_post_alloc)
return potential_peak, _post_alloc_update
def _update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_by_candidate,
_post_alloc_update,
):
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
for gn in gns:
cm = _curr_memory[gn]
_curr_memory[gn] = (
cm[0] - candidate_delta_mem,
cm[1] - candidate_delta_mem,
)
_candidate_post_alloc_mem = (
_curr_memory[group_tail][1] + candidate_allocfree.size_alloc
)
_candidate_post_free_mem = (
_candidate_post_alloc_mem - candidate_allocfree.size_free
)
_curr_memory[candidate] = (
_candidate_post_alloc_mem,
_candidate_post_free_mem,
)
return
# Candidate becomes last use of some bufs
for (
gn,
bufs,
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
for buf in bufs:
buf_to_snode_last_use[buf] = candidate
size_free_to_move_to_candidate_sum: int = 0
for n in gns:
_gn_post_alloc_mem: int = _post_alloc_update[n]
size_free_to_move_to_candidate: int = sum(
buf.mpi_buffer.size_free
for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n]
)
size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate
# group node does not deallocate this after swap
snodes_allocfree[n].size_free -= size_free_to_move_to_candidate
gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free
_curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem)
_candidate_post_alloc_mem = _post_alloc_update[candidate]
snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum
candidate_post_free_mem = (
_candidate_post_alloc_mem - snodes_allocfree[candidate].size_free
)
_curr_memory[candidate] = (
_candidate_post_alloc_mem,
candidate_post_free_mem,
)
debug_num_collectives_to_reorder: Optional[int] = (
config.reorder_iterative_debug_limit_to_reorder
)
num_processed_collectives: int = 0
curr = _head curr = _head
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
iterative_recompute_error = False
while _next[curr] is not None: while _next[curr] is not None:
if iterative_recompute_error:
break
if contains_collective(curr): if contains_collective(curr):
reorder_info = stats[curr] = ReorderInfo() if debug_num_collectives_to_reorder is not None and (
reorder_info.initial_exposed = reorder_info.final_exposed = ( num_processed_collectives >= debug_num_collectives_to_reorder
exposed_communication_time(curr, _group_nodes(_next[curr], None)) ):
break
num_processed_collectives += 1
info = stats[curr] = ReorderInfo()
info.initial_exposed = info.final_exposed = exposed_communication_time(
curr, _group_nodes(_next[curr], None)
) )
candidate = _prev[curr] candidate = _prev[curr]
group_head = curr group_head = curr
group_tail = curr group_tail = curr
group_peak_memory = _curr_memory[curr] group_peak_memory = _curr_memory[curr][0] # post_alloc memory
while candidate is not None: while candidate is not None:
if contains_collective(candidate): if contains_collective(candidate):
reorder_info.limiting_factor = "collective ordering" info.limiting_factor = "collective ordering"
break break
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
group = GroupedSchedulerNode( group = GroupedSchedulerNode(
curr.scheduler, curr.scheduler,
_group_nodes(group_head, group_tail), gns,
temp_grouping=True, temp_grouping=True,
) )
@ -314,7 +487,9 @@ def _reorder_communication_preserving_peak_memory_internal(
if data_dep is not None: if data_dep is not None:
def is_groupable(candidate): def is_groupable(
candidate: BaseSchedulerNode,
) -> tuple[bool, Optional[str]]:
# preserve ordering # preserve ordering
if contains_collective(candidate): if contains_collective(candidate):
return False, "contains_collective" return False, "contains_collective"
@ -323,73 +498,106 @@ def _reorder_communication_preserving_peak_memory_internal(
return False, "contains_gemm_like" return False, "contains_gemm_like"
return True, None return True, None
is_grp, grp_reason = is_groupable(candidate) is_groupable_result, grouping_reason = is_groupable(candidate)
if is_grp: if is_groupable_result:
group_head = candidate group_head = candidate
group_peak_memory = max( group_peak_memory = max(
group_peak_memory, _curr_memory[candidate] group_peak_memory, _curr_memory[candidate][0]
) )
reorder_info.grouped += 1 info.grouped += 1
reorder_info.grouped_info = _group_names(group_head, group_tail) info.grouped_info = _group_names(gns)
candidate = _prev[candidate] candidate = _prev[candidate]
continue continue
else: else:
msg = ( msg = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(group_head, group_tail)}" f"dep on {_group_names(gns)}"
f"\n non_group_reason:{grp_reason}" f"\n non_group_reason:{grouping_reason}"
) )
reorder_info.limiting_factor = msg info.limiting_factor = msg
break break
delta_memory_candidate = ( candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] candidate_delta_mem: int = (
candidate_allocfree.size_alloc - candidate_allocfree.size_free
)
# candidate and one of group nodes are successors of the same buffer
# and last use of the buffer happen in group nodes.
# This last use deallocates it.
# If we swap [candidate [group]] to [[group] candidate],
# candidate becomes the last use
# and deallocated this buffer instead of group node.
# we need to update size_free accordingly to group_node and candidate,
# and recalculate post_alloc, post_free for them.
#
# Buf that changes its last use snode,
# after swap will be deallocated only by candidate,
# while before it was deallocated by group node.
group_n_to_bufs_after_swap_dealloc_by_candidate: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
] = defaultdict(list)
for (
buf,
snode_last_use,
) in buf_to_snode_last_use.items():
succ_nodes = buf.mpi_buffer.succ_nodes
if candidate not in succ_nodes:
continue
if not any(gn == snode_last_use for gn in gns):
continue
group_n_to_bufs_after_swap_dealloc_by_candidate[
snode_last_use
].append(buf)
potential_peak, _post_alloc_update = _calculate_potential_peak_memory(
candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate
) )
if group_peak_memory - delta_memory_candidate > peak_memory: if potential_peak > peak_memory:
reorder_info.limiting_factor = "peak memory" info.limiting_factor = (
f"peak memory new:{potential_peak} vs base:{peak_memory}"
)
break break
info.moves += 1
reorder_info.moves += 1
total_moves += 1 total_moves += 1
mem_deltas = {} _perform_double_linked_list_swap(candidate, group_head, group_tail)
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# swap (candidate, group_head...group_tail)
# Before:
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
# After:
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
# 0
candidate_prev = _prev[candidate]
if candidate_prev:
_next[candidate_prev] = group_head
_prev[group_head] = candidate_prev
# 2 info.final_exposed = exposed_communication_time(
group_tail_next = _next[group_tail]
if group_tail_next:
_prev[group_tail_next] = candidate
_next[candidate] = group_tail_next
# 1
_prev[candidate] = group_tail
_next[group_tail] = candidate
if _head == candidate:
_head = group_head
reorder_info.final_exposed = exposed_communication_time(
curr, _group_nodes(_next[curr], None) curr, _group_nodes(_next[curr], None)
) )
# Recompute curr_memory
_prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index] _update_memory_tracking_after_swap(
for n in _group_nodes(group_head, candidate): candidate,
_curr_memory[n] = _prev_curr_memory = ( gns,
_prev_curr_memory + mem_deltas[n] group_n_to_bufs_after_swap_dealloc_by_candidate,
_post_alloc_update,
)
if debug_iterative_memory_recompute:
# Compare iteratively recomputed memory data
# with full run of estimate_peak_memory
from .comms_debug import _debug_iterative_memory_recompute
iterative_recompute_error = _debug_iterative_memory_recompute(
candidate,
gns,
_group_names(gns),
_group_nodes(_head, None),
name_to_freeable_input_buf,
graph_outputs,
peak_memory,
_curr_memory,
snodes_allocfree,
"reorder_communication_preserving_peak_memory",
group_n_to_bufs_after_swap_dealloc_by_candidate,
) )
if iterative_recompute_error:
break
candidate = _prev[group_head] candidate = _prev[group_head]
curr = _next[curr] # type: ignore[assignment] curr = _next[curr] # type: ignore[assignment]
@ -415,15 +623,15 @@ def _reorder_communication_preserving_peak_memory_internal(
rows = [ rows = [
[ [
node_summary(snode), node_summary(snode),
node_reorder_info.initial_exposed, node_info.initial_exposed,
node_reorder_info.final_exposed, node_info.final_exposed,
node_reorder_info.improvement, node_info.improvement,
node_reorder_info.limiting_factor, node_info.limiting_factor,
node_reorder_info.moves, node_info.moves,
node_reorder_info.grouped, node_info.grouped,
node_reorder_info.grouped_info, node_info.grouped_info,
] ]
for snode, node_reorder_info in node_stats.items() for snode, node_info in node_stats.items()
] ]
if importlib.util.find_spec("tabulate"): if importlib.util.find_spec("tabulate"):
from tabulate import tabulate from tabulate import tabulate
@ -441,7 +649,7 @@ def _reorder_communication_preserving_peak_memory_internal(
new_snodes = _group_nodes(_head, None) new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num assert len(new_snodes) == original_snodes_num
new_peak_memory, curr_memory = estimate_peak_memory( new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
new_snodes, name_to_freeable_input_buf, graph_outputs new_snodes, name_to_freeable_input_buf, graph_outputs
) )
reorder_log_str += f"\n peak_memory_before:{peak_memory}" reorder_log_str += f"\n peak_memory_before:{peak_memory}"
@ -657,24 +865,21 @@ def _sink_waits_iterative_internal(
return snodes, {} return snodes, {}
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( (
snodes, graph_inputs peak_memory,
) _curr_memory,
peak_memory, curr_memory = estimate_peak_memory( snodes_allocfree,
snodes, name_to_freeable_input_buf, graph_outputs buf_to_snode_last_use,
) name_to_freeable_input_buf,
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
_prev, _next, _head = _initialize_double_linked_list(snodes)
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_head = snodes[0]
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
def _group_nodes(head, tail): def _group_nodes(
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
) -> list[BaseSchedulerNode]:
ret = [] ret = []
n = head n = head
while True: while True:
@ -682,21 +887,125 @@ def _sink_waits_iterative_internal(
ret.append(n) ret.append(n)
if n == tail: if n == tail:
break break
n = _next[n] n = _next[n] # type: ignore[index]
return ret return ret
def _group_names(head, tail): def _calculate_potential_peak_memory(
ret = "" candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate
for n in _group_nodes(head, tail): ):
if ret: pre_group_mem = (
ret += "~" _curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
ret += n.get_name() )
return ret # Stash memory tracing updates to not recompute them after swap
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
_size_free_delta_update: dict[BaseSchedulerNode, int] = {}
potential_peak = 0
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
# Not accounting for buffers liveliness change
potential_peak = max(
group_peak_memory + candidate_delta_mem,
pre_group_mem + candidate_allocfree.size_alloc,
)
return potential_peak, _post_alloc_update, _size_free_delta_update
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
_post_alloc_update[candidate] = candidate_post_alloc
potential_peak = candidate_post_alloc
candidate_size_free_to_move = sum(
buf.mpi_buffer.size_free # type: ignore[attr-defined]
for buf in itertools.chain.from_iterable(
group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values()
)
)
_size_free_delta_update[candidate] = -candidate_size_free_to_move
delta_mem = candidate_delta_mem + candidate_size_free_to_move
for gn in gns:
gn_post_alloc = _curr_memory[gn][0] + delta_mem
_post_alloc_update[gn] = gn_post_alloc
potential_peak = max(potential_peak, gn_post_alloc)
gn_size_free_to_add = 0
if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn]
for buf in bufs:
gn_size_free_to_add += buf.mpi_buffer.size_free
_size_free_delta_update[gn] = gn_size_free_to_add
delta_mem -= gn_size_free_to_add
return potential_peak, _post_alloc_update, _size_free_delta_update
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
# 2:
candidate_next = _next[candidate]
if candidate_next:
_prev[candidate_next] = group_tail
_next[group_tail] = candidate_next
# 1:
_prev[group_head] = candidate
_next[candidate] = group_head
nonlocal _head
if group_head == _head:
_head = candidate
def _update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
_post_alloc_update,
_size_free_delta_update,
):
group_head = gns[0]
pre_group_mem = (
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
)
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
_curr_memory[candidate] = (
candidate_post_alloc,
candidate_post_alloc - candidate_allocfree.size_free,
)
for gn in gns:
cm = _curr_memory[gn]
_curr_memory[gn] = (
cm[0] + candidate_delta_mem,
cm[1] + candidate_delta_mem,
)
return
for n in [candidate, *gns]:
post_alloc = _post_alloc_update[n]
snodes_allocfree[n].size_free += _size_free_delta_update[n]
_curr_memory[n] = (
post_alloc,
post_alloc - snodes_allocfree[n].size_free,
)
curr = snodes[-1] curr = snodes[-1]
processed_waits = OrderedSet() # type: ignore[var-annotated] processed_waits = OrderedSet() # type: ignore[var-annotated]
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
debug_num_sink_waits_to_reorder: Optional[int] = (
config.sink_waits_iterative_debug_limit_to_sink
)
iterative_recompute_error = False
while _prev[curr] is not None: while _prev[curr] is not None:
if iterative_recompute_error:
break
if (
debug_num_sink_waits_to_reorder is not None
and len(processed_waits) >= debug_num_sink_waits_to_reorder
):
break
if contains_wait(curr) and curr not in processed_waits: if contains_wait(curr) and curr not in processed_waits:
processed_waits.add(curr) processed_waits.add(curr)
info = stats[curr] = SinkWaitInfo() info = stats[curr] = SinkWaitInfo()
@ -704,11 +1013,14 @@ def _sink_waits_iterative_internal(
wait_snode = curr wait_snode = curr
group_head = curr group_head = curr
group_tail = curr group_tail = curr
group_peak_memory = _curr_memory[curr] group_peak_memory = _curr_memory[curr][0]
while candidate is not None: while candidate is not None:
if iterative_recompute_error:
break
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
group = GroupedSchedulerNode( group = GroupedSchedulerNode(
wait_snode.scheduler, wait_snode.scheduler,
_group_nodes(group_head, group_tail), gns,
temp_grouping=True, temp_grouping=True,
) )
@ -753,15 +1065,15 @@ def _sink_waits_iterative_internal(
if is_grp: if is_grp:
group_tail = candidate group_tail = candidate
group_peak_memory = max( group_peak_memory = max(
group_peak_memory, _curr_memory[candidate] group_peak_memory, _curr_memory[candidate][0]
) )
info.grouped += 1 info.grouped += 1
info.grouped_info = _group_names(group_head, group_tail) info.grouped_info = _group_names(gns)
candidate = _next[candidate] candidate = _next[candidate]
continue continue
elif (data_dep is None) and both_contain_comms: elif (data_dep is None) and both_contain_comms:
info.limiting_factor = ( info.limiting_factor = (
f"collective ordering {_group_names(group_head, group_tail)}" f"collective ordering {_group_names(gns)}"
f" with candidate:{candidate.get_name()}" f" with candidate:{candidate.get_name()}"
) )
break break
@ -769,49 +1081,89 @@ def _sink_waits_iterative_internal(
info.limiting_factor = ( info.limiting_factor = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(group_head, group_tail)}" f"dep on {gns}"
f"\n outs:{[o.get_name() for o in group_outs]}" f"\n outs:{[o.get_name() for o in group_outs]}"
f"\n non_group_reason:{grp_reason}" f"\n non_group_reason:{grp_reason}"
) )
break break
candidate_delta_memory = ( candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] candidate_delta_mem = (
candidate_allocfree.size_alloc - candidate_allocfree.size_free
) )
if group_peak_memory + candidate_delta_memory > peak_memory: # [group] candidate -> candidate [group]
info.limiting_factor = "peak_memory" # Check for buffers with successors in group and candidate last successor
#
# Buf that changes its last use snode,
# It was deallocated by candidate,
# but after swap it will be deallocated by group node.
group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
] = defaultdict(list)
for (
buf,
snode_last_use,
) in buf_to_snode_last_use.items():
succ_nodes = buf.mpi_buffer.succ_nodes
if snode_last_use != candidate: # noqa: E711
continue
# candidate is last use of buf
last_succ_gn = None
for gn in gns:
if gn in succ_nodes:
last_succ_gn = gn
if last_succ_gn is None:
continue
# gn has successors of buf that after potential swap will become
# last use of buf and start deallocating buf instead of candidate
group_n_to_bufs_after_swap_dealloc_instead_of_candidate[
last_succ_gn
].append(buf)
potential_peak, _post_alloc_update, _size_free_delta_update = (
_calculate_potential_peak_memory(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
)
)
if potential_peak > peak_memory:
info.limiting_factor = (
f"peak memory new:{potential_peak} vs base:{peak_memory}"
)
break break
info.moves += 1 info.moves += 1
info.moves_info += f"+{candidate.get_name()}" info.moves_info += f"+{candidate.get_name()}"
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next _perform_double_linked_list_swap(candidate, group_head, group_tail)
mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
# 2: _update_memory_tracking_after_swap(
candidate_next = _next[candidate] candidate,
if candidate_next: gns,
_prev[candidate_next] = group_tail group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
_next[group_tail] = candidate_next _post_alloc_update,
_size_free_delta_update,
)
# 1: if debug_iterative_memory_recompute:
_prev[group_head] = candidate from .comms_debug import _debug_iterative_memory_recompute
_next[candidate] = group_head
if group_head == _head:
_head = candidate
# Recompute curr_memory iterative_recompute_error = _debug_iterative_memory_recompute(
_prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index] candidate,
for n in _group_nodes(candidate, group_tail): gns,
_curr_memory[n] = _prev_curr_memory = ( _group_names(gns),
_prev_curr_memory + mem_deltas[n] _group_nodes(_head, None),
name_to_freeable_input_buf,
graph_outputs,
peak_memory,
_curr_memory,
snodes_allocfree,
"sink_waits_iterative",
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
) )
if iterative_recompute_error:
break
candidate = _next[group_tail] candidate = _next[group_tail]
curr = _prev[curr] # type: ignore[assignment] curr = _prev[curr] # type: ignore[assignment]
@ -850,11 +1202,11 @@ def _sink_waits_iterative_internal(
overlap_log.info(log_str) overlap_log.info(log_str)
new_snodes = _group_nodes(_head, None) new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num assert len(new_snodes) == original_snodes_num
new_peak_memory, curr_memory = estimate_peak_memory( new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
new_snodes, name_to_freeable_input_buf, graph_outputs new_snodes, name_to_freeable_input_buf, graph_outputs
) )
log_str += f"\n peak_memory_before:{peak_memory}" log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}"
log_str += f"\n peak_memory_after:{new_peak_memory}" log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}"
trace_structured( trace_structured(
"artifact", "artifact",
metadata_fn=lambda: { metadata_fn=lambda: {

View File

@ -0,0 +1,112 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from torch._logging import trace_structured
from .memory import estimate_peak_memory_allocfree
if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
from .memory import FreeableInputBuffer, SNodeMemory
from .scheduler import BaseSchedulerNode, SchedulerBuffer
def _debug_iterative_memory_recompute(
candidate: BaseSchedulerNode,
gns: list[BaseSchedulerNode],
group_names: str,
snodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
peak_memory: int,
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
tlparse_name: str,
gn_to_bufs_last_use: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
],
) -> bool:
iterative_recompute_error = False
candidate_allocfree = snodes_allocfree[candidate]
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
iter_cm = iter_curr_memory[candidate]
new_cm = est_curr_memory[candidate]
log = ""
if est_peak_memory > peak_memory:
log = "ITERATIVE PEAK DOES NOT MATCH"
iterative_recompute_error = True
if iter_cm != new_cm:
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
iterative_recompute_error = True
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
if iter_gnm != new_gnm:
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
iterative_recompute_error = True
if iterative_recompute_error:
log += (
f"\nCANDIDATE:{candidate.get_name()}"
f"\nGROUP:{group_names}"
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
f"\nCANDIDATE:{candidate.debug_str()}"
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
)
peak_log = ""
for i, (pre, post) in enumerate(snodes_curr_memory):
if est_peak_memory == pre:
n = snodes[i]
peak_log = (
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
)
break
group_log = ""
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
group_log += (
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
)
log += peak_log
log += group_log
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
log += "\n\n".join(
[
(
f"\nSNODE[{i}]\n{n.debug_str()}"
f"\nITER_cur_mem:{iter_curr_memory[n]}"
f"\nESTM_cur_mem:{est_curr_memory[n]}"
f"\nITER_allocfree:{snodes_allocfree[n]}"
f"\nESTM_allocfree:{snodes_allocfree[n]}"
)
for i, n in enumerate(snodes)
]
)
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
print(f"{tname}:\n{log}")
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": tname,
"encoding": "string",
},
payload_fn=lambda: log,
)
return iterative_recompute_error

View File

@ -389,6 +389,16 @@ reorder_prefetch_limit: Optional[int] = None
# enable operator reordering for peak memory optimization # enable operator reordering for peak memory optimization
reorder_for_peak_memory = True reorder_for_peak_memory = True
reorder_iterative_debug_memory_recompute: bool = False
reorder_iterative_debug_limit_to_reorder: Optional[int] = (
None
if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None
else int(env_str)
)
sink_waits_iterative_debug_limit_to_sink: Optional[int] = (
None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str)
)
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none" bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None

View File

@ -4,7 +4,7 @@ import collections
import dataclasses import dataclasses
import heapq import heapq
import logging import logging
from typing import Callable, TYPE_CHECKING, TypedDict, Union from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union
from torch._environment import is_fbcode from torch._environment import is_fbcode
from torch._utils_internal import signpost_event from torch._utils_internal import signpost_event
@ -76,7 +76,7 @@ def get_freeable_input_buf(
Create and keep track of all input buffers that can be freed during the program Create and keep track of all input buffers that can be freed during the program
Returns: Returns:
A dictionary containing all freeble input buffers, keyed by their names. A dictionary containing all freeable input buffers, keyed by their names.
""" """
def _dep_size_hint(dep: Dep) -> int: def _dep_size_hint(dep: Dep) -> int:
@ -315,7 +315,11 @@ def compute_memory_timeline(
nodes: list[BaseSchedulerNode], nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer], name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str], graph_outputs: OrderedSet[str],
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]: ) -> tuple[
list[BufferInfo],
dict[BaseSchedulerNode, int],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
""" """
Compute buffer allocation and deallocation sizes and map their Compute buffer allocation and deallocation sizes and map their
lifetime to the node schedule lifetime to the node schedule
@ -329,15 +333,33 @@ def compute_memory_timeline(
# get buffers' size and liveliness information # get buffers' size and liveliness information
buf_info_list: list[BufferInfo] = [] buf_info_list: list[BufferInfo] = []
buf_to_snode_last_use: dict[
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
] = {}
def _get_end_step_and_snode(
buf: Union[FreeableInputBuffer, SchedulerBuffer],
) -> tuple[int, Optional[BaseSchedulerNode]]:
max_step: int = -1
max_step_snode: Optional[BaseSchedulerNode] = None
succ_nodes = buf.mpi_buffer.succ_nodes
if succ_nodes:
for succ_node in succ_nodes:
step = node_to_step[succ_node]
if step > max_step:
max_step = step
max_step_snode = succ_node
assert max_step_snode is not None
return max_step, max_step_snode
# 1. for freeable input buffers # 1. for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items(): for buf_name, input_buf in name_to_freeable_input_buf.items():
end_step = ( end_step = -1
len(nodes) - 1 if buf_name not in graph_outputs:
if buf_name in graph_outputs end_step, end_step_snode = _get_end_step_and_snode(input_buf)
else max( assert end_step_snode is not None
node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes buf_to_snode_last_use[input_buf] = end_step_snode
)
)
buf_info_list.append( buf_info_list.append(
BufferInfo( BufferInfo(
input_buf, input_buf,
@ -354,17 +376,17 @@ def compute_memory_timeline(
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and # note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
# to be only used by its defining op (e.g., due to fusion when all consumers of # to be only used by its defining op (e.g., due to fusion when all consumers of
# the buffer are fused with its defining op). In such cases, end_step is step. # the buffer are fused with its defining op). In such cases, end_step is step.
end_step = ( buf_name = sched_buf.get_name()
len(nodes) - 1 end_step = -1
if sched_buf.get_name() in graph_outputs if buf_name not in graph_outputs:
else max( end_step, end_step_snode = _get_end_step_and_snode(sched_buf)
[ if end_step == -1:
node_to_step[succ_node] end_step = step
for succ_node in sched_buf.mpi_buffer.succ_nodes buf_to_snode_last_use[sched_buf] = node
], else:
default=step, assert end_step_snode is not None
) buf_to_snode_last_use[sched_buf] = end_step_snode
)
buf_info_list.append( buf_info_list.append(
BufferInfo( BufferInfo(
sched_buf, sched_buf,
@ -375,7 +397,7 @@ def compute_memory_timeline(
) )
) )
return buf_info_list, node_to_step return buf_info_list, node_to_step, buf_to_snode_last_use
def estimate_peak_memory( def estimate_peak_memory(
@ -392,7 +414,7 @@ def estimate_peak_memory(
List[int]: memory usage at each node (or each step). List[int]: memory usage at each node (or each step).
""" """
buf_info_list, _ = compute_memory_timeline( buf_info_list, _, _ = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs nodes, name_to_freeable_input_buf, graph_outputs
) )
@ -416,6 +438,73 @@ def estimate_peak_memory(
return (max_memory, memories_at_nodes) return (max_memory, memories_at_nodes)
@dataclasses.dataclass
class SNodeMemory:
size_alloc: int
size_free: int
def estimate_peak_memory_allocfree(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[
int,
list[tuple[int, int]],
dict[BaseSchedulerNode, SNodeMemory],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Alternative version of estimate_peak_memory, that respects the fact,
that every SchedulerNode has multiple phases:
1. alloc ( outputs )
2. run_kernel
3. dealloc last_use buffers
estimate_peak_memory collapses memory into one value: size_alloc - size_free
While peak memory happens after alloc.
Duplicating the code to not migrate all callsites at once,
In future usages of estimate_peak_memory will migrate to this version.
"""
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs
)
# incremental memory changes at each step
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))]
# for each buffer, update memory when created and when freed
for buf_info in buf_info_list:
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc
if buf_info.end_step != -1:
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
snodes_allocfree = {}
for i, node in enumerate(nodes):
snodes_allocfree[node] = step_idx_allocfree[i]
max_memory = 0
cur_memory = 0
snodes_curr_memory = []
for t in range(len(nodes)):
alloc = step_idx_allocfree[t].size_alloc
free = step_idx_allocfree[t].size_free
cur_memory += alloc
post_alloc = cur_memory
max_memory = max(max_memory, cur_memory)
cur_memory -= free
post_free = cur_memory
snodes_curr_memory.append((post_alloc, post_free))
return (
max_memory,
snodes_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
)
def topological_sort_lpmf( def topological_sort_lpmf(
nodes: list[BaseSchedulerNode], nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer], name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
@ -429,7 +518,7 @@ def topological_sort_lpmf(
Buffer memory optimization for video codec application modeled in Simulink Buffer memory optimization for video codec application modeled in Simulink
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
The algorithm maintain the max memory so far. The algorithm maintains the max memory so far.
At every iteration, for each scheduleable node, it computes: At every iteration, for each scheduleable node, it computes:
- how much memory needs to be allocated for the output buffers of this node; - how much memory needs to be allocated for the output buffers of this node;
- how much memory can be freed as a result of executing this node. - how much memory can be freed as a result of executing this node.

View File

@ -2160,6 +2160,12 @@ class Scheduler:
OrderedSet(V.graph.get_output_names()), OrderedSet(V.graph.get_output_names()),
) )
if config.reorder_for_compute_comm_overlap: if config.reorder_for_compute_comm_overlap:
if not config.reorder_for_peak_memory:
from .memory import assign_memory_planning_info_for_scheduler_buffers
assign_memory_planning_info_for_scheduler_buffers(
self.nodes, self.name_to_buf
)
from torch._logging import trace_structured from torch._logging import trace_structured
trace_structured( trace_structured(
@ -2556,7 +2562,7 @@ class Scheduler:
) )
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
buf_info_list, _ = compute_memory_timeline( buf_info_list, _, _ = compute_memory_timeline(
self.nodes, self.nodes,
name_to_freeable_input_buf, name_to_freeable_input_buf,
graph_outputs, graph_outputs,