[inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)

1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory:
```
    """
    Alternative version of estimate_peak_memory, that respects the fact,
    that every SchedulerNode has multiple phases:
    1. alloc ( outputs )
    2. run_kernel
    3. dealloc last_use buffers
    estimate_peak_memory collapses memory into one value: size_alloc - size_free
    While peak memory happens after alloc.

    Duplicating the code to not migrate all callsites at once,
    In future usages of estimate_peak_memory will migrate to this version.
    """
```

- Applying this in `reorder_communication_preserving_peak_memory` pass.

2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode.

- Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size).

4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder.

What is after this PR:

Iterative recomputation of memory estimations matches full memory estimations.

Active memory is not regressing a lot, but reserved memory is significantly regressed.

Investigation and fix of "reserved" memory will be in following PRs.

BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb
```
[rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step:  1  loss: 12.2722  grad_norm:  4.2192  active_memory: 24.66GiB(25.96%)  reserved_memory: 25.38GiB(26.72%)  tps: 99  tflops: 5.71  mfu: 0.58%
[rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step:  2  loss: 13.1738  grad_norm: 50.5566  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 4,448  tflops: 257.63  mfu: 26.05%
[rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step:  3  loss: 15.6866  grad_norm: 80.0862  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,900  tflops: 341.72  mfu: 34.55%
[rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step:  4  loss: 13.4853  grad_norm:  7.8538  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,881  tflops: 340.57  mfu: 34.44%
[rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step:  5  loss: 16.1191  grad_norm: 53.2481  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,867  tflops: 339.77  mfu: 34.35%
```
REORDER: active: 32Gb reserved: 36Gb
```
[rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step:  1  loss: 12.2490  grad_norm:  4.1944  active_memory: 24.66GiB(25.96%)  reserved_memory: 26.81GiB(28.22%)  tps: 85  tflops: 4.90  mfu: 0.50%
[rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step:  2  loss: 13.1427  grad_norm: 39.5942  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 3,205  tflops: 185.61  mfu: 18.77%
[rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step:  3  loss: 14.6084  grad_norm: 51.0743  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,688  tflops: 329.44  mfu: 33.31%
[rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step:  4  loss: 13.6181  grad_norm:  8.1122  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,744  tflops: 332.68  mfu: 33.64%
[rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step:  5  loss: 15.8913  grad_norm: 59.8510  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,046  tflops: 292.22  mfu: 29.55%
```

REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb
```
[rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step:  1  loss: 12.2646  grad_norm:  4.1282  active_memory: 27.60GiB(29.05%)  reserved_memory: 32.49GiB(34.20%)  tps: 173  tflops: 10.00  mfu: 1.01%
[rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step:  2  loss: 13.2353  grad_norm: 42.4234  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,152  tflops: 356.26  mfu: 36.02%
[rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step:  3  loss: 13.8205  grad_norm: 24.0156  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,169  tflops: 357.29  mfu: 36.13%
[rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step:  4  loss: 13.1033  grad_norm:  9.1167  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,183  tflops: 358.10  mfu: 36.21%
[rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step:  5  loss: 16.3530  grad_norm: 51.8118  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,130  tflops: 355.03  mfu: 35.90%
```

