Revert "[inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)"

This reverts commit 9d18bf01b1661d227f6af41ac07a1e9ef20a9e1a.

Reverted https://github.com/pytorch/pytorch/pull/160113 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but lots of failures showing up after this lands ([comment](https://github.com/pytorch/pytorch/pull/160113#issuecomment-3209487237))
This commit is contained in:
PyTorch MergeBot
2025-08-21 08:13:33 +00:00
parent 23b033452f
commit bd5857a1d6
5 changed files with 190 additions and 741 deletions

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import heapq import heapq
import importlib import importlib
import itertools
import logging import logging
import operator import operator
import sys import sys
@ -24,15 +23,8 @@ from .dependencies import WeakDep
if TYPE_CHECKING: if TYPE_CHECKING:
from .ir import IRNode, Operation from .ir import IRNode, Operation
from .scheduler import SchedulerBuffer
from .memory import ( from .memory import estimate_peak_memory, FreeableInputBuffer, get_freeable_input_buf
estimate_peak_memory,
estimate_peak_memory_allocfree,
FreeableInputBuffer,
get_freeable_input_buf,
SNodeMemory,
)
from .utils import ( from .utils import (
contains_collective, contains_collective,
contains_wait, contains_wait,
@ -196,46 +188,6 @@ def _is_fake_dep(d):
return isinstance(d, WeakDep) and d.is_fake return isinstance(d, WeakDep) and d.is_fake
def _group_names(gns: list[BaseSchedulerNode]) -> str:
return "~".join([gn.get_name() for gn in gns])
def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs):
"""Initialize memory tracking data structures"""
name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs)
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
_curr_memory = dict(zip(snodes, snodes_curr_memory))
_curr_memory[None] = (0, 0)
return (
peak_memory,
_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
name_to_freeable_input_buf,
)
def _initialize_double_linked_list(
snodes: list[BaseSchedulerNode],
) -> tuple[
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
BaseSchedulerNode,
]:
"""Create double-linked list structure from snodes"""
_prev = {}
_next = {}
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_head = snodes[0]
return _prev, _next, _head
def _reorder_communication_preserving_peak_memory_internal( def _reorder_communication_preserving_peak_memory_internal(
snodes: list[BaseSchedulerNode], snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
@ -259,22 +211,20 @@ def _reorder_communication_preserving_peak_memory_internal(
# heuristic to avoid degenerating to quadratic time # heuristic to avoid degenerating to quadratic time
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
( name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
peak_memory, snodes, graph_inputs
_curr_memory, )
snodes_allocfree, peak_memory, curr_memory = estimate_peak_memory(
buf_to_snode_last_use, snodes, name_to_freeable_input_buf, graph_outputs
name_to_freeable_input_buf, )
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
runtimes: dict[BaseSchedulerNode, float] = { _curr_memory = dict(zip(snodes, curr_memory))
snode: estimate_op_runtime(snode) for snode in snodes _curr_memory[None] = 0 # type: ignore[index]
}
# debug stats # debug stats
stats: dict[BaseSchedulerNode, ReorderInfo] = {} stats: dict[BaseSchedulerNode, ReorderInfo] = {}
def exposed_communication_time( def exposed_communication_time(collective_snode, remaining_snodes):
collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode]
) -> float:
# assumes a linear schedule and computes the overlap of the collective with the remaining nodes # assumes a linear schedule and computes the overlap of the collective with the remaining nodes
comm_time = estimate_op_runtime(collective_snode) comm_time = estimate_op_runtime(collective_snode)
compute_time = 0.0 compute_time = 0.0
@ -286,7 +236,7 @@ def _reorder_communication_preserving_peak_memory_internal(
# we can ignore it. Otherwise, it's the end of the road for overlap opportunities # we can ignore it. Otherwise, it's the end of the road for overlap opportunities
break break
def accumulate_time(_snode: BaseSchedulerNode) -> None: def accumulate_time(_snode):
nonlocal compute_time nonlocal compute_time
compute_time += runtimes[_snode] compute_time += runtimes[_snode]
@ -295,11 +245,18 @@ def _reorder_communication_preserving_peak_memory_internal(
total_moves = 0 total_moves = 0
_prev, _next, _head = _initialize_double_linked_list(snodes) # Dicts to keep track of "next" and "previous" as double-linked structure during grouping
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
def _group_nodes( _head = snodes[0]
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
) -> list[BaseSchedulerNode]: def _group_nodes(head, tail):
ret = [] ret = []
n = head n = head
while True: while True:
@ -307,167 +264,37 @@ def _reorder_communication_preserving_peak_memory_internal(
ret.append(n) ret.append(n)
if n == tail: if n == tail:
break break
n = _next[n] # type: ignore[index] n = _next[n]
return ret return ret
def _perform_double_linked_list_swap(candidate, group_head, group_tail): def _group_names(head, tail):
# swap (candidate, group_head...group_tail) ret = ""
# Before: for n in _group_nodes(head, tail):
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next if ret:
# After: ret += "~"
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next ret += n.get_name()
# 0 return ret
candidate_prev = _prev[candidate]
if candidate_prev:
_next[candidate_prev] = group_head
_prev[group_head] = candidate_prev
# 2
group_tail_next = _next[group_tail]
if group_tail_next:
_prev[group_tail_next] = candidate
_next[candidate] = group_tail_next
# 1
_prev[candidate] = group_tail
_next[group_tail] = candidate
nonlocal _head
if _head == candidate:
_head = group_head
def _calculate_potential_peak_memory(
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate
):
# Caching calculations of memory for group nodes and candidate,
# to apply without recalculation after swap.
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
potential_peak: int = 0
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
# Not accounting for buffers last use change
potential_peak = max(
group_peak_memory - candidate_delta_mem,
_curr_memory[group_tail][1]
- candidate_delta_mem
+ candidate_allocfree.size_alloc,
)
return potential_peak, _post_alloc_update
# If candidate will be after group, the starting memory level of group nodes
# changes to the -(candidate.size_alloc - candidate.size_free)
mem_after_reorder_delta: int = -candidate_delta_mem
for gn in gns:
gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta
_post_alloc_update[gn] = gn_post_alloc_mem
potential_peak = max(potential_peak, gn_post_alloc_mem)
bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None)
if bufs is not None:
for buf in bufs:
# Candidate will deallocate those buffers
mem_after_reorder_delta += buf.mpi_buffer.size_free
candidate_mem_post_alloc = (
_curr_memory[group_tail][1]
+ mem_after_reorder_delta
+ candidate_allocfree.size_alloc
)
_post_alloc_update[candidate] = candidate_mem_post_alloc
potential_peak = max(potential_peak, candidate_mem_post_alloc)
return potential_peak, _post_alloc_update
def _update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_by_candidate,
_post_alloc_update,
):
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
for gn in gns:
cm = _curr_memory[gn]
_curr_memory[gn] = (
cm[0] - candidate_delta_mem,
cm[1] - candidate_delta_mem,
)
_candidate_post_alloc_mem = (
_curr_memory[group_tail][1] + candidate_allocfree.size_alloc
)
_candidate_post_free_mem = (
_candidate_post_alloc_mem - candidate_allocfree.size_free
)
_curr_memory[candidate] = (
_candidate_post_alloc_mem,
_candidate_post_free_mem,
)
return
# Candidate becomes last use of some bufs
for (
gn,
bufs,
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
for buf in bufs:
buf_to_snode_last_use[buf] = candidate
size_free_to_move_to_candidate_sum: int = 0
for n in gns:
_gn_post_alloc_mem: int = _post_alloc_update[n]
size_free_to_move_to_candidate: int = sum(
buf.mpi_buffer.size_free
for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n]
)
size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate
# group node does not deallocate this after swap
snodes_allocfree[n].size_free -= size_free_to_move_to_candidate
gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free
_curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem)
_candidate_post_alloc_mem = _post_alloc_update[candidate]
snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum
candidate_post_free_mem = (
_candidate_post_alloc_mem - snodes_allocfree[candidate].size_free
)
_curr_memory[candidate] = (
_candidate_post_alloc_mem,
candidate_post_free_mem,
)
debug_num_collectives_to_reorder: Optional[int] = (
config.reorder_iterative_debug_limit_to_reorder
)
num_processed_collectives: int = 0
curr = _head curr = _head
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
iterative_recompute_error = False
while _next[curr] is not None: while _next[curr] is not None:
if iterative_recompute_error:
break
if contains_collective(curr): if contains_collective(curr):
if debug_num_collectives_to_reorder is not None and ( reorder_info = stats[curr] = ReorderInfo()
num_processed_collectives >= debug_num_collectives_to_reorder reorder_info.initial_exposed = reorder_info.final_exposed = (
): exposed_communication_time(curr, _group_nodes(_next[curr], None))
break
num_processed_collectives += 1
info = stats[curr] = ReorderInfo()
info.initial_exposed = info.final_exposed = exposed_communication_time(
curr, _group_nodes(_next[curr], None)
) )
candidate = _prev[curr] candidate = _prev[curr]
group_head = curr group_head = curr
group_tail = curr group_tail = curr
group_peak_memory = _curr_memory[curr][0] # post_alloc memory group_peak_memory = _curr_memory[curr]
while candidate is not None: while candidate is not None:
if contains_collective(candidate): if contains_collective(candidate):
info.limiting_factor = "collective ordering" reorder_info.limiting_factor = "collective ordering"
break break
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
group = GroupedSchedulerNode( group = GroupedSchedulerNode(
curr.scheduler, curr.scheduler,
gns, _group_nodes(group_head, group_tail),
temp_grouping=True, temp_grouping=True,
) )
@ -487,9 +314,7 @@ def _reorder_communication_preserving_peak_memory_internal(
if data_dep is not None: if data_dep is not None:
def is_groupable( def is_groupable(candidate):
candidate: BaseSchedulerNode,
) -> tuple[bool, Optional[str]]:
# preserve ordering # preserve ordering
if contains_collective(candidate): if contains_collective(candidate):
return False, "contains_collective" return False, "contains_collective"
@ -498,106 +323,73 @@ def _reorder_communication_preserving_peak_memory_internal(
return False, "contains_gemm_like" return False, "contains_gemm_like"
return True, None return True, None
is_groupable_result, grouping_reason = is_groupable(candidate) is_grp, grp_reason = is_groupable(candidate)
if is_groupable_result: if is_grp:
group_head = candidate group_head = candidate
group_peak_memory = max( group_peak_memory = max(
group_peak_memory, _curr_memory[candidate][0] group_peak_memory, _curr_memory[candidate]
) )
info.grouped += 1 reorder_info.grouped += 1
info.grouped_info = _group_names(gns) reorder_info.grouped_info = _group_names(group_head, group_tail)
candidate = _prev[candidate] candidate = _prev[candidate]
continue continue
else: else:
msg = ( msg = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})" f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(gns)}" f"dep on {_group_names(group_head, group_tail)}"
f"\n non_group_reason:{grouping_reason}" f"\n non_group_reason:{grp_reason}"
) )
info.limiting_factor = msg reorder_info.limiting_factor = msg
break break
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] delta_memory_candidate = (
candidate_delta_mem: int = ( _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
candidate_allocfree.size_alloc - candidate_allocfree.size_free
)
# candidate and one of group nodes are successors of the same buffer
# and last use of the buffer happen in group nodes.
# This last use deallocates it.
# If we swap [candidate [group]] to [[group] candidate],
# candidate becomes the last use
# and deallocated this buffer instead of group node.
# we need to update size_free accordingly to group_node and candidate,
# and recalculate post_alloc, post_free for them.
#
# Buf that changes its last use snode,
# after swap will be deallocated only by candidate,
# while before it was deallocated by group node.
group_n_to_bufs_after_swap_dealloc_by_candidate: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
] = defaultdict(list)
for (
buf,
snode_last_use,
) in buf_to_snode_last_use.items():
succ_nodes = buf.mpi_buffer.succ_nodes
if candidate not in succ_nodes:
continue
if not any(gn == snode_last_use for gn in gns):
continue
group_n_to_bufs_after_swap_dealloc_by_candidate[
snode_last_use
].append(buf)
potential_peak, _post_alloc_update = _calculate_potential_peak_memory(
candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate
) )
if potential_peak > peak_memory: if group_peak_memory - delta_memory_candidate > peak_memory:
info.limiting_factor = ( reorder_info.limiting_factor = "peak memory"
f"peak memory new:{potential_peak} vs base:{peak_memory}"
)
break break
info.moves += 1
reorder_info.moves += 1
total_moves += 1 total_moves += 1
_perform_double_linked_list_swap(candidate, group_head, group_tail) mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# swap (candidate, group_head...group_tail)
# Before:
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
# After:
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
# 0
candidate_prev = _prev[candidate]
if candidate_prev:
_next[candidate_prev] = group_head
_prev[group_head] = candidate_prev
info.final_exposed = exposed_communication_time( # 2
group_tail_next = _next[group_tail]
if group_tail_next:
_prev[group_tail_next] = candidate
_next[candidate] = group_tail_next
# 1
_prev[candidate] = group_tail
_next[group_tail] = candidate
if _head == candidate:
_head = group_head
reorder_info.final_exposed = exposed_communication_time(
curr, _group_nodes(_next[curr], None) curr, _group_nodes(_next[curr], None)
) )
# Recompute curr_memory
_update_memory_tracking_after_swap( _prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index]
candidate, for n in _group_nodes(group_head, candidate):
gns, _curr_memory[n] = _prev_curr_memory = (
group_n_to_bufs_after_swap_dealloc_by_candidate, _prev_curr_memory + mem_deltas[n]
_post_alloc_update,
)
if debug_iterative_memory_recompute:
# Compare iteratively recomputed memory data
# with full run of estimate_peak_memory
from .comms_debug import _debug_iterative_memory_recompute
iterative_recompute_error = _debug_iterative_memory_recompute(
candidate,
gns,
_group_names(gns),
_group_nodes(_head, None),
name_to_freeable_input_buf,
graph_outputs,
peak_memory,
_curr_memory,
snodes_allocfree,
"reorder_communication_preserving_peak_memory",
group_n_to_bufs_after_swap_dealloc_by_candidate,
) )
if iterative_recompute_error:
break
candidate = _prev[group_head] candidate = _prev[group_head]
curr = _next[curr] # type: ignore[assignment] curr = _next[curr] # type: ignore[assignment]
@ -623,15 +415,15 @@ def _reorder_communication_preserving_peak_memory_internal(
rows = [ rows = [
[ [
node_summary(snode), node_summary(snode),
node_info.initial_exposed, node_reorder_info.initial_exposed,
node_info.final_exposed, node_reorder_info.final_exposed,
node_info.improvement, node_reorder_info.improvement,
node_info.limiting_factor, node_reorder_info.limiting_factor,
node_info.moves, node_reorder_info.moves,
node_info.grouped, node_reorder_info.grouped,
node_info.grouped_info, node_reorder_info.grouped_info,
] ]
for snode, node_info in node_stats.items() for snode, node_reorder_info in node_stats.items()
] ]
if importlib.util.find_spec("tabulate"): if importlib.util.find_spec("tabulate"):
from tabulate import tabulate from tabulate import tabulate
@ -649,7 +441,7 @@ def _reorder_communication_preserving_peak_memory_internal(
new_snodes = _group_nodes(_head, None) new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num assert len(new_snodes) == original_snodes_num
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_peak_memory, curr_memory = estimate_peak_memory(
new_snodes, name_to_freeable_input_buf, graph_outputs new_snodes, name_to_freeable_input_buf, graph_outputs
) )
reorder_log_str += f"\n peak_memory_before:{peak_memory}" reorder_log_str += f"\n peak_memory_before:{peak_memory}"
@ -865,21 +657,24 @@ def _sink_waits_iterative_internal(
return snodes, {} return snodes, {}
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
( name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
peak_memory, snodes, graph_inputs
_curr_memory, )
snodes_allocfree, peak_memory, curr_memory = estimate_peak_memory(
buf_to_snode_last_use, snodes, name_to_freeable_input_buf, graph_outputs
name_to_freeable_input_buf, )
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
_prev, _next, _head = _initialize_double_linked_list(snodes)
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_head = snodes[0]
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
def _group_nodes( def _group_nodes(head, tail):
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
) -> list[BaseSchedulerNode]:
ret = [] ret = []
n = head n = head
while True: while True:
@ -887,125 +682,21 @@ def _sink_waits_iterative_internal(
ret.append(n) ret.append(n)
if n == tail: if n == tail:
break break
n = _next[n] # type: ignore[index] n = _next[n]
return ret return ret
def _calculate_potential_peak_memory( def _group_names(head, tail):
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate ret = ""
): for n in _group_nodes(head, tail):
pre_group_mem = ( if ret:
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc ret += "~"
) ret += n.get_name()
# Stash memory tracing updates to not recompute them after swap return ret
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
_size_free_delta_update: dict[BaseSchedulerNode, int] = {}
potential_peak = 0
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
# Not accounting for buffers liveliness change
potential_peak = max(
group_peak_memory + candidate_delta_mem,
pre_group_mem + candidate_allocfree.size_alloc,
)
return potential_peak, _post_alloc_update, _size_free_delta_update
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
_post_alloc_update[candidate] = candidate_post_alloc
potential_peak = candidate_post_alloc
candidate_size_free_to_move = sum(
buf.mpi_buffer.size_free # type: ignore[attr-defined]
for buf in itertools.chain.from_iterable(
group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values()
)
)
_size_free_delta_update[candidate] = -candidate_size_free_to_move
delta_mem = candidate_delta_mem + candidate_size_free_to_move
for gn in gns:
gn_post_alloc = _curr_memory[gn][0] + delta_mem
_post_alloc_update[gn] = gn_post_alloc
potential_peak = max(potential_peak, gn_post_alloc)
gn_size_free_to_add = 0
if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn]
for buf in bufs:
gn_size_free_to_add += buf.mpi_buffer.size_free
_size_free_delta_update[gn] = gn_size_free_to_add
delta_mem -= gn_size_free_to_add
return potential_peak, _post_alloc_update, _size_free_delta_update
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
# 2:
candidate_next = _next[candidate]
if candidate_next:
_prev[candidate_next] = group_tail
_next[group_tail] = candidate_next
# 1:
_prev[group_head] = candidate
_next[candidate] = group_head
nonlocal _head
if group_head == _head:
_head = candidate
def _update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
_post_alloc_update,
_size_free_delta_update,
):
group_head = gns[0]
pre_group_mem = (
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
)
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
_curr_memory[candidate] = (
candidate_post_alloc,
candidate_post_alloc - candidate_allocfree.size_free,
)
for gn in gns:
cm = _curr_memory[gn]
_curr_memory[gn] = (
cm[0] + candidate_delta_mem,
cm[1] + candidate_delta_mem,
)
return
for n in [candidate, *gns]:
post_alloc = _post_alloc_update[n]
snodes_allocfree[n].size_free += _size_free_delta_update[n]
_curr_memory[n] = (
post_alloc,
post_alloc - snodes_allocfree[n].size_free,
)
curr = snodes[-1] curr = snodes[-1]
processed_waits = OrderedSet() # type: ignore[var-annotated] processed_waits = OrderedSet() # type: ignore[var-annotated]
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
debug_num_sink_waits_to_reorder: Optional[int] = (
config.sink_waits_iterative_debug_limit_to_sink
)
iterative_recompute_error = False
while _prev[curr] is not None: while _prev[curr] is not None:
if iterative_recompute_error:
break
if (
debug_num_sink_waits_to_reorder is not None
and len(processed_waits) >= debug_num_sink_waits_to_reorder
):
break
if contains_wait(curr) and curr not in processed_waits: if contains_wait(curr) and curr not in processed_waits:
processed_waits.add(curr) processed_waits.add(curr)
info = stats[curr] = SinkWaitInfo() info = stats[curr] = SinkWaitInfo()
@ -1013,14 +704,11 @@ def _sink_waits_iterative_internal(
wait_snode = curr wait_snode = curr
group_head = curr group_head = curr
group_tail = curr group_tail = curr
group_peak_memory = _curr_memory[curr][0] group_peak_memory = _curr_memory[curr]
while candidate is not None: while candidate is not None:
if iterative_recompute_error:
break
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
group = GroupedSchedulerNode( group = GroupedSchedulerNode(
wait_snode.scheduler, wait_snode.scheduler,
gns, _group_nodes(group_head, group_tail),
temp_grouping=True, temp_grouping=True,
) )
@ -1065,15 +753,15 @@ def _sink_waits_iterative_internal(
if is_grp: if is_grp:
group_tail = candidate group_tail = candidate
group_peak_memory = max( group_peak_memory = max(
group_peak_memory, _curr_memory[candidate][0] group_peak_memory, _curr_memory[candidate]
) )
info.grouped += 1 info.grouped += 1
info.grouped_info = _group_names(gns) info.grouped_info = _group_names(group_head, group_tail)
candidate = _next[candidate] candidate = _next[candidate]
continue continue
elif (data_dep is None) and both_contain_comms: elif (data_dep is None) and both_contain_comms:
info.limiting_factor = ( info.limiting_factor = (
f"collective ordering {_group_names(gns)}" f"collective ordering {_group_names(group_head, group_tail)}"
f" with candidate:{candidate.get_name()}" f" with candidate:{candidate.get_name()}"
) )
break break
@ -1081,89 +769,49 @@ def _sink_waits_iterative_internal(
info.limiting_factor = ( info.limiting_factor = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {gns}" f"dep on {_group_names(group_head, group_tail)}"
f"\n outs:{[o.get_name() for o in group_outs]}" f"\n outs:{[o.get_name() for o in group_outs]}"
f"\n non_group_reason:{grp_reason}" f"\n non_group_reason:{grp_reason}"
) )
break break
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] candidate_delta_memory = (
candidate_delta_mem = ( _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
candidate_allocfree.size_alloc - candidate_allocfree.size_free
) )
# [group] candidate -> candidate [group] if group_peak_memory + candidate_delta_memory > peak_memory:
# Check for buffers with successors in group and candidate last successor info.limiting_factor = "peak_memory"
#
# Buf that changes its last use snode,
# It was deallocated by candidate,
# but after swap it will be deallocated by group node.
group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
] = defaultdict(list)
for (
buf,
snode_last_use,
) in buf_to_snode_last_use.items():
succ_nodes = buf.mpi_buffer.succ_nodes
if snode_last_use != candidate: # noqa: E711
continue
# candidate is last use of buf
last_succ_gn = None
for gn in gns:
if gn in succ_nodes:
last_succ_gn = gn
if last_succ_gn is None:
continue
# gn has successors of buf that after potential swap will become
# last use of buf and start deallocating buf instead of candidate
group_n_to_bufs_after_swap_dealloc_instead_of_candidate[
last_succ_gn
].append(buf)
potential_peak, _post_alloc_update, _size_free_delta_update = (
_calculate_potential_peak_memory(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
)
)
if potential_peak > peak_memory:
info.limiting_factor = (
f"peak memory new:{potential_peak} vs base:{peak_memory}"
)
break break
info.moves += 1 info.moves += 1
info.moves_info += f"+{candidate.get_name()}" info.moves_info += f"+{candidate.get_name()}"
_perform_double_linked_list_swap(candidate, group_head, group_tail) # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
_update_memory_tracking_after_swap( # 2:
candidate, candidate_next = _next[candidate]
gns, if candidate_next:
group_n_to_bufs_after_swap_dealloc_instead_of_candidate, _prev[candidate_next] = group_tail
_post_alloc_update, _next[group_tail] = candidate_next
_size_free_delta_update,
)
if debug_iterative_memory_recompute: # 1:
from .comms_debug import _debug_iterative_memory_recompute _prev[group_head] = candidate
_next[candidate] = group_head
if group_head == _head:
_head = candidate
iterative_recompute_error = _debug_iterative_memory_recompute( # Recompute curr_memory
candidate, _prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index]
gns, for n in _group_nodes(candidate, group_tail):
_group_names(gns), _curr_memory[n] = _prev_curr_memory = (
_group_nodes(_head, None), _prev_curr_memory + mem_deltas[n]
name_to_freeable_input_buf,
graph_outputs,
peak_memory,
_curr_memory,
snodes_allocfree,
"sink_waits_iterative",
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
) )
if iterative_recompute_error:
break
candidate = _next[group_tail] candidate = _next[group_tail]
curr = _prev[curr] # type: ignore[assignment] curr = _prev[curr] # type: ignore[assignment]
@ -1202,11 +850,11 @@ def _sink_waits_iterative_internal(
overlap_log.info(log_str) overlap_log.info(log_str)
new_snodes = _group_nodes(_head, None) new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num assert len(new_snodes) == original_snodes_num
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_peak_memory, curr_memory = estimate_peak_memory(
new_snodes, name_to_freeable_input_buf, graph_outputs new_snodes, name_to_freeable_input_buf, graph_outputs
) )
log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}" log_str += f"\n peak_memory_before:{peak_memory}"
log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}" log_str += f"\n peak_memory_after:{new_peak_memory}"
trace_structured( trace_structured(
"artifact", "artifact",
metadata_fn=lambda: { metadata_fn=lambda: {

View File

@ -1,112 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from torch._logging import trace_structured
from .memory import estimate_peak_memory_allocfree
if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
from .memory import FreeableInputBuffer, SNodeMemory
from .scheduler import BaseSchedulerNode, SchedulerBuffer
def _debug_iterative_memory_recompute(
candidate: BaseSchedulerNode,
gns: list[BaseSchedulerNode],
group_names: str,
snodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
peak_memory: int,
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
tlparse_name: str,
gn_to_bufs_last_use: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
],
) -> bool:
iterative_recompute_error = False
candidate_allocfree = snodes_allocfree[candidate]
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
iter_cm = iter_curr_memory[candidate]
new_cm = est_curr_memory[candidate]
log = ""
if est_peak_memory > peak_memory:
log = "ITERATIVE PEAK DOES NOT MATCH"
iterative_recompute_error = True
if iter_cm != new_cm:
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
iterative_recompute_error = True
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
if iter_gnm != new_gnm:
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
iterative_recompute_error = True
if iterative_recompute_error:
log += (
f"\nCANDIDATE:{candidate.get_name()}"
f"\nGROUP:{group_names}"
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
f"\nCANDIDATE:{candidate.debug_str()}"
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
)
peak_log = ""
for i, (pre, post) in enumerate(snodes_curr_memory):
if est_peak_memory == pre:
n = snodes[i]
peak_log = (
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
)
break
group_log = ""
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
group_log += (
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
)
log += peak_log
log += group_log
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
log += "\n\n".join(
[
(
f"\nSNODE[{i}]\n{n.debug_str()}"
f"\nITER_cur_mem:{iter_curr_memory[n]}"
f"\nESTM_cur_mem:{est_curr_memory[n]}"
f"\nITER_allocfree:{snodes_allocfree[n]}"
f"\nESTM_allocfree:{snodes_allocfree[n]}"
)
for i, n in enumerate(snodes)
]
)
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
print(f"{tname}:\n{log}")
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": tname,
"encoding": "string",
},
payload_fn=lambda: log,
)
return iterative_recompute_error

View File

@ -389,16 +389,6 @@ reorder_prefetch_limit: Optional[int] = None
# enable operator reordering for peak memory optimization # enable operator reordering for peak memory optimization
reorder_for_peak_memory = True reorder_for_peak_memory = True
reorder_iterative_debug_memory_recompute: bool = False
reorder_iterative_debug_limit_to_reorder: Optional[int] = (
None
if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None
else int(env_str)
)
sink_waits_iterative_debug_limit_to_sink: Optional[int] = (
None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str)
)
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none" bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None

