Revert "[inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)"

This reverts commit 517d38d3406abbba35d0694bff259a698cad3ec9.

Reverted https://github.com/pytorch/pytorch/pull/160113 on behalf of https://github.com/IvanKobzarev due to Segment tree starts failing on trunk even ciflows/trunk passed on PR ([comment](https://github.com/pytorch/pytorch/pull/160113#issuecomment-3211286092))
This commit is contained in:
PyTorch MergeBot
2025-08-21 16:22:44 +00:00
parent 517d38d340
commit 7006fd0c88
5 changed files with 190 additions and 741 deletions

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import heapq
import importlib
import itertools
import logging
import operator
import sys
@ -24,15 +23,8 @@ from .dependencies import WeakDep
if TYPE_CHECKING:
from .ir import IRNode, Operation
from .scheduler import SchedulerBuffer
from .memory import (
estimate_peak_memory,
estimate_peak_memory_allocfree,
FreeableInputBuffer,
get_freeable_input_buf,
SNodeMemory,
)
from .memory import estimate_peak_memory, FreeableInputBuffer, get_freeable_input_buf
from .utils import (
contains_collective,
contains_wait,
@ -196,46 +188,6 @@ def _is_fake_dep(d):
return isinstance(d, WeakDep) and d.is_fake
def _group_names(gns: list[BaseSchedulerNode]) -> str:
return "~".join([gn.get_name() for gn in gns])
def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs):
"""Initialize memory tracking data structures"""
name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs)
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
_curr_memory = dict(zip(snodes, snodes_curr_memory))
_curr_memory[None] = (0, 0)
return (
peak_memory,
_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
name_to_freeable_input_buf,
)
def _initialize_double_linked_list(
snodes: list[BaseSchedulerNode],
) -> tuple[
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
BaseSchedulerNode,
]:
"""Create double-linked list structure from snodes"""
_prev = {}
_next = {}
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_head = snodes[0]
return _prev, _next, _head
def _reorder_communication_preserving_peak_memory_internal(
snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
@ -259,22 +211,20 @@ def _reorder_communication_preserving_peak_memory_internal(
# heuristic to avoid degenerating to quadratic time
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
(
peak_memory,
_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
name_to_freeable_input_buf,
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
runtimes: dict[BaseSchedulerNode, float] = {
snode: estimate_op_runtime(snode) for snode in snodes
}
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
snodes, graph_inputs
)
peak_memory, curr_memory = estimate_peak_memory(
snodes, name_to_freeable_input_buf, graph_outputs
)
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
# debug stats
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
def exposed_communication_time(
collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode]
) -> float:
def exposed_communication_time(collective_snode, remaining_snodes):
# assumes a linear schedule and computes the overlap of the collective with the remaining nodes
comm_time = estimate_op_runtime(collective_snode)
compute_time = 0.0
@ -286,7 +236,7 @@ def _reorder_communication_preserving_peak_memory_internal(
# we can ignore it. Otherwise, it's the end of the road for overlap opportunities
break
def accumulate_time(_snode: BaseSchedulerNode) -> None:
def accumulate_time(_snode):
nonlocal compute_time
compute_time += runtimes[_snode]
@ -295,11 +245,18 @@ def _reorder_communication_preserving_peak_memory_internal(
total_moves = 0
_prev, _next, _head = _initialize_double_linked_list(snodes)
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
def _group_nodes(
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
) -> list[BaseSchedulerNode]:
_head = snodes[0]
def _group_nodes(head, tail):
ret = []
n = head
while True:
@ -307,167 +264,37 @@ def _reorder_communication_preserving_peak_memory_internal(
ret.append(n)
if n == tail:
break
n = _next[n] # type: ignore[index]
n = _next[n]
return ret
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
# swap (candidate, group_head...group_tail)
# Before:
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
# After:
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
# 0
candidate_prev = _prev[candidate]
if candidate_prev:
_next[candidate_prev] = group_head
_prev[group_head] = candidate_prev
def _group_names(head, tail):
ret = ""
for n in _group_nodes(head, tail):
if ret:
ret += "~"
ret += n.get_name()
return ret
# 2
group_tail_next = _next[group_tail]
if group_tail_next:
_prev[group_tail_next] = candidate
_next[candidate] = group_tail_next
# 1
_prev[candidate] = group_tail
_next[group_tail] = candidate
nonlocal _head
if _head == candidate:
_head = group_head
def _calculate_potential_peak_memory(
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate
):
# Caching calculations of memory for group nodes and candidate,
# to apply without recalculation after swap.
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
potential_peak: int = 0
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
# Not accounting for buffers last use change
potential_peak = max(
group_peak_memory - candidate_delta_mem,
_curr_memory[group_tail][1]
- candidate_delta_mem
+ candidate_allocfree.size_alloc,
)
return potential_peak, _post_alloc_update
# If candidate will be after group, the starting memory level of group nodes
# changes to the -(candidate.size_alloc - candidate.size_free)
mem_after_reorder_delta: int = -candidate_delta_mem
for gn in gns:
gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta
_post_alloc_update[gn] = gn_post_alloc_mem
potential_peak = max(potential_peak, gn_post_alloc_mem)
bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None)
if bufs is not None:
for buf in bufs:
# Candidate will deallocate those buffers
mem_after_reorder_delta += buf.mpi_buffer.size_free
candidate_mem_post_alloc = (
_curr_memory[group_tail][1]
+ mem_after_reorder_delta
+ candidate_allocfree.size_alloc
)
_post_alloc_update[candidate] = candidate_mem_post_alloc
potential_peak = max(potential_peak, candidate_mem_post_alloc)
return potential_peak, _post_alloc_update
def _update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_by_candidate,
_post_alloc_update,
):
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
for gn in gns:
cm = _curr_memory[gn]
_curr_memory[gn] = (
cm[0] - candidate_delta_mem,
cm[1] - candidate_delta_mem,
)
_candidate_post_alloc_mem = (
_curr_memory[group_tail][1] + candidate_allocfree.size_alloc
)
_candidate_post_free_mem = (
_candidate_post_alloc_mem - candidate_allocfree.size_free
)
_curr_memory[candidate] = (
_candidate_post_alloc_mem,
_candidate_post_free_mem,
)
return
# Candidate becomes last use of some bufs
for (
gn,
bufs,
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
for buf in bufs:
buf_to_snode_last_use[buf] = candidate
size_free_to_move_to_candidate_sum: int = 0
for n in gns:
_gn_post_alloc_mem: int = _post_alloc_update[n]
size_free_to_move_to_candidate: int = sum(
buf.mpi_buffer.size_free
for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n]
)
size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate
# group node does not deallocate this after swap
snodes_allocfree[n].size_free -= size_free_to_move_to_candidate
gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free
_curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem)
_candidate_post_alloc_mem = _post_alloc_update[candidate]
snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum
candidate_post_free_mem = (
_candidate_post_alloc_mem - snodes_allocfree[candidate].size_free
)
_curr_memory[candidate] = (
_candidate_post_alloc_mem,
candidate_post_free_mem,
)
debug_num_collectives_to_reorder: Optional[int] = (
config.reorder_iterative_debug_limit_to_reorder
)
num_processed_collectives: int = 0
curr = _head
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
iterative_recompute_error = False
while _next[curr] is not None:
if iterative_recompute_error:
break
if contains_collective(curr):
if debug_num_collectives_to_reorder is not None and (
num_processed_collectives >= debug_num_collectives_to_reorder
):
break
num_processed_collectives += 1
info = stats[curr] = ReorderInfo()
info.initial_exposed = info.final_exposed = exposed_communication_time(
curr, _group_nodes(_next[curr], None)
reorder_info = stats[curr] = ReorderInfo()
reorder_info.initial_exposed = reorder_info.final_exposed = (
exposed_communication_time(curr, _group_nodes(_next[curr], None))
)
candidate = _prev[curr]
group_head = curr
group_tail = curr
group_peak_memory = _curr_memory[curr][0] # post_alloc memory
group_peak_memory = _curr_memory[curr]
while candidate is not None:
if contains_collective(candidate):
info.limiting_factor = "collective ordering"
reorder_info.limiting_factor = "collective ordering"
break
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
group = GroupedSchedulerNode(
curr.scheduler,
gns,
_group_nodes(group_head, group_tail),
temp_grouping=True,
)
@ -487,9 +314,7 @@ def _reorder_communication_preserving_peak_memory_internal(
if data_dep is not None:
def is_groupable(
candidate: BaseSchedulerNode,
) -> tuple[bool, Optional[str]]:
def is_groupable(candidate):
# preserve ordering
if contains_collective(candidate):
return False, "contains_collective"
@ -498,106 +323,73 @@ def _reorder_communication_preserving_peak_memory_internal(
return False, "contains_gemm_like"
return True, None
is_groupable_result, grouping_reason = is_groupable(candidate)
if is_groupable_result:
is_grp, grp_reason = is_groupable(candidate)
if is_grp:
group_head = candidate
group_peak_memory = max(
group_peak_memory, _curr_memory[candidate][0]
group_peak_memory, _curr_memory[candidate]
)
info.grouped += 1
info.grouped_info = _group_names(gns)
reorder_info.grouped += 1
reorder_info.grouped_info = _group_names(group_head, group_tail)
candidate = _prev[candidate]
continue
else:
msg = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(gns)}"
f"\n non_group_reason:{grouping_reason}"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {_group_names(group_head, group_tail)}"
f"\n non_group_reason:{grp_reason}"
)
info.limiting_factor = msg
reorder_info.limiting_factor = msg
break
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
candidate_delta_mem: int = (
candidate_allocfree.size_alloc - candidate_allocfree.size_free
)
# candidate and one of group nodes are successors of the same buffer
# and last use of the buffer happen in group nodes.
# This last use deallocates it.
# If we swap [candidate [group]] to [[group] candidate],
# candidate becomes the last use
# and deallocated this buffer instead of group node.
# we need to update size_free accordingly to group_node and candidate,
# and recalculate post_alloc, post_free for them.
#
# Buf that changes its last use snode,
# after swap will be deallocated only by candidate,
# while before it was deallocated by group node.
group_n_to_bufs_after_swap_dealloc_by_candidate: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
] = defaultdict(list)
for (
buf,
snode_last_use,
) in buf_to_snode_last_use.items():
succ_nodes = buf.mpi_buffer.succ_nodes
if candidate not in succ_nodes:
continue
if not any(gn == snode_last_use for gn in gns):
continue
group_n_to_bufs_after_swap_dealloc_by_candidate[
snode_last_use
].append(buf)
potential_peak, _post_alloc_update = _calculate_potential_peak_memory(
candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate
delta_memory_candidate = (
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
)
if potential_peak > peak_memory:
info.limiting_factor = (
f"peak memory new:{potential_peak} vs base:{peak_memory}"
)
if group_peak_memory - delta_memory_candidate > peak_memory:
reorder_info.limiting_factor = "peak memory"
break
info.moves += 1
reorder_info.moves += 1
total_moves += 1
_perform_double_linked_list_swap(candidate, group_head, group_tail)
mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# swap (candidate, group_head...group_tail)
# Before:
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
# After:
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
# 0
candidate_prev = _prev[candidate]
if candidate_prev:
_next[candidate_prev] = group_head
_prev[group_head] = candidate_prev
info.final_exposed = exposed_communication_time(
# 2
group_tail_next = _next[group_tail]
if group_tail_next:
_prev[group_tail_next] = candidate
_next[candidate] = group_tail_next
# 1
_prev[candidate] = group_tail
_next[group_tail] = candidate
if _head == candidate:
_head = group_head
reorder_info.final_exposed = exposed_communication_time(
curr, _group_nodes(_next[curr], None)
)
_update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_by_candidate,
_post_alloc_update,
)
if debug_iterative_memory_recompute:
# Compare iteratively recomputed memory data
# with full run of estimate_peak_memory
from .comms_debug import _debug_iterative_memory_recompute
iterative_recompute_error = _debug_iterative_memory_recompute(
candidate,
gns,
_group_names(gns),
_group_nodes(_head, None),
name_to_freeable_input_buf,
graph_outputs,
peak_memory,
_curr_memory,
snodes_allocfree,
"reorder_communication_preserving_peak_memory",
group_n_to_bufs_after_swap_dealloc_by_candidate,
# Recompute curr_memory
_prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index]
for n in _group_nodes(group_head, candidate):
_curr_memory[n] = _prev_curr_memory = (
_prev_curr_memory + mem_deltas[n]
)
if iterative_recompute_error:
break
candidate = _prev[group_head]
curr = _next[curr] # type: ignore[assignment]
@ -623,15 +415,15 @@ def _reorder_communication_preserving_peak_memory_internal(
rows = [
[
node_summary(snode),
node_info.initial_exposed,
node_info.final_exposed,
node_info.improvement,
node_info.limiting_factor,
node_info.moves,
node_info.grouped,
node_info.grouped_info,
node_reorder_info.initial_exposed,
node_reorder_info.final_exposed,
node_reorder_info.improvement,
node_reorder_info.limiting_factor,
node_reorder_info.moves,
node_reorder_info.grouped,
node_reorder_info.grouped_info,
]
for snode, node_info in node_stats.items()
for snode, node_reorder_info in node_stats.items()
]
if importlib.util.find_spec("tabulate"):
from tabulate import tabulate
@ -649,7 +441,7 @@ def _reorder_communication_preserving_peak_memory_internal(
new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
new_peak_memory, curr_memory = estimate_peak_memory(
new_snodes, name_to_freeable_input_buf, graph_outputs
)
reorder_log_str += f"\n peak_memory_before:{peak_memory}"
@ -865,21 +657,24 @@ def _sink_waits_iterative_internal(
return snodes, {}
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
(
peak_memory,
_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
name_to_freeable_input_buf,
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
_prev, _next, _head = _initialize_double_linked_list(snodes)
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
snodes, graph_inputs
)
peak_memory, curr_memory = estimate_peak_memory(
snodes, name_to_freeable_input_buf, graph_outputs
)
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
_head = snodes[0]
for i, snode in enumerate(snodes):
_prev[snode] = snodes[i - 1] if i > 0 else None
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
_curr_memory = dict(zip(snodes, curr_memory))
_curr_memory[None] = 0 # type: ignore[index]
def _group_nodes(
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
) -> list[BaseSchedulerNode]:
def _group_nodes(head, tail):
ret = []
n = head
while True:
@ -887,125 +682,21 @@ def _sink_waits_iterative_internal(
ret.append(n)
if n == tail:
break
n = _next[n] # type: ignore[index]
n = _next[n]
return ret
def _calculate_potential_peak_memory(
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate
):
pre_group_mem = (
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
)
# Stash memory tracing updates to not recompute them after swap
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
_size_free_delta_update: dict[BaseSchedulerNode, int] = {}
potential_peak = 0
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
# Not accounting for buffers liveliness change
potential_peak = max(
group_peak_memory + candidate_delta_mem,
pre_group_mem + candidate_allocfree.size_alloc,
)
return potential_peak, _post_alloc_update, _size_free_delta_update
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
_post_alloc_update[candidate] = candidate_post_alloc
potential_peak = candidate_post_alloc
candidate_size_free_to_move = sum(
buf.mpi_buffer.size_free # type: ignore[attr-defined]
for buf in itertools.chain.from_iterable(
group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values()
)
)
_size_free_delta_update[candidate] = -candidate_size_free_to_move
delta_mem = candidate_delta_mem + candidate_size_free_to_move
for gn in gns:
gn_post_alloc = _curr_memory[gn][0] + delta_mem
_post_alloc_update[gn] = gn_post_alloc
potential_peak = max(potential_peak, gn_post_alloc)
gn_size_free_to_add = 0
if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn]
for buf in bufs:
gn_size_free_to_add += buf.mpi_buffer.size_free
_size_free_delta_update[gn] = gn_size_free_to_add
delta_mem -= gn_size_free_to_add
return potential_peak, _post_alloc_update, _size_free_delta_update
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
# 2:
candidate_next = _next[candidate]
if candidate_next:
_prev[candidate_next] = group_tail
_next[group_tail] = candidate_next
# 1:
_prev[group_head] = candidate
_next[candidate] = group_head
nonlocal _head
if group_head == _head:
_head = candidate
def _update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
_post_alloc_update,
_size_free_delta_update,
):
group_head = gns[0]
pre_group_mem = (
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
)
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
_curr_memory[candidate] = (
candidate_post_alloc,
candidate_post_alloc - candidate_allocfree.size_free,
)
for gn in gns:
cm = _curr_memory[gn]
_curr_memory[gn] = (
cm[0] + candidate_delta_mem,
cm[1] + candidate_delta_mem,
)
return
for n in [candidate, *gns]:
post_alloc = _post_alloc_update[n]
snodes_allocfree[n].size_free += _size_free_delta_update[n]
_curr_memory[n] = (
post_alloc,
post_alloc - snodes_allocfree[n].size_free,
)
def _group_names(head, tail):
ret = ""
for n in _group_nodes(head, tail):
if ret:
ret += "~"
ret += n.get_name()
return ret
curr = snodes[-1]
processed_waits = OrderedSet() # type: ignore[var-annotated]
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
debug_num_sink_waits_to_reorder: Optional[int] = (
config.sink_waits_iterative_debug_limit_to_sink
)
iterative_recompute_error = False
while _prev[curr] is not None:
if iterative_recompute_error:
break
if (
debug_num_sink_waits_to_reorder is not None
and len(processed_waits) >= debug_num_sink_waits_to_reorder
):
break
if contains_wait(curr) and curr not in processed_waits:
processed_waits.add(curr)
info = stats[curr] = SinkWaitInfo()
@ -1013,14 +704,11 @@ def _sink_waits_iterative_internal(
wait_snode = curr
group_head = curr
group_tail = curr
group_peak_memory = _curr_memory[curr][0]
group_peak_memory = _curr_memory[curr]
while candidate is not None:
if iterative_recompute_error:
break
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
group = GroupedSchedulerNode(
wait_snode.scheduler,
gns,
_group_nodes(group_head, group_tail),
temp_grouping=True,
)
@ -1065,15 +753,15 @@ def _sink_waits_iterative_internal(
if is_grp:
group_tail = candidate
group_peak_memory = max(
group_peak_memory, _curr_memory[candidate][0]
group_peak_memory, _curr_memory[candidate]
)
info.grouped += 1
info.grouped_info = _group_names(gns)
info.grouped_info = _group_names(group_head, group_tail)
candidate = _next[candidate]
continue
elif (data_dep is None) and both_contain_comms:
info.limiting_factor = (
f"collective ordering {_group_names(gns)}"
f"collective ordering {_group_names(group_head, group_tail)}"
f" with candidate:{candidate.get_name()}"
)
break
@ -1081,89 +769,49 @@ def _sink_waits_iterative_internal(
info.limiting_factor = (
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
f"dep on {gns}"
f"dep on {_group_names(group_head, group_tail)}"
f"\n outs:{[o.get_name() for o in group_outs]}"
f"\n non_group_reason:{grp_reason}"
)
break
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
candidate_delta_mem = (
candidate_allocfree.size_alloc - candidate_allocfree.size_free
candidate_delta_memory = (
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
)
# [group] candidate -> candidate [group]
# Check for buffers with successors in group and candidate last successor
#
# Buf that changes its last use snode,
# It was deallocated by candidate,
# but after swap it will be deallocated by group node.
group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
] = defaultdict(list)
for (
buf,
snode_last_use,
) in buf_to_snode_last_use.items():
succ_nodes = buf.mpi_buffer.succ_nodes
if snode_last_use != candidate: # noqa: E711
continue
# candidate is last use of buf
last_succ_gn = None
for gn in gns:
if gn in succ_nodes:
last_succ_gn = gn
if last_succ_gn is None:
continue
# gn has successors of buf that after potential swap will become
# last use of buf and start deallocating buf instead of candidate
group_n_to_bufs_after_swap_dealloc_instead_of_candidate[
last_succ_gn
].append(buf)
potential_peak, _post_alloc_update, _size_free_delta_update = (
_calculate_potential_peak_memory(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
)
)
if potential_peak > peak_memory:
info.limiting_factor = (
f"peak memory new:{potential_peak} vs base:{peak_memory}"
)
if group_peak_memory + candidate_delta_memory > peak_memory:
info.limiting_factor = "peak_memory"
break
info.moves += 1
info.moves_info += f"+{candidate.get_name()}"
_perform_double_linked_list_swap(candidate, group_head, group_tail)
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
mem_deltas = {}
for n in [candidate, *_group_nodes(group_head, group_tail)]:
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
# 0:
group_head_prev = _prev[group_head]
if group_head_prev:
_next[group_head_prev] = candidate
_prev[candidate] = group_head_prev
_update_memory_tracking_after_swap(
candidate,
gns,
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
_post_alloc_update,
_size_free_delta_update,
)
# 2:
candidate_next = _next[candidate]
if candidate_next:
_prev[candidate_next] = group_tail
_next[group_tail] = candidate_next
if debug_iterative_memory_recompute:
from .comms_debug import _debug_iterative_memory_recompute
# 1:
_prev[group_head] = candidate
_next[candidate] = group_head
if group_head == _head:
_head = candidate
iterative_recompute_error = _debug_iterative_memory_recompute(
candidate,
gns,
_group_names(gns),
_group_nodes(_head, None),
name_to_freeable_input_buf,
graph_outputs,
peak_memory,
_curr_memory,
snodes_allocfree,
"sink_waits_iterative",
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
# Recompute curr_memory
_prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index]
for n in _group_nodes(candidate, group_tail):
_curr_memory[n] = _prev_curr_memory = (
_prev_curr_memory + mem_deltas[n]
)
if iterative_recompute_error:
break
candidate = _next[group_tail]
curr = _prev[curr] # type: ignore[assignment]
@ -1202,11 +850,11 @@ def _sink_waits_iterative_internal(
overlap_log.info(log_str)
new_snodes = _group_nodes(_head, None)
assert len(new_snodes) == original_snodes_num
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
new_peak_memory, curr_memory = estimate_peak_memory(
new_snodes, name_to_freeable_input_buf, graph_outputs
)
log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}"
log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}"
log_str += f"\n peak_memory_before:{peak_memory}"
log_str += f"\n peak_memory_after:{new_peak_memory}"
trace_structured(
"artifact",
metadata_fn=lambda: {

View File

@ -1,112 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from torch._logging import trace_structured
from .memory import estimate_peak_memory_allocfree
if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
from .memory import FreeableInputBuffer, SNodeMemory
from .scheduler import BaseSchedulerNode, SchedulerBuffer
def _debug_iterative_memory_recompute(
candidate: BaseSchedulerNode,
gns: list[BaseSchedulerNode],
group_names: str,
snodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
peak_memory: int,
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
tlparse_name: str,
gn_to_bufs_last_use: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
],
) -> bool:
iterative_recompute_error = False
candidate_allocfree = snodes_allocfree[candidate]
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
iter_cm = iter_curr_memory[candidate]
new_cm = est_curr_memory[candidate]
log = ""
if est_peak_memory > peak_memory:
log = "ITERATIVE PEAK DOES NOT MATCH"
iterative_recompute_error = True
if iter_cm != new_cm:
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
iterative_recompute_error = True
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
if iter_gnm != new_gnm:
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
iterative_recompute_error = True
if iterative_recompute_error:
log += (
f"\nCANDIDATE:{candidate.get_name()}"
f"\nGROUP:{group_names}"
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
f"\nCANDIDATE:{candidate.debug_str()}"
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
)
peak_log = ""
for i, (pre, post) in enumerate(snodes_curr_memory):
if est_peak_memory == pre:
n = snodes[i]
peak_log = (
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
)
break
group_log = ""
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
group_log += (
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
)
log += peak_log
log += group_log
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
log += "\n\n".join(
[
(
f"\nSNODE[{i}]\n{n.debug_str()}"
f"\nITER_cur_mem:{iter_curr_memory[n]}"
f"\nESTM_cur_mem:{est_curr_memory[n]}"
f"\nITER_allocfree:{snodes_allocfree[n]}"
f"\nESTM_allocfree:{snodes_allocfree[n]}"
)
for i, n in enumerate(snodes)
]
)
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
print(f"{tname}:\n{log}")
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": tname,
"encoding": "string",
},
payload_fn=lambda: log,
)
return iterative_recompute_error

View File

@ -389,16 +389,6 @@ reorder_prefetch_limit: Optional[int] = None
# enable operator reordering for peak memory optimization
reorder_for_peak_memory = True
reorder_iterative_debug_memory_recompute: bool = False
reorder_iterative_debug_limit_to_reorder: Optional[int] = (
None
if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None
else int(env_str)
)
sink_waits_iterative_debug_limit_to_sink: Optional[int] = (
None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str)
)
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None

View File

@ -4,7 +4,7 @@ import collections
import dataclasses
import heapq
import logging
from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union
from typing import Callable, TYPE_CHECKING, TypedDict, Union
from torch._environment import is_fbcode
from torch._utils_internal import signpost_event
@ -76,7 +76,7 @@ def get_freeable_input_buf(
Create and keep track of all input buffers that can be freed during the program
Returns:
A dictionary containing all freeable input buffers, keyed by their names.
A dictionary containing all freeble input buffers, keyed by their names.
"""
def _dep_size_hint(dep: Dep) -> int:
@ -303,11 +303,7 @@ def compute_memory_timeline(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[
list[BufferInfo],
dict[BaseSchedulerNode, int],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
"""
Compute buffer allocation and deallocation sizes and map their
lifetime to the node schedule
@ -321,33 +317,15 @@ def compute_memory_timeline(
# get buffers' size and liveliness information
buf_info_list: list[BufferInfo] = []
buf_to_snode_last_use: dict[
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
] = {}
def _get_end_step_and_snode(
buf: Union[FreeableInputBuffer, SchedulerBuffer],
) -> tuple[int, Optional[BaseSchedulerNode]]:
max_step: int = -1
max_step_snode: Optional[BaseSchedulerNode] = None
succ_nodes = buf.mpi_buffer.succ_nodes
if succ_nodes:
for succ_node in succ_nodes:
step = node_to_step[succ_node]
if step > max_step:
max_step = step
max_step_snode = succ_node
assert max_step_snode is not None
return max_step, max_step_snode
# 1. for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items():
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(input_buf)
assert end_step_snode is not None
buf_to_snode_last_use[input_buf] = end_step_snode
end_step = (
len(nodes) - 1
if buf_name in graph_outputs
else max(
node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes
)
)
buf_info_list.append(
BufferInfo(
input_buf,
@ -364,17 +342,17 @@ def compute_memory_timeline(
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
# to be only used by its defining op (e.g., due to fusion when all consumers of
# the buffer are fused with its defining op). In such cases, end_step is step.
buf_name = sched_buf.get_name()
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(sched_buf)
if end_step == -1:
end_step = step
buf_to_snode_last_use[sched_buf] = node
else:
assert end_step_snode is not None
buf_to_snode_last_use[sched_buf] = end_step_snode
end_step = (
len(nodes) - 1
if sched_buf.get_name() in graph_outputs
else max(
[
node_to_step[succ_node]
for succ_node in sched_buf.mpi_buffer.succ_nodes
],
default=step,
)
)
buf_info_list.append(
BufferInfo(
sched_buf,
@ -385,7 +363,7 @@ def compute_memory_timeline(
)
)
return buf_info_list, node_to_step, buf_to_snode_last_use
return buf_info_list, node_to_step
def estimate_peak_memory(
@ -395,84 +373,35 @@ def estimate_peak_memory(
) -> tuple[int, list[int]]:
"""
Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveness of SchedulerBuffers and FreeableInputBuffers.
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
Returns:
int: peak memory
List[int]: memory usage at each node (or each step).
"""
# Use estimate_peak_memory_allocfree to keep one impl.
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
estimate_peak_memory_allocfree(nodes, name_to_freeable_input_buf, graph_outputs)
)
return peak_memory, [(curr_mem[0] + curr_mem[1]) for curr_mem in snodes_curr_memory]
@dataclasses.dataclass
class SNodeMemory:
size_alloc: int
size_free: int
def estimate_peak_memory_allocfree(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[
int,
list[tuple[int, int]],
dict[BaseSchedulerNode, SNodeMemory],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Alternative version of estimate_peak_memory, that respects the fact,
that every SchedulerNode has multiple phases:
1. alloc ( outputs )
2. run_kernel
3. dealloc last_use buffers
estimate_peak_memory collapses memory into one value: size_alloc - size_free
While peak memory happens after alloc.
Duplicating the code to not migrate all callsites at once,
In future usages of estimate_peak_memory will migrate to this version.
"""
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
buf_info_list, _ = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs
)
# incremental memory changes at each step
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))]
memory = [0 for _ in range(len(nodes) + 1)]
# for each buffer, update memory when created and when freed
for buf_info in buf_info_list:
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc
if buf_info.end_step != -1:
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
snodes_allocfree = {}
for i, node in enumerate(nodes):
snodes_allocfree[node] = step_idx_allocfree[i]
memory[buf_info.start_step] += buf_info.size_alloc
memory[buf_info.end_step + 1] -= buf_info.size_free
# get peak memory by compute the cumulative memories
max_memory = 0
cur_memory = 0
snodes_curr_memory = []
for t in range(len(nodes)):
alloc = step_idx_allocfree[t].size_alloc
free = step_idx_allocfree[t].size_free
cur_memory += alloc
post_alloc = cur_memory
memories_at_nodes = []
for t in range(len(nodes) + 1):
cur_memory += memory[t]
memories_at_nodes.append(cur_memory)
max_memory = max(max_memory, cur_memory)
cur_memory -= free
post_free = cur_memory
snodes_curr_memory.append((post_alloc, post_free))
return (
max_memory,
snodes_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
)
return (max_memory, memories_at_nodes)
def topological_sort_lpmf(
@ -488,7 +417,7 @@ def topological_sort_lpmf(
Buffer memory optimization for video codec application modeled in Simulink
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
The algorithm maintains the max memory so far.
The algorithm maintain the max memory so far.
At every iteration, for each scheduleable node, it computes:
- how much memory needs to be allocated for the output buffers of this node;
- how much memory can be freed as a result of executing this node.

View File

@ -2160,12 +2160,6 @@ class Scheduler:
OrderedSet(V.graph.get_output_names()),
)
if config.reorder_for_compute_comm_overlap:
if not config.reorder_for_peak_memory:
from .memory import assign_memory_planning_info_for_scheduler_buffers
assign_memory_planning_info_for_scheduler_buffers(
self.nodes, self.name_to_buf
)
from torch._logging import trace_structured
trace_structured(
@ -2562,7 +2556,7 @@ class Scheduler:
)
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
buf_info_list, _, _ = compute_memory_timeline(
buf_info_list, _ = compute_memory_timeline(
self.nodes,
name_to_freeable_input_buf,
graph_outputs,