Differential Revision: [D80718143](https://our.internmc.facebook.com/intern/diff/D80718143)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113
Approved by: https://github.com/wconstab, https://github.com/eellison

Co-authored-by: eellison <elias.ellison@gmail.com>
This commit is contained in:
IvanKobzarev
2025-08-22 00:47:04 -07:00
committed by PyTorch MergeBot
parent 639b8cc51d
commit db44de4c0d
5 changed files with 749 additions and 180 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import heapq
import importlib
import itertools
import logging
import operator
import sys
@ -23,8 +24,15 @@ from .dependencies import WeakDep
if TYPE_CHECKING:
from .ir import IRNode, Operation
from .scheduler import SchedulerBuffer
from .memory import estimate_peak_memory, FreeableInputBuffer, get_freeable_input_buf
from .memory import (
estimate_peak_memory,
estimate_peak_memory_allocfree,
FreeableInputBuffer,
get_freeable_input_buf,
SNodeMemory,
)
from .utils import (
contains_collective,
contains_wait,
@ -188,6 +196,46 @@ def _is_fake_dep(d):
return isinstance(d, WeakDep) and d.is_fake
def _group_names(gns: list[BaseSchedulerNode]) -> str:
return "~".join([gn.get_name() for gn in gns])
def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs):
"""Initialize memory tracking data structures"""
name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs)
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
_curr_memory = dict(zip(snodes, snodes_curr_memory))
_curr_memory[None] = (0, 0)
return (
peak_memory,
_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
name_to_freeable_input_buf,
)
def _initialize_double_linked_list(
snodes: list[BaseSchedulerNode],
) -> tuple[
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
BaseSchedulerNode,
]:
"""Create double-linked list structure from snodes"""
_prev = {}
_next = {}
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_head = snodes[0]
return _prev, _next, _head
def _reorder_communication_preserving_peak_memory_internal(
snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
@ -211,20 +259,22 @@ def _reorder_communication_preserving_peak_memory_internal(
# heuristic to avoid degenerating to quadratic time
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
snodes, graph_inputs
)
peak_memory, curr_memory = estimate_peak_memory(
snodes, name_to_freeable_input_buf, graph_outputs
)
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
(
peak_memory,
_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
name_to_freeable_input_buf,
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
runtimes: dict[BaseSchedulerNode, float] = {
snode: estimate_op_runtime(snode) for snode in snodes
}
# debug stats
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
def exposed_communication_time(collective_snode, remaining_snodes):
def exposed_communication_time(
collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode]
) -> float:
# assumes a linear schedule and computes the overlap of the collective with the remaining nodes
comm_time = estimate_op_runtime(collective_snode)
compute_time = 0.0
@ -236,7 +286,7 @@ def _reorder_communication_preserving_peak_memory_internal(
# we can ignore it. Otherwise, it's the end of the road for overlap opportunities
break
def accumulate_time(_snode):
def accumulate_time(_snode: BaseSchedulerNode) -> None:
nonlocal compute_time
compute_time += runtimes[_snode]
@ -245,18 +295,11 @@ def _reorder_communication_preserving_peak_memory_internal(
total_moves = 0
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
_prev, _next, _head = _initialize_double_linked_list(snodes)
_head = snodes[0]
def _group_nodes(head, tail):
def _group_nodes(
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
) -> list[BaseSchedulerNode]:
ret = []
n = head
while True:
@ -264,37 +307,167 @@ def _reorder_communication_preserving_peak_memory_internal(
ret.append(n)
if n == tail:
break
n = _next[n]
n = _next[n] # type: ignore[index]
return ret
def _group_names(head, tail):
ret = ""
for n in _group_nodes(head, tail):
if ret:
ret += "~"
ret += n.get_name()
return ret
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
# swap (candidate, group_head...group_tail)
# Before:
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
# After:
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
# 0
candidate_prev = _prev[candidate]
if candidate_prev:
_next[candidate_prev] = group_head
_prev[group_head] = candidate_prev
# 2
group_tail_next = _next[group_tail]
if group_tail_next:
_prev[group_tail_next] = candidate
_next[candidate] = group_tail_next
# 1
_prev[candidate] = group_tail
_next[group_tail] = candidate
nonlocal _head
if _head == candidate:
_head = group_head
def _calculate_potential_peak_memory(
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate
):
# Caching calculations of memory for group nodes and candidate,
# to apply without recalculation after swap.
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
potential_peak: int = 0
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
# Not accounting for buffers last use change
potential_peak = max(
group_peak_memory - candidate_delta_mem,
_curr_memory[group_tail][1]
- candidate_delta_mem
+ candidate_allocfree.size_alloc,
)
return potential_peak, _post_alloc_update
# If candidate will be after group, the starting memory level of group nodes
# changes to the -(candidate.size_alloc - candidate.size_free)
mem_after_reorder_delta: int = -candidate_delta_mem
for gn in gns:
gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta
_post_alloc_update[gn] = gn_post_alloc_mem
potential_peak = max(potential_peak, gn_post_alloc_mem)
bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None)
if bufs is not None:
for buf in bufs:
# Candidate will deallocate those buffers
mem_after_reorder_delta += buf.mpi_buffer.size_free
candidate_mem_post_alloc = (
_curr_memory[group_tail][1]
+ mem_after_reorder_delta
+ candidate_allocfree.size_alloc
)
_post_alloc_update[candidate] = candidate_mem_post_alloc
potential_peak = max(potential_peak, candidate_mem_post_alloc)
return potential_peak, _post_alloc_update
def _update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_by_candidate,
_post_alloc_update,
):
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
for gn in gns:
cm = _curr_memory[gn]
_curr_memory[gn] = (
cm[0] - candidate_delta_mem,
cm[1] - candidate_delta_mem,
)
_candidate_post_alloc_mem = (
_curr_memory[group_tail][1] + candidate_allocfree.size_alloc
)
_candidate_post_free_mem = (
_candidate_post_alloc_mem - candidate_allocfree.size_free
)
_curr_memory[candidate] = (
_candidate_post_alloc_mem,
_candidate_post_free_mem,
)
return
# Candidate becomes last use of some bufs
for (
gn,
bufs,
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
for buf in bufs:
buf_to_snode_last_use[buf] = candidate
size_free_to_move_to_candidate_sum: int = 0
for n in gns:
_gn_post_alloc_mem: int = _post_alloc_update[n]
size_free_to_move_to_candidate: int = sum(
buf.mpi_buffer.size_free
for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n]
)
size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate
# group node does not deallocate this after swap
snodes_allocfree[n].size_free -= size_free_to_move_to_candidate
gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free
_curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem)
_candidate_post_alloc_mem = _post_alloc_update[candidate]
snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum
candidate_post_free_mem = (
_candidate_post_alloc_mem - snodes_allocfree[candidate].size_free
)
_curr_memory[candidate] = (
_candidate_post_alloc_mem,
candidate_post_free_mem,
)
debug_num_collectives_to_reorder: Optional[int] = (
config.reorder_iterative_debug_limit_to_reorder
)
num_processed_collectives: int = 0
curr = _head
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
iterative_recompute_error = False
while _next[curr] is not None:
if iterative_recompute_error:
break
if contains_collective(curr):
reorder_info = stats[curr] = ReorderInfo()
reorder_info.initial_exposed = reorder_info.final_exposed = (
exposed_communication_time(curr, _group_nodes(_next[curr], None))
if debug_num_collectives_to_reorder is not None and (
num_processed_collectives >= debug_num_collectives_to_reorder
):
break
num_processed_collectives += 1
info = stats[curr] = ReorderInfo()
info.initial_exposed = info.final_exposed = exposed_communication_time(
curr, _group_nodes(_next[curr], None)
)
candidate = _prev[curr]
group_head = curr
group_tail = curr
group_peak_memory = _curr_memory[curr]
group_peak_memory = _curr_memory[curr][0] # post_alloc memory
while candidate is not None:
if contains_collective(candidate):
reorder_info.limiting_factor = "collective ordering"
info.limiting_factor = "collective ordering"
break
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
group = GroupedSchedulerNode(
curr.scheduler,
_group_nodes(group_head, group_tail),
gns,
temp_grouping=True,
)
@ -314,7 +487,9 @@ def _reorder_communication_preserving_peak_memory_internal(
if data_dep is not None:
def is_groupable(candidate):
def is_groupable(
candidate: BaseSchedulerNode,
) -> tuple[bool, Optional[str]]:
# preserve ordering
if contains_collective(candidate):
return False, "contains_collective"
@ -323,73 +498,106 @@ def _reorder_communication_preserving_peak_memory_internal(
return False, "contains_gemm_like"
return True, None
is_grp, grp_reason = is_groupable(candidate)
if is_grp:
is_groupable_result, grouping_reason = is_groupable(candidate)
if is_groupable_result:
group_head = candidate
group_peak_memory = max(
group_peak_memory, _curr_memory[candidate]
group_peak_memory, _curr_memory[candidate][0]
)
reorder_info.grouped += 1
reorder_info.grouped_info = _group_names(group_head, group_tail)
info.grouped += 1
info.grouped_info = _group_names(gns)
candidate = _prev[candidate]
continue
else:
msg = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(group_head, group_tail)}"
f"\n non_group_reason:{grp_reason}"
f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(gns)}"
f"\n non_group_reason:{grouping_reason}"
)
reorder_info.limiting_factor = msg
info.limiting_factor = msg
break
delta_memory_candidate = (
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
candidate_delta_mem: int = (
candidate_allocfree.size_alloc - candidate_allocfree.size_free
)
# candidate and one of group nodes are successors of the same buffer
# and last use of the buffer happen in group nodes.
# This last use deallocates it.
# If we swap [candidate [group]] to [[group] candidate],
# candidate becomes the last use
# and deallocated this buffer instead of group node.
# we need to update size_free accordingly to group_node and candidate,
# and recalculate post_alloc, post_free for them.
#
# Buf that changes its last use snode,
# after swap will be deallocated only by candidate,
# while before it was deallocated by group node.
group_n_to_bufs_after_swap_dealloc_by_candidate: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
] = defaultdict(list)
for (
buf,
snode_last_use,
) in buf_to_snode_last_use.items():
succ_nodes = buf.mpi_buffer.succ_nodes
if candidate not in succ_nodes:
continue
if not any(gn == snode_last_use for gn in gns):
continue
group_n_to_bufs_after_swap_dealloc_by_candidate[
snode_last_use
].append(buf)
potential_peak, _post_alloc_update = _calculate_potential_peak_memory(
candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate
)
if group_peak_memory - delta_memory_candidate > peak_memory:
reorder_info.limiting_factor = "peak memory"
if potential_peak > peak_memory:
info.limiting_factor = (
f"peak memory new:{potential_peak} vs base:{peak_memory}"
)
break
reorder_info.moves += 1
info.moves += 1
total_moves += 1
mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# swap (candidate, group_head...group_tail)
# Before:
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
# After:
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
# 0
candidate_prev = _prev[candidate]
if candidate_prev:
_next[candidate_prev] = group_head
_prev[group_head] = candidate_prev
_perform_double_linked_list_swap(candidate, group_head, group_tail)
# 2
group_tail_next = _next[group_tail]
if group_tail_next:
_prev[group_tail_next] = candidate
_next[candidate] = group_tail_next
# 1
_prev[candidate] = group_tail
_next[group_tail] = candidate
if _head == candidate:
_head = group_head
reorder_info.final_exposed = exposed_communication_time(
info.final_exposed = exposed_communication_time(
curr, _group_nodes(_next[curr], None)
)
# Recompute curr_memory
_prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index]
for n in _group_nodes(group_head, candidate):
_curr_memory[n] = _prev_curr_memory = (
_prev_curr_memory + mem_deltas[n]
_update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_by_candidate,
_post_alloc_update,
)
if debug_iterative_memory_recompute:
# Compare iteratively recomputed memory data
# with full run of estimate_peak_memory
from .comms_debug import _debug_iterative_memory_recompute
iterative_recompute_error = _debug_iterative_memory_recompute(
candidate,
gns,
_group_names(gns),
_group_nodes(_head, None),
name_to_freeable_input_buf,
graph_outputs,
peak_memory,
_curr_memory,
snodes_allocfree,
"reorder_communication_preserving_peak_memory",
group_n_to_bufs_after_swap_dealloc_by_candidate,
)
if iterative_recompute_error:
break
candidate = _prev[group_head]
curr = _next[curr] # type: ignore[assignment]
@ -415,15 +623,15 @@ def _reorder_communication_preserving_peak_memory_internal(
rows = [
[
node_summary(snode),
node_reorder_info.initial_exposed,
node_reorder_info.final_exposed,
node_reorder_info.improvement,
node_reorder_info.limiting_factor,
node_reorder_info.moves,
node_reorder_info.grouped,
node_reorder_info.grouped_info,
node_info.initial_exposed,
node_info.final_exposed,
node_info.improvement,
node_info.limiting_factor,
node_info.moves,
node_info.grouped,
node_info.grouped_info,
]
for snode, node_reorder_info in node_stats.items()
for snode, node_info in node_stats.items()
]
if importlib.util.find_spec("tabulate"):
from tabulate import tabulate
@ -441,7 +649,7 @@ def _reorder_communication_preserving_peak_memory_internal(
new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num
new_peak_memory, curr_memory = estimate_peak_memory(
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
new_snodes, name_to_freeable_input_buf, graph_outputs
)
reorder_log_str += f"\n peak_memory_before:{peak_memory}"
@ -657,24 +865,21 @@ def _sink_waits_iterative_internal(
return snodes, {}
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
snodes, graph_inputs
)
peak_memory, curr_memory = estimate_peak_memory(
snodes, name_to_freeable_input_buf, graph_outputs
)
(
peak_memory,
_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
name_to_freeable_input_buf,
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
_prev, _next, _head = _initialize_double_linked_list(snodes)
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_head = snodes[0]
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
def _group_nodes(head, tail):
def _group_nodes(
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
) -> list[BaseSchedulerNode]:
ret = []
n = head
while True:
@ -682,21 +887,125 @@ def _sink_waits_iterative_internal(
ret.append(n)
if n == tail:
break
n = _next[n]
n = _next[n] # type: ignore[index]
return ret
def _group_names(head, tail):
ret = ""
for n in _group_nodes(head, tail):
if ret:
ret += "~"
ret += n.get_name()
return ret
def _calculate_potential_peak_memory(
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate
):
pre_group_mem = (
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
)
# Stash memory tracing updates to not recompute them after swap
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
_size_free_delta_update: dict[BaseSchedulerNode, int] = {}
potential_peak = 0
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
# Not accounting for buffers liveliness change
potential_peak = max(
group_peak_memory + candidate_delta_mem,
pre_group_mem + candidate_allocfree.size_alloc,
)
return potential_peak, _post_alloc_update, _size_free_delta_update
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
_post_alloc_update[candidate] = candidate_post_alloc
potential_peak = candidate_post_alloc
candidate_size_free_to_move = sum(
buf.mpi_buffer.size_free # type: ignore[attr-defined]
for buf in itertools.chain.from_iterable(
group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values()
)
)
_size_free_delta_update[candidate] = -candidate_size_free_to_move
delta_mem = candidate_delta_mem + candidate_size_free_to_move
for gn in gns:
gn_post_alloc = _curr_memory[gn][0] + delta_mem
_post_alloc_update[gn] = gn_post_alloc
potential_peak = max(potential_peak, gn_post_alloc)
gn_size_free_to_add = 0
if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn]
for buf in bufs:
gn_size_free_to_add += buf.mpi_buffer.size_free
_size_free_delta_update[gn] = gn_size_free_to_add
delta_mem -= gn_size_free_to_add
return potential_peak, _post_alloc_update, _size_free_delta_update
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
# 2:
candidate_next = _next[candidate]
if candidate_next:
_prev[candidate_next] = group_tail
_next[group_tail] = candidate_next
# 1:
_prev[group_head] = candidate
_next[candidate] = group_head
nonlocal _head
if group_head == _head:
_head = candidate
def _update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
_post_alloc_update,
_size_free_delta_update,
):
group_head = gns[0]
pre_group_mem = (
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
)
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
_curr_memory[candidate] = (
candidate_post_alloc,
candidate_post_alloc - candidate_allocfree.size_free,
)
for gn in gns:
cm = _curr_memory[gn]
_curr_memory[gn] = (
cm[0] + candidate_delta_mem,
cm[1] + candidate_delta_mem,
)
return
for n in [candidate, *gns]:
post_alloc = _post_alloc_update[n]
snodes_allocfree[n].size_free += _size_free_delta_update[n]
_curr_memory[n] = (
post_alloc,
post_alloc - snodes_allocfree[n].size_free,
)
curr = snodes[-1]
processed_waits = OrderedSet() # type: ignore[var-annotated]
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
debug_num_sink_waits_to_reorder: Optional[int] = (
config.sink_waits_iterative_debug_limit_to_sink
)
iterative_recompute_error = False
while _prev[curr] is not None:
if iterative_recompute_error:
break
if (
debug_num_sink_waits_to_reorder is not None
and len(processed_waits) >= debug_num_sink_waits_to_reorder
):
break
if contains_wait(curr) and curr not in processed_waits:
processed_waits.add(curr)
info = stats[curr] = SinkWaitInfo()
@ -704,11 +1013,14 @@ def _sink_waits_iterative_internal(
wait_snode = curr
group_head = curr
group_tail = curr
group_peak_memory = _curr_memory[curr]
group_peak_memory = _curr_memory[curr][0]
while candidate is not None:
if iterative_recompute_error:
break
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
group = GroupedSchedulerNode(
wait_snode.scheduler,
_group_nodes(group_head, group_tail),
gns,
temp_grouping=True,
)
@ -753,15 +1065,15 @@ def _sink_waits_iterative_internal(
if is_grp:
group_tail = candidate
group_peak_memory = max(
group_peak_memory, _curr_memory[candidate]
group_peak_memory, _curr_memory[candidate][0]
)
info.grouped += 1
info.grouped_info = _group_names(group_head, group_tail)
info.grouped_info = _group_names(gns)
candidate = _next[candidate]
continue
elif (data_dep is None) and both_contain_comms:
info.limiting_factor = (
f"collective ordering {_group_names(group_head, group_tail)}"
f"collective ordering {_group_names(gns)}"
f" with candidate:{candidate.get_name()}"
)
break
@ -769,49 +1081,89 @@ def _sink_waits_iterative_internal(
info.limiting_factor = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(group_head, group_tail)}"
f"dep on {gns}"
f"\n outs:{[o.get_name() for o in group_outs]}"
f"\n non_group_reason:{grp_reason}"
)
break
candidate_delta_memory = (
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
candidate_delta_mem = (
candidate_allocfree.size_alloc - candidate_allocfree.size_free
)
if group_peak_memory + candidate_delta_memory > peak_memory:
info.limiting_factor = "peak_memory"
# [group] candidate -> candidate [group]
# Check for buffers with successors in group and candidate last successor
#
# Buf that changes its last use snode,
# It was deallocated by candidate,
# but after swap it will be deallocated by group node.
group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
] = defaultdict(list)
for (
buf,
snode_last_use,
) in buf_to_snode_last_use.items():
succ_nodes = buf.mpi_buffer.succ_nodes
if snode_last_use != candidate: # noqa: E711
continue
# candidate is last use of buf
last_succ_gn = None
for gn in gns:
if gn in succ_nodes:
last_succ_gn = gn
if last_succ_gn is None:
continue
# gn has successors of buf that after potential swap will become
# last use of buf and start deallocating buf instead of candidate
group_n_to_bufs_after_swap_dealloc_instead_of_candidate[
last_succ_gn
].append(buf)
potential_peak, _post_alloc_update, _size_free_delta_update = (
_calculate_potential_peak_memory(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
)
)
if potential_peak > peak_memory:
info.limiting_factor = (
f"peak memory new:{potential_peak} vs base:{peak_memory}"
)
break
info.moves += 1
info.moves_info += f"+{candidate.get_name()}"
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
_perform_double_linked_list_swap(candidate, group_head, group_tail)
# 2:
candidate_next = _next[candidate]
if candidate_next:
_prev[candidate_next] = group_tail
_next[group_tail] = candidate_next
_update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
_post_alloc_update,
_size_free_delta_update,
)
# 1:
_prev[group_head] = candidate
_next[candidate] = group_head
if group_head == _head:
_head = candidate
if debug_iterative_memory_recompute:
from .comms_debug import _debug_iterative_memory_recompute
# Recompute curr_memory
_prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index]
for n in _group_nodes(candidate, group_tail):
_curr_memory[n] = _prev_curr_memory = (
_prev_curr_memory + mem_deltas[n]
iterative_recompute_error = _debug_iterative_memory_recompute(
candidate,
gns,
_group_names(gns),
_group_nodes(_head, None),
name_to_freeable_input_buf,
graph_outputs,
peak_memory,
_curr_memory,
snodes_allocfree,
"sink_waits_iterative",
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
)
if iterative_recompute_error:
break
candidate = _next[group_tail]
curr = _prev[curr] # type: ignore[assignment]
@ -850,11 +1202,11 @@ def _sink_waits_iterative_internal(
overlap_log.info(log_str)
new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num
new_peak_memory, curr_memory = estimate_peak_memory(
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
new_snodes, name_to_freeable_input_buf, graph_outputs
)
log_str += f"\n peak_memory_before:{peak_memory}"
log_str += f"\n peak_memory_after:{new_peak_memory}"
log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}"
log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}"
trace_structured(
"artifact",
metadata_fn=lambda: {

View File

@ -0,0 +1,112 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from torch._logging import trace_structured
from .memory import estimate_peak_memory_allocfree
if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
from .memory import FreeableInputBuffer, SNodeMemory
from .scheduler import BaseSchedulerNode, SchedulerBuffer
def _debug_iterative_memory_recompute(
candidate: BaseSchedulerNode,
gns: list[BaseSchedulerNode],
group_names: str,
snodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
peak_memory: int,
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
tlparse_name: str,
gn_to_bufs_last_use: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
],
) -> bool:
iterative_recompute_error = False
candidate_allocfree = snodes_allocfree[candidate]
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
iter_cm = iter_curr_memory[candidate]
new_cm = est_curr_memory[candidate]
log = ""
if est_peak_memory > peak_memory:
log = "ITERATIVE PEAK DOES NOT MATCH"
iterative_recompute_error = True
if iter_cm != new_cm:
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
iterative_recompute_error = True
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
if iter_gnm != new_gnm:
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
iterative_recompute_error = True
if iterative_recompute_error:
log += (
f"\nCANDIDATE:{candidate.get_name()}"
f"\nGROUP:{group_names}"
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
f"\nCANDIDATE:{candidate.debug_str()}"
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
)
peak_log = ""
for i, (pre, post) in enumerate(snodes_curr_memory):
if est_peak_memory == pre:
n = snodes[i]
peak_log = (
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
)
break
group_log = ""
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
group_log += (
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
)
log += peak_log
log += group_log
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
log += "\n\n".join(
[
(
f"\nSNODE[{i}]\n{n.debug_str()}"
f"\nITER_cur_mem:{iter_curr_memory[n]}"
f"\nESTM_cur_mem:{est_curr_memory[n]}"
f"\nITER_allocfree:{snodes_allocfree[n]}"
f"\nESTM_allocfree:{snodes_allocfree[n]}"
)
for i, n in enumerate(snodes)
]
)
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
print(f"{tname}:\n{log}")
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": tname,
"encoding": "string",
},
payload_fn=lambda: log,
)
return iterative_recompute_error

View File

@ -389,6 +389,16 @@ reorder_prefetch_limit: Optional[int] = None
# enable operator reordering for peak memory optimization
reorder_for_peak_memory = True
reorder_iterative_debug_memory_recompute: bool = False
reorder_iterative_debug_limit_to_reorder: Optional[int] = (
None
if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None
else int(env_str)
)
sink_waits_iterative_debug_limit_to_sink: Optional[int] = (
None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str)
)
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None