View File

@ -4,7 +4,7 @@ import collections
import dataclasses import dataclasses
import heapq import heapq
import logging import logging
from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union from typing import Callable, TYPE_CHECKING, TypedDict, Union
from torch._environment import is_fbcode from torch._environment import is_fbcode
from torch._utils_internal import signpost_event from torch._utils_internal import signpost_event
@ -76,7 +76,7 @@ def get_freeable_input_buf(
Create and keep track of all input buffers that can be freed during the program Create and keep track of all input buffers that can be freed during the program
Returns: Returns:
A dictionary containing all freeable input buffers, keyed by their names. A dictionary containing all freeble input buffers, keyed by their names.
""" """
def _dep_size_hint(dep: Dep) -> int: def _dep_size_hint(dep: Dep) -> int:
@ -303,11 +303,7 @@ def compute_memory_timeline(
nodes: list[BaseSchedulerNode], nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer], name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str], graph_outputs: OrderedSet[str],
) -> tuple[ ) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
list[BufferInfo],
dict[BaseSchedulerNode, int],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
""" """
Compute buffer allocation and deallocation sizes and map their Compute buffer allocation and deallocation sizes and map their
lifetime to the node schedule lifetime to the node schedule
@ -321,33 +317,15 @@ def compute_memory_timeline(
# get buffers' size and liveliness information # get buffers' size and liveliness information
buf_info_list: list[BufferInfo] = [] buf_info_list: list[BufferInfo] = []
buf_to_snode_last_use: dict[
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
] = {}
def _get_end_step_and_snode(
buf: Union[FreeableInputBuffer, SchedulerBuffer],
) -> tuple[int, Optional[BaseSchedulerNode]]:
max_step: int = -1
max_step_snode: Optional[BaseSchedulerNode] = None
succ_nodes = buf.mpi_buffer.succ_nodes
if succ_nodes:
for succ_node in succ_nodes:
step = node_to_step[succ_node]
if step > max_step:
max_step = step
max_step_snode = succ_node
assert max_step_snode is not None
return max_step, max_step_snode
# 1. for freeable input buffers # 1. for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items(): for buf_name, input_buf in name_to_freeable_input_buf.items():
end_step = -1 end_step = (
if buf_name not in graph_outputs: len(nodes) - 1
end_step, end_step_snode = _get_end_step_and_snode(input_buf) if buf_name in graph_outputs
assert end_step_snode is not None else max(
buf_to_snode_last_use[input_buf] = end_step_snode node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes
)
)
buf_info_list.append( buf_info_list.append(
BufferInfo( BufferInfo(
input_buf, input_buf,
@ -364,17 +342,17 @@ def compute_memory_timeline(
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and # note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
# to be only used by its defining op (e.g., due to fusion when all consumers of # to be only used by its defining op (e.g., due to fusion when all consumers of
# the buffer are fused with its defining op). In such cases, end_step is step. # the buffer are fused with its defining op). In such cases, end_step is step.
buf_name = sched_buf.get_name() end_step = (
end_step = -1 len(nodes) - 1
if buf_name not in graph_outputs: if sched_buf.get_name() in graph_outputs
end_step, end_step_snode = _get_end_step_and_snode(sched_buf) else max(
if end_step == -1: [
end_step = step node_to_step[succ_node]
buf_to_snode_last_use[sched_buf] = node for succ_node in sched_buf.mpi_buffer.succ_nodes
else: ],
assert end_step_snode is not None default=step,
buf_to_snode_last_use[sched_buf] = end_step_snode )
)
buf_info_list.append( buf_info_list.append(
BufferInfo( BufferInfo(
sched_buf, sched_buf,
@ -385,7 +363,7 @@ def compute_memory_timeline(
) )
) )
return buf_info_list, node_to_step, buf_to_snode_last_use return buf_info_list, node_to_step
def estimate_peak_memory( def estimate_peak_memory(
@ -395,84 +373,35 @@ def estimate_peak_memory(
) -> tuple[int, list[int]]: ) -> tuple[int, list[int]]:
""" """
Given a list of nodes in their execution order, estimate the peak memory, by Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveness of SchedulerBuffers and FreeableInputBuffers. keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
Returns: Returns:
int: peak memory int: peak memory
List[int]: memory usage at each node (or each step). List[int]: memory usage at each node (or each step).
""" """
# Use estimate_peak_memory_allocfree to keep one impl.
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
estimate_peak_memory_allocfree(nodes, name_to_freeable_input_buf, graph_outputs)
)
return peak_memory, [(curr_mem[0] + curr_mem[1]) for curr_mem in snodes_curr_memory]
buf_info_list, _ = compute_memory_timeline(
@dataclasses.dataclass
class SNodeMemory:
size_alloc: int
size_free: int
def estimate_peak_memory_allocfree(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[
int,
list[tuple[int, int]],
dict[BaseSchedulerNode, SNodeMemory],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Alternative version of estimate_peak_memory, that respects the fact,
that every SchedulerNode has multiple phases:
1. alloc ( outputs )
2. run_kernel
3. dealloc last_use buffers
estimate_peak_memory collapses memory into one value: size_alloc - size_free
While peak memory happens after alloc.
Duplicating the code to not migrate all callsites at once,
In future usages of estimate_peak_memory will migrate to this version.
"""
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs nodes, name_to_freeable_input_buf, graph_outputs
) )
# incremental memory changes at each step # incremental memory changes at each step
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))] memory = [0 for _ in range(len(nodes) + 1)]
# for each buffer, update memory when created and when freed # for each buffer, update memory when created and when freed
for buf_info in buf_info_list: for buf_info in buf_info_list:
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc memory[buf_info.start_step] += buf_info.size_alloc
if buf_info.end_step != -1: memory[buf_info.end_step + 1] -= buf_info.size_free
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
snodes_allocfree = {}
for i, node in enumerate(nodes):
snodes_allocfree[node] = step_idx_allocfree[i]
# get peak memory by compute the cumulative memories
max_memory = 0 max_memory = 0
cur_memory = 0 cur_memory = 0
snodes_curr_memory = [] memories_at_nodes = []
for t in range(len(nodes)): for t in range(len(nodes) + 1):
alloc = step_idx_allocfree[t].size_alloc cur_memory += memory[t]
free = step_idx_allocfree[t].size_free memories_at_nodes.append(cur_memory)
cur_memory += alloc
post_alloc = cur_memory
max_memory = max(max_memory, cur_memory) max_memory = max(max_memory, cur_memory)
cur_memory -= free
post_free = cur_memory
snodes_curr_memory.append((post_alloc, post_free))
return ( return (max_memory, memories_at_nodes)
max_memory,
snodes_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
)
def topological_sort_lpmf( def topological_sort_lpmf(
@ -488,7 +417,7 @@ def topological_sort_lpmf(
Buffer memory optimization for video codec application modeled in Simulink Buffer memory optimization for video codec application modeled in Simulink
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
The algorithm maintains the max memory so far. The algorithm maintain the max memory so far.
At every iteration, for each scheduleable node, it computes: At every iteration, for each scheduleable node, it computes:
- how much memory needs to be allocated for the output buffers of this node; - how much memory needs to be allocated for the output buffers of this node;
- how much memory can be freed as a result of executing this node. - how much memory can be freed as a result of executing this node.

View File

@ -2160,12 +2160,6 @@ class Scheduler:
OrderedSet(V.graph.get_output_names()), OrderedSet(V.graph.get_output_names()),
) )
if config.reorder_for_compute_comm_overlap: if config.reorder_for_compute_comm_overlap:
if not config.reorder_for_peak_memory:
from .memory import assign_memory_planning_info_for_scheduler_buffers
assign_memory_planning_info_for_scheduler_buffers(
self.nodes, self.name_to_buf
)
from torch._logging import trace_structured from torch._logging import trace_structured
trace_structured( trace_structured(
@ -2562,7 +2556,7 @@ class Scheduler:
) )
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
buf_info_list, _, _ = compute_memory_timeline( buf_info_list, _ = compute_memory_timeline(
self.nodes, self.nodes,
name_to_freeable_input_buf, name_to_freeable_input_buf,
graph_outputs, graph_outputs,