View File

@ -4,7 +4,7 @@ import collections
import dataclasses
import heapq
import logging
from typing import Callable, TYPE_CHECKING, TypedDict, Union
from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union
from torch._environment import is_fbcode
from torch._utils_internal import signpost_event
@ -76,7 +76,7 @@ def get_freeable_input_buf(
Create and keep track of all input buffers that can be freed during the program
Returns:
A dictionary containing all freeble input buffers, keyed by their names.
A dictionary containing all freeable input buffers, keyed by their names.
"""
def _dep_size_hint(dep: Dep) -> int:
@ -315,7 +315,11 @@ def compute_memory_timeline(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
) -> tuple[
list[BufferInfo],
dict[BaseSchedulerNode, int],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Compute buffer allocation and deallocation sizes and map their
lifetime to the node schedule
@ -329,15 +333,33 @@ def compute_memory_timeline(
# get buffers' size and liveliness information
buf_info_list: list[BufferInfo] = []
buf_to_snode_last_use: dict[
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
] = {}
def _get_end_step_and_snode(
buf: Union[FreeableInputBuffer, SchedulerBuffer],
) -> tuple[int, Optional[BaseSchedulerNode]]:
max_step: int = -1
max_step_snode: Optional[BaseSchedulerNode] = None
succ_nodes = buf.mpi_buffer.succ_nodes
if succ_nodes:
for succ_node in succ_nodes:
step = node_to_step[succ_node]
if step > max_step:
max_step = step
max_step_snode = succ_node
assert max_step_snode is not None
return max_step, max_step_snode
# 1. for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items():
end_step = (
len(nodes) - 1
if buf_name in graph_outputs
else max(
node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes
)
)
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(input_buf)
assert end_step_snode is not None
buf_to_snode_last_use[input_buf] = end_step_snode
buf_info_list.append(
BufferInfo(
input_buf,
@ -354,17 +376,17 @@ def compute_memory_timeline(
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
# to be only used by its defining op (e.g., due to fusion when all consumers of
# the buffer are fused with its defining op). In such cases, end_step is step.
end_step = (
len(nodes) - 1
if sched_buf.get_name() in graph_outputs
else max(
[
node_to_step[succ_node]
for succ_node in sched_buf.mpi_buffer.succ_nodes
],
default=step,
)
)
buf_name = sched_buf.get_name()
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(sched_buf)
if end_step == -1:
end_step = step
buf_to_snode_last_use[sched_buf] = node
else:
assert end_step_snode is not None
buf_to_snode_last_use[sched_buf] = end_step_snode
buf_info_list.append(
BufferInfo(
sched_buf,
@ -375,7 +397,7 @@ def compute_memory_timeline(
)
)
return buf_info_list, node_to_step
return buf_info_list, node_to_step, buf_to_snode_last_use
def estimate_peak_memory(
@ -392,7 +414,7 @@ def estimate_peak_memory(
List[int]: memory usage at each node (or each step).
"""
buf_info_list, _ = compute_memory_timeline(
buf_info_list, _, _ = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs
)
@ -416,6 +438,73 @@ def estimate_peak_memory(
return (max_memory, memories_at_nodes)
@dataclasses.dataclass
class SNodeMemory:
size_alloc: int
size_free: int
def estimate_peak_memory_allocfree(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[
int,
list[tuple[int, int]],
dict[BaseSchedulerNode, SNodeMemory],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Alternative version of estimate_peak_memory, that respects the fact,
that every SchedulerNode has multiple phases:
1. alloc ( outputs )
2. run_kernel
3. dealloc last_use buffers
estimate_peak_memory collapses memory into one value: size_alloc - size_free
While peak memory happens after alloc.
Duplicating the code to not migrate all callsites at once,
In future usages of estimate_peak_memory will migrate to this version.
"""
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs
)
# incremental memory changes at each step
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))]
# for each buffer, update memory when created and when freed
for buf_info in buf_info_list:
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc
if buf_info.end_step != -1:
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
snodes_allocfree = {}
for i, node in enumerate(nodes):
snodes_allocfree[node] = step_idx_allocfree[i]
max_memory = 0
cur_memory = 0
snodes_curr_memory = []
for t in range(len(nodes)):
alloc = step_idx_allocfree[t].size_alloc
free = step_idx_allocfree[t].size_free
cur_memory += alloc
post_alloc = cur_memory
max_memory = max(max_memory, cur_memory)
cur_memory -= free
post_free = cur_memory
snodes_curr_memory.append((post_alloc, post_free))
return (
max_memory,
snodes_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
)
def topological_sort_lpmf(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
@ -429,7 +518,7 @@ def topological_sort_lpmf(
Buffer memory optimization for video codec application modeled in Simulink
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
The algorithm maintain the max memory so far.
The algorithm maintains the max memory so far.
At every iteration, for each scheduleable node, it computes:
- how much memory needs to be allocated for the output buffers of this node;
- how much memory can be freed as a result of executing this node.

View File

@ -2160,6 +2160,12 @@ class Scheduler:
OrderedSet(V.graph.get_output_names()),
)
if config.reorder_for_compute_comm_overlap:
if not config.reorder_for_peak_memory:
from .memory import assign_memory_planning_info_for_scheduler_buffers
assign_memory_planning_info_for_scheduler_buffers(
self.nodes, self.name_to_buf
)
from torch._logging import trace_structured
trace_structured(
@ -2556,7 +2562,7 @@ class Scheduler:
)
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
buf_info_list, _ = compute_memory_timeline(
buf_info_list, _, _ = compute_memory_timeline(
self.nodes,
name_to_freeable_input_buf,
graph_outputs,