mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)
1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory: ``` """ Alternative version of estimate_peak_memory, that respects the fact, that every SchedulerNode has multiple phases: 1. alloc ( outputs ) 2. run_kernel 3. dealloc last_use buffers estimate_peak_memory collapses memory into one value: size_alloc - size_free While peak memory happens after alloc. Duplicating the code to not migrate all callsites at once, In future usages of estimate_peak_memory will migrate to this version. """ ``` - Applying this in `reorder_communication_preserving_peak_memory` pass. 2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode. - Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size). 4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder. What is after this PR: Iterative recomputation of memory estimations matches full memory estimations. Active memory is not regressing a lot, but reserved memory is significantly regressed. Investigation and fix of "reserved" memory will be in following PRs. BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb ``` [rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step: 1 loss: 12.2722 grad_norm: 4.2192 active_memory: 24.66GiB(25.96%) reserved_memory: 25.38GiB(26.72%) tps: 99 tflops: 5.71 mfu: 0.58% [rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step: 2 loss: 13.1738 grad_norm: 50.5566 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 4,448 tflops: 257.63 mfu: 26.05% [rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step: 3 loss: 15.6866 grad_norm: 80.0862 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,900 tflops: 341.72 mfu: 34.55% [rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step: 4 loss: 13.4853 grad_norm: 7.8538 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,881 tflops: 340.57 mfu: 34.44% [rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step: 5 loss: 16.1191 grad_norm: 53.2481 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,867 tflops: 339.77 mfu: 34.35% ``` REORDER: active: 32Gb reserved: 36Gb ``` [rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step: 1 loss: 12.2490 grad_norm: 4.1944 active_memory: 24.66GiB(25.96%) reserved_memory: 26.81GiB(28.22%) tps: 85 tflops: 4.90 mfu: 0.50% [rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step: 2 loss: 13.1427 grad_norm: 39.5942 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 3,205 tflops: 185.61 mfu: 18.77% [rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step: 3 loss: 14.6084 grad_norm: 51.0743 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,688 tflops: 329.44 mfu: 33.31% [rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step: 4 loss: 13.6181 grad_norm: 8.1122 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,744 tflops: 332.68 mfu: 33.64% [rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step: 5 loss: 15.8913 grad_norm: 59.8510 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,046 tflops: 292.22 mfu: 29.55% ``` REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb ``` [rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step: 1 loss: 12.2646 grad_norm: 4.1282 active_memory: 27.60GiB(29.05%) reserved_memory: 32.49GiB(34.20%) tps: 173 tflops: 10.00 mfu: 1.01% [rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step: 2 loss: 13.2353 grad_norm: 42.4234 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,152 tflops: 356.26 mfu: 36.02% [rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step: 3 loss: 13.8205 grad_norm: 24.0156 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,169 tflops: 357.29 mfu: 36.13% [rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step: 4 loss: 13.1033 grad_norm: 9.1167 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,183 tflops: 358.10 mfu: 36.21% [rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step: 5 loss: 16.3530 grad_norm: 51.8118 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,130 tflops: 355.03 mfu: 35.90% ``` Differential Revision: [D80718143](https://our.internmc.facebook.com/intern/diff/D80718143) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113 Approved by: https://github.com/wconstab, https://github.com/eellison Co-authored-by: eellison <elias.ellison@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
639b8cc51d
commit
db44de4c0d
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
import importlib
|
||||
import itertools
|
||||
import logging
|
||||
import operator
|
||||
import sys
|
||||
@ -23,8 +24,15 @@ from .dependencies import WeakDep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ir import IRNode, Operation
|
||||
from .scheduler import SchedulerBuffer
|
||||
|
||||
from .memory import estimate_peak_memory, FreeableInputBuffer, get_freeable_input_buf
|
||||
from .memory import (
|
||||
estimate_peak_memory,
|
||||
estimate_peak_memory_allocfree,
|
||||
FreeableInputBuffer,
|
||||
get_freeable_input_buf,
|
||||
SNodeMemory,
|
||||
)
|
||||
from .utils import (
|
||||
contains_collective,
|
||||
contains_wait,
|
||||
@ -188,6 +196,46 @@ def _is_fake_dep(d):
|
||||
return isinstance(d, WeakDep) and d.is_fake
|
||||
|
||||
|
||||
def _group_names(gns: list[BaseSchedulerNode]) -> str:
|
||||
return "~".join([gn.get_name() for gn in gns])
|
||||
|
||||
|
||||
def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs):
|
||||
"""Initialize memory tracking data structures"""
|
||||
name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs)
|
||||
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
|
||||
estimate_peak_memory_allocfree(
|
||||
snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
)
|
||||
_curr_memory = dict(zip(snodes, snodes_curr_memory))
|
||||
_curr_memory[None] = (0, 0)
|
||||
return (
|
||||
peak_memory,
|
||||
_curr_memory,
|
||||
snodes_allocfree,
|
||||
buf_to_snode_last_use,
|
||||
name_to_freeable_input_buf,
|
||||
)
|
||||
|
||||
|
||||
def _initialize_double_linked_list(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> tuple[
|
||||
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
|
||||
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
|
||||
BaseSchedulerNode,
|
||||
]:
|
||||
"""Create double-linked list structure from snodes"""
|
||||
_prev = {}
|
||||
_next = {}
|
||||
for i, snode in enumerate(snodes):
|
||||
_prev[snode] = snodes[i - 1] if i > 0 else None
|
||||
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
||||
_head = snodes[0]
|
||||
return _prev, _next, _head
|
||||
|
||||
|
||||
def _reorder_communication_preserving_peak_memory_internal(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
||||
@ -211,20 +259,22 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||
# heuristic to avoid degenerating to quadratic time
|
||||
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
||||
snodes, graph_inputs
|
||||
)
|
||||
peak_memory, curr_memory = estimate_peak_memory(
|
||||
snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
runtimes = {snode: estimate_op_runtime(snode) for snode in snodes}
|
||||
_curr_memory = dict(zip(snodes, curr_memory))
|
||||
_curr_memory[None] = 0 # type: ignore[index]
|
||||
|
||||
(
|
||||
peak_memory,
|
||||
_curr_memory,
|
||||
snodes_allocfree,
|
||||
buf_to_snode_last_use,
|
||||
name_to_freeable_input_buf,
|
||||
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
|
||||
runtimes: dict[BaseSchedulerNode, float] = {
|
||||
snode: estimate_op_runtime(snode) for snode in snodes
|
||||
}
|
||||
# debug stats
|
||||
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
|
||||
|
||||
def exposed_communication_time(collective_snode, remaining_snodes):
|
||||
def exposed_communication_time(
|
||||
collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode]
|
||||
) -> float:
|
||||
# assumes a linear schedule and computes the overlap of the collective with the remaining nodes
|
||||
comm_time = estimate_op_runtime(collective_snode)
|
||||
compute_time = 0.0
|
||||
@ -236,7 +286,7 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||
# we can ignore it. Otherwise, it's the end of the road for overlap opportunities
|
||||
break
|
||||
|
||||
def accumulate_time(_snode):
|
||||
def accumulate_time(_snode: BaseSchedulerNode) -> None:
|
||||
nonlocal compute_time
|
||||
compute_time += runtimes[_snode]
|
||||
|
||||
@ -245,18 +295,11 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||
|
||||
total_moves = 0
|
||||
|
||||
# Dicts to keep track of "next" and "previous" as double-linked structure during grouping
|
||||
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||
for i, snode in enumerate(snodes):
|
||||
_prev[snode] = snodes[i - 1] if i > 0 else None
|
||||
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
||||
_curr_memory = dict(zip(snodes, curr_memory))
|
||||
_curr_memory[None] = 0 # type: ignore[index]
|
||||
_prev, _next, _head = _initialize_double_linked_list(snodes)
|
||||
|
||||
_head = snodes[0]
|
||||
|
||||
def _group_nodes(head, tail):
|
||||
def _group_nodes(
|
||||
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
|
||||
) -> list[BaseSchedulerNode]:
|
||||
ret = []
|
||||
n = head
|
||||
while True:
|
||||
@ -264,37 +307,167 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||
ret.append(n)
|
||||
if n == tail:
|
||||
break
|
||||
n = _next[n]
|
||||
n = _next[n] # type: ignore[index]
|
||||
return ret
|
||||
|
||||
def _group_names(head, tail):
|
||||
ret = ""
|
||||
for n in _group_nodes(head, tail):
|
||||
if ret:
|
||||
ret += "~"
|
||||
ret += n.get_name()
|
||||
return ret
|
||||
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
|
||||
# swap (candidate, group_head...group_tail)
|
||||
# Before:
|
||||
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
|
||||
# After:
|
||||
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
|
||||
# 0
|
||||
candidate_prev = _prev[candidate]
|
||||
if candidate_prev:
|
||||
_next[candidate_prev] = group_head
|
||||
_prev[group_head] = candidate_prev
|
||||
|
||||
# 2
|
||||
group_tail_next = _next[group_tail]
|
||||
if group_tail_next:
|
||||
_prev[group_tail_next] = candidate
|
||||
_next[candidate] = group_tail_next
|
||||
|
||||
# 1
|
||||
_prev[candidate] = group_tail
|
||||
_next[group_tail] = candidate
|
||||
|
||||
nonlocal _head
|
||||
if _head == candidate:
|
||||
_head = group_head
|
||||
|
||||
def _calculate_potential_peak_memory(
|
||||
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate
|
||||
):
|
||||
# Caching calculations of memory for group nodes and candidate,
|
||||
# to apply without recalculation after swap.
|
||||
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
|
||||
potential_peak: int = 0
|
||||
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
|
||||
# Not accounting for buffers last use change
|
||||
potential_peak = max(
|
||||
group_peak_memory - candidate_delta_mem,
|
||||
_curr_memory[group_tail][1]
|
||||
- candidate_delta_mem
|
||||
+ candidate_allocfree.size_alloc,
|
||||
)
|
||||
return potential_peak, _post_alloc_update
|
||||
|
||||
# If candidate will be after group, the starting memory level of group nodes
|
||||
# changes to the -(candidate.size_alloc - candidate.size_free)
|
||||
mem_after_reorder_delta: int = -candidate_delta_mem
|
||||
for gn in gns:
|
||||
gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta
|
||||
_post_alloc_update[gn] = gn_post_alloc_mem
|
||||
potential_peak = max(potential_peak, gn_post_alloc_mem)
|
||||
|
||||
bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None)
|
||||
if bufs is not None:
|
||||
for buf in bufs:
|
||||
# Candidate will deallocate those buffers
|
||||
mem_after_reorder_delta += buf.mpi_buffer.size_free
|
||||
|
||||
candidate_mem_post_alloc = (
|
||||
_curr_memory[group_tail][1]
|
||||
+ mem_after_reorder_delta
|
||||
+ candidate_allocfree.size_alloc
|
||||
)
|
||||
_post_alloc_update[candidate] = candidate_mem_post_alloc
|
||||
potential_peak = max(potential_peak, candidate_mem_post_alloc)
|
||||
return potential_peak, _post_alloc_update
|
||||
|
||||
def _update_memory_tracking_after_swap(
|
||||
candidate,
|
||||
gns,
|
||||
group_n_to_bufs_after_swap_dealloc_by_candidate,
|
||||
_post_alloc_update,
|
||||
):
|
||||
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
|
||||
for gn in gns:
|
||||
cm = _curr_memory[gn]
|
||||
_curr_memory[gn] = (
|
||||
cm[0] - candidate_delta_mem,
|
||||
cm[1] - candidate_delta_mem,
|
||||
)
|
||||
_candidate_post_alloc_mem = (
|
||||
_curr_memory[group_tail][1] + candidate_allocfree.size_alloc
|
||||
)
|
||||
_candidate_post_free_mem = (
|
||||
_candidate_post_alloc_mem - candidate_allocfree.size_free
|
||||
)
|
||||
_curr_memory[candidate] = (
|
||||
_candidate_post_alloc_mem,
|
||||
_candidate_post_free_mem,
|
||||
)
|
||||
return
|
||||
|
||||
# Candidate becomes last use of some bufs
|
||||
for (
|
||||
gn,
|
||||
bufs,
|
||||
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
|
||||
for buf in bufs:
|
||||
buf_to_snode_last_use[buf] = candidate
|
||||
|
||||
size_free_to_move_to_candidate_sum: int = 0
|
||||
for n in gns:
|
||||
_gn_post_alloc_mem: int = _post_alloc_update[n]
|
||||
size_free_to_move_to_candidate: int = sum(
|
||||
buf.mpi_buffer.size_free
|
||||
for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n]
|
||||
)
|
||||
size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate
|
||||
# group node does not deallocate this after swap
|
||||
snodes_allocfree[n].size_free -= size_free_to_move_to_candidate
|
||||
gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free
|
||||
_curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem)
|
||||
_candidate_post_alloc_mem = _post_alloc_update[candidate]
|
||||
snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum
|
||||
candidate_post_free_mem = (
|
||||
_candidate_post_alloc_mem - snodes_allocfree[candidate].size_free
|
||||
)
|
||||
_curr_memory[candidate] = (
|
||||
_candidate_post_alloc_mem,
|
||||
candidate_post_free_mem,
|
||||
)
|
||||
|
||||
debug_num_collectives_to_reorder: Optional[int] = (
|
||||
config.reorder_iterative_debug_limit_to_reorder
|
||||
)
|
||||
|
||||
num_processed_collectives: int = 0
|
||||
curr = _head
|
||||
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
|
||||
iterative_recompute_error = False
|
||||
|
||||
while _next[curr] is not None:
|
||||
if iterative_recompute_error:
|
||||
break
|
||||
if contains_collective(curr):
|
||||
reorder_info = stats[curr] = ReorderInfo()
|
||||
reorder_info.initial_exposed = reorder_info.final_exposed = (
|
||||
exposed_communication_time(curr, _group_nodes(_next[curr], None))
|
||||
if debug_num_collectives_to_reorder is not None and (
|
||||
num_processed_collectives >= debug_num_collectives_to_reorder
|
||||
):
|
||||
break
|
||||
num_processed_collectives += 1
|
||||
|
||||
info = stats[curr] = ReorderInfo()
|
||||
info.initial_exposed = info.final_exposed = exposed_communication_time(
|
||||
curr, _group_nodes(_next[curr], None)
|
||||
)
|
||||
|
||||
candidate = _prev[curr]
|
||||
group_head = curr
|
||||
group_tail = curr
|
||||
group_peak_memory = _curr_memory[curr]
|
||||
group_peak_memory = _curr_memory[curr][0] # post_alloc memory
|
||||
while candidate is not None:
|
||||
if contains_collective(candidate):
|
||||
reorder_info.limiting_factor = "collective ordering"
|
||||
info.limiting_factor = "collective ordering"
|
||||
break
|
||||
|
||||
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
|
||||
group = GroupedSchedulerNode(
|
||||
curr.scheduler,
|
||||
_group_nodes(group_head, group_tail),
|
||||
gns,
|
||||
temp_grouping=True,
|
||||
)
|
||||
|
||||
@ -314,7 +487,9 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||
|
||||
if data_dep is not None:
|
||||
|
||||
def is_groupable(candidate):
|
||||
def is_groupable(
|
||||
candidate: BaseSchedulerNode,
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
# preserve ordering
|
||||
if contains_collective(candidate):
|
||||
return False, "contains_collective"
|
||||
@ -323,73 +498,106 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||
return False, "contains_gemm_like"
|
||||
return True, None
|
||||
|
||||
is_grp, grp_reason = is_groupable(candidate)
|
||||
if is_grp:
|
||||
is_groupable_result, grouping_reason = is_groupable(candidate)
|
||||
if is_groupable_result:
|
||||
group_head = candidate
|
||||
group_peak_memory = max(
|
||||
group_peak_memory, _curr_memory[candidate]
|
||||
group_peak_memory, _curr_memory[candidate][0]
|
||||
)
|
||||
reorder_info.grouped += 1
|
||||
reorder_info.grouped_info = _group_names(group_head, group_tail)
|
||||
info.grouped += 1
|
||||
info.grouped_info = _group_names(gns)
|
||||
candidate = _prev[candidate]
|
||||
continue
|
||||
else:
|
||||
msg = (
|
||||
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
||||
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
||||
f"dep on {_group_names(group_head, group_tail)}"
|
||||
f"\n non_group_reason:{grp_reason}"
|
||||
f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})"
|
||||
f"dep on {_group_names(gns)}"
|
||||
f"\n non_group_reason:{grouping_reason}"
|
||||
)
|
||||
reorder_info.limiting_factor = msg
|
||||
info.limiting_factor = msg
|
||||
break
|
||||
|
||||
delta_memory_candidate = (
|
||||
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
|
||||
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
|
||||
candidate_delta_mem: int = (
|
||||
candidate_allocfree.size_alloc - candidate_allocfree.size_free
|
||||
)
|
||||
# candidate and one of group nodes are successors of the same buffer
|
||||
# and last use of the buffer happen in group nodes.
|
||||
# This last use deallocates it.
|
||||
# If we swap [candidate [group]] to [[group] candidate],
|
||||
# candidate becomes the last use
|
||||
# and deallocated this buffer instead of group node.
|
||||
# we need to update size_free accordingly to group_node and candidate,
|
||||
# and recalculate post_alloc, post_free for them.
|
||||
#
|
||||
# Buf that changes its last use snode,
|
||||
# after swap will be deallocated only by candidate,
|
||||
# while before it was deallocated by group node.
|
||||
group_n_to_bufs_after_swap_dealloc_by_candidate: dict[
|
||||
BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
|
||||
] = defaultdict(list)
|
||||
for (
|
||||
buf,
|
||||
snode_last_use,
|
||||
) in buf_to_snode_last_use.items():
|
||||
succ_nodes = buf.mpi_buffer.succ_nodes
|
||||
if candidate not in succ_nodes:
|
||||
continue
|
||||
|
||||
if not any(gn == snode_last_use for gn in gns):
|
||||
continue
|
||||
|
||||
group_n_to_bufs_after_swap_dealloc_by_candidate[
|
||||
snode_last_use
|
||||
].append(buf)
|
||||
|
||||
potential_peak, _post_alloc_update = _calculate_potential_peak_memory(
|
||||
candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate
|
||||
)
|
||||
|
||||
if group_peak_memory - delta_memory_candidate > peak_memory:
|
||||
reorder_info.limiting_factor = "peak memory"
|
||||
if potential_peak > peak_memory:
|
||||
info.limiting_factor = (
|
||||
f"peak memory new:{potential_peak} vs base:{peak_memory}"
|
||||
)
|
||||
break
|
||||
|
||||
reorder_info.moves += 1
|
||||
info.moves += 1
|
||||
total_moves += 1
|
||||
|
||||
mem_deltas = {}
|
||||
for n in [candidate, *_group_nodes(group_head, group_tail)]:
|
||||
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
|
||||
# swap (candidate, group_head...group_tail)
|
||||
# Before:
|
||||
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
|
||||
# After:
|
||||
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
|
||||
# 0
|
||||
candidate_prev = _prev[candidate]
|
||||
if candidate_prev:
|
||||
_next[candidate_prev] = group_head
|
||||
_prev[group_head] = candidate_prev
|
||||
_perform_double_linked_list_swap(candidate, group_head, group_tail)
|
||||
|
||||
# 2
|
||||
group_tail_next = _next[group_tail]
|
||||
if group_tail_next:
|
||||
_prev[group_tail_next] = candidate
|
||||
_next[candidate] = group_tail_next
|
||||
|
||||
# 1
|
||||
_prev[candidate] = group_tail
|
||||
_next[group_tail] = candidate
|
||||
|
||||
if _head == candidate:
|
||||
_head = group_head
|
||||
|
||||
reorder_info.final_exposed = exposed_communication_time(
|
||||
info.final_exposed = exposed_communication_time(
|
||||
curr, _group_nodes(_next[curr], None)
|
||||
)
|
||||
# Recompute curr_memory
|
||||
_prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index]
|
||||
for n in _group_nodes(group_head, candidate):
|
||||
_curr_memory[n] = _prev_curr_memory = (
|
||||
_prev_curr_memory + mem_deltas[n]
|
||||
|
||||
_update_memory_tracking_after_swap(
|
||||
candidate,
|
||||
gns,
|
||||
group_n_to_bufs_after_swap_dealloc_by_candidate,
|
||||
_post_alloc_update,
|
||||
)
|
||||
|
||||
if debug_iterative_memory_recompute:
|
||||
# Compare iteratively recomputed memory data
|
||||
# with full run of estimate_peak_memory
|
||||
|
||||
from .comms_debug import _debug_iterative_memory_recompute
|
||||
|
||||
iterative_recompute_error = _debug_iterative_memory_recompute(
|
||||
candidate,
|
||||
gns,
|
||||
_group_names(gns),
|
||||
_group_nodes(_head, None),
|
||||
name_to_freeable_input_buf,
|
||||
graph_outputs,
|
||||
peak_memory,
|
||||
_curr_memory,
|
||||
snodes_allocfree,
|
||||
"reorder_communication_preserving_peak_memory",
|
||||
group_n_to_bufs_after_swap_dealloc_by_candidate,
|
||||
)
|
||||
if iterative_recompute_error:
|
||||
break
|
||||
candidate = _prev[group_head]
|
||||
curr = _next[curr] # type: ignore[assignment]
|
||||
|
||||
@ -415,15 +623,15 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||
rows = [
|
||||
[
|
||||
node_summary(snode),
|
||||
node_reorder_info.initial_exposed,
|
||||
node_reorder_info.final_exposed,
|
||||
node_reorder_info.improvement,
|
||||
node_reorder_info.limiting_factor,
|
||||
node_reorder_info.moves,
|
||||
node_reorder_info.grouped,
|
||||
node_reorder_info.grouped_info,
|
||||
node_info.initial_exposed,
|
||||
node_info.final_exposed,
|
||||
node_info.improvement,
|
||||
node_info.limiting_factor,
|
||||
node_info.moves,
|
||||
node_info.grouped,
|
||||
node_info.grouped_info,
|
||||
]
|
||||
for snode, node_reorder_info in node_stats.items()
|
||||
for snode, node_info in node_stats.items()
|
||||
]
|
||||
if importlib.util.find_spec("tabulate"):
|
||||
from tabulate import tabulate
|
||||
@ -441,7 +649,7 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||
|
||||
new_snodes = _group_nodes(_head, None)
|
||||
assert len(new_snodes) == original_snodes_num
|
||||
new_peak_memory, curr_memory = estimate_peak_memory(
|
||||
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
|
||||
new_snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
reorder_log_str += f"\n peak_memory_before:{peak_memory}"
|
||||
@ -657,24 +865,21 @@ def _sink_waits_iterative_internal(
|
||||
return snodes, {}
|
||||
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
||||
snodes, graph_inputs
|
||||
)
|
||||
peak_memory, curr_memory = estimate_peak_memory(
|
||||
snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
(
|
||||
peak_memory,
|
||||
_curr_memory,
|
||||
snodes_allocfree,
|
||||
buf_to_snode_last_use,
|
||||
name_to_freeable_input_buf,
|
||||
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
|
||||
|
||||
_prev, _next, _head = _initialize_double_linked_list(snodes)
|
||||
|
||||
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
|
||||
_prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||
_next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {}
|
||||
_head = snodes[0]
|
||||
for i, snode in enumerate(snodes):
|
||||
_prev[snode] = snodes[i - 1] if i > 0 else None
|
||||
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
||||
_curr_memory = dict(zip(snodes, curr_memory))
|
||||
_curr_memory[None] = 0 # type: ignore[index]
|
||||
|
||||
def _group_nodes(head, tail):
|
||||
def _group_nodes(
|
||||
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
|
||||
) -> list[BaseSchedulerNode]:
|
||||
ret = []
|
||||
n = head
|
||||
while True:
|
||||
@ -682,21 +887,125 @@ def _sink_waits_iterative_internal(
|
||||
ret.append(n)
|
||||
if n == tail:
|
||||
break
|
||||
n = _next[n]
|
||||
n = _next[n] # type: ignore[index]
|
||||
return ret
|
||||
|
||||
def _group_names(head, tail):
|
||||
ret = ""
|
||||
for n in _group_nodes(head, tail):
|
||||
if ret:
|
||||
ret += "~"
|
||||
ret += n.get_name()
|
||||
return ret
|
||||
def _calculate_potential_peak_memory(
|
||||
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate
|
||||
):
|
||||
pre_group_mem = (
|
||||
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
|
||||
)
|
||||
# Stash memory tracing updates to not recompute them after swap
|
||||
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
|
||||
_size_free_delta_update: dict[BaseSchedulerNode, int] = {}
|
||||
|
||||
potential_peak = 0
|
||||
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
|
||||
# Not accounting for buffers liveliness change
|
||||
potential_peak = max(
|
||||
group_peak_memory + candidate_delta_mem,
|
||||
pre_group_mem + candidate_allocfree.size_alloc,
|
||||
)
|
||||
return potential_peak, _post_alloc_update, _size_free_delta_update
|
||||
|
||||
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
|
||||
_post_alloc_update[candidate] = candidate_post_alloc
|
||||
potential_peak = candidate_post_alloc
|
||||
candidate_size_free_to_move = sum(
|
||||
buf.mpi_buffer.size_free # type: ignore[attr-defined]
|
||||
for buf in itertools.chain.from_iterable(
|
||||
group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values()
|
||||
)
|
||||
)
|
||||
_size_free_delta_update[candidate] = -candidate_size_free_to_move
|
||||
delta_mem = candidate_delta_mem + candidate_size_free_to_move
|
||||
for gn in gns:
|
||||
gn_post_alloc = _curr_memory[gn][0] + delta_mem
|
||||
_post_alloc_update[gn] = gn_post_alloc
|
||||
potential_peak = max(potential_peak, gn_post_alloc)
|
||||
gn_size_free_to_add = 0
|
||||
if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
|
||||
bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn]
|
||||
for buf in bufs:
|
||||
gn_size_free_to_add += buf.mpi_buffer.size_free
|
||||
_size_free_delta_update[gn] = gn_size_free_to_add
|
||||
delta_mem -= gn_size_free_to_add
|
||||
return potential_peak, _post_alloc_update, _size_free_delta_update
|
||||
|
||||
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
|
||||
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
|
||||
# 0:
|
||||
group_head_prev = _prev[group_head]
|
||||
if group_head_prev:
|
||||
_next[group_head_prev] = candidate
|
||||
_prev[candidate] = group_head_prev
|
||||
|
||||
# 2:
|
||||
candidate_next = _next[candidate]
|
||||
if candidate_next:
|
||||
_prev[candidate_next] = group_tail
|
||||
_next[group_tail] = candidate_next
|
||||
|
||||
# 1:
|
||||
_prev[group_head] = candidate
|
||||
_next[candidate] = group_head
|
||||
nonlocal _head
|
||||
if group_head == _head:
|
||||
_head = candidate
|
||||
|
||||
def _update_memory_tracking_after_swap(
|
||||
candidate,
|
||||
gns,
|
||||
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
||||
_post_alloc_update,
|
||||
_size_free_delta_update,
|
||||
):
|
||||
group_head = gns[0]
|
||||
pre_group_mem = (
|
||||
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
|
||||
)
|
||||
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
|
||||
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
|
||||
_curr_memory[candidate] = (
|
||||
candidate_post_alloc,
|
||||
candidate_post_alloc - candidate_allocfree.size_free,
|
||||
)
|
||||
for gn in gns:
|
||||
cm = _curr_memory[gn]
|
||||
_curr_memory[gn] = (
|
||||
cm[0] + candidate_delta_mem,
|
||||
cm[1] + candidate_delta_mem,
|
||||
)
|
||||
return
|
||||
|
||||
for n in [candidate, *gns]:
|
||||
post_alloc = _post_alloc_update[n]
|
||||
snodes_allocfree[n].size_free += _size_free_delta_update[n]
|
||||
_curr_memory[n] = (
|
||||
post_alloc,
|
||||
post_alloc - snodes_allocfree[n].size_free,
|
||||
)
|
||||
|
||||
curr = snodes[-1]
|
||||
|
||||
processed_waits = OrderedSet() # type: ignore[var-annotated]
|
||||
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
|
||||
debug_num_sink_waits_to_reorder: Optional[int] = (
|
||||
config.sink_waits_iterative_debug_limit_to_sink
|
||||
)
|
||||
|
||||
iterative_recompute_error = False
|
||||
|
||||
while _prev[curr] is not None:
|
||||
if iterative_recompute_error:
|
||||
break
|
||||
if (
|
||||
debug_num_sink_waits_to_reorder is not None
|
||||
and len(processed_waits) >= debug_num_sink_waits_to_reorder
|
||||
):
|
||||
break
|
||||
|
||||
if contains_wait(curr) and curr not in processed_waits:
|
||||
processed_waits.add(curr)
|
||||
info = stats[curr] = SinkWaitInfo()
|
||||
@ -704,11 +1013,14 @@ def _sink_waits_iterative_internal(
|
||||
wait_snode = curr
|
||||
group_head = curr
|
||||
group_tail = curr
|
||||
group_peak_memory = _curr_memory[curr]
|
||||
group_peak_memory = _curr_memory[curr][0]
|
||||
while candidate is not None:
|
||||
if iterative_recompute_error:
|
||||
break
|
||||
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
|
||||
group = GroupedSchedulerNode(
|
||||
wait_snode.scheduler,
|
||||
_group_nodes(group_head, group_tail),
|
||||
gns,
|
||||
temp_grouping=True,
|
||||
)
|
||||
|
||||
@ -753,15 +1065,15 @@ def _sink_waits_iterative_internal(
|
||||
if is_grp:
|
||||
group_tail = candidate
|
||||
group_peak_memory = max(
|
||||
group_peak_memory, _curr_memory[candidate]
|
||||
group_peak_memory, _curr_memory[candidate][0]
|
||||
)
|
||||
info.grouped += 1
|
||||
info.grouped_info = _group_names(group_head, group_tail)
|
||||
info.grouped_info = _group_names(gns)
|
||||
candidate = _next[candidate]
|
||||
continue
|
||||
elif (data_dep is None) and both_contain_comms:
|
||||
info.limiting_factor = (
|
||||
f"collective ordering {_group_names(group_head, group_tail)}"
|
||||
f"collective ordering {_group_names(gns)}"
|
||||
f" with candidate:{candidate.get_name()}"
|
||||
)
|
||||
break
|
||||
@ -769,49 +1081,89 @@ def _sink_waits_iterative_internal(
|
||||
info.limiting_factor = (
|
||||
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
||||
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
||||
f"dep on {_group_names(group_head, group_tail)}"
|
||||
f"dep on {gns}"
|
||||
f"\n outs:{[o.get_name() for o in group_outs]}"
|
||||
f"\n non_group_reason:{grp_reason}"
|
||||
)
|
||||
break
|
||||
candidate_delta_memory = (
|
||||
_curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index]
|
||||
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
|
||||
candidate_delta_mem = (
|
||||
candidate_allocfree.size_alloc - candidate_allocfree.size_free
|
||||
)
|
||||
if group_peak_memory + candidate_delta_memory > peak_memory:
|
||||
info.limiting_factor = "peak_memory"
|
||||
# [group] candidate -> candidate [group]
|
||||
# Check for buffers with successors in group and candidate last successor
|
||||
#
|
||||
# Buf that changes its last use snode,
|
||||
# It was deallocated by candidate,
|
||||
# but after swap it will be deallocated by group node.
|
||||
group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[
|
||||
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
|
||||
] = defaultdict(list)
|
||||
for (
|
||||
buf,
|
||||
snode_last_use,
|
||||
) in buf_to_snode_last_use.items():
|
||||
succ_nodes = buf.mpi_buffer.succ_nodes
|
||||
if snode_last_use != candidate: # noqa: E711
|
||||
continue
|
||||
# candidate is last use of buf
|
||||
last_succ_gn = None
|
||||
for gn in gns:
|
||||
if gn in succ_nodes:
|
||||
last_succ_gn = gn
|
||||
if last_succ_gn is None:
|
||||
continue
|
||||
|
||||
# gn has successors of buf that after potential swap will become
|
||||
# last use of buf and start deallocating buf instead of candidate
|
||||
group_n_to_bufs_after_swap_dealloc_instead_of_candidate[
|
||||
last_succ_gn
|
||||
].append(buf)
|
||||
|
||||
potential_peak, _post_alloc_update, _size_free_delta_update = (
|
||||
_calculate_potential_peak_memory(
|
||||
candidate,
|
||||
gns,
|
||||
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
||||
)
|
||||
)
|
||||
if potential_peak > peak_memory:
|
||||
info.limiting_factor = (
|
||||
f"peak memory new:{potential_peak} vs base:{peak_memory}"
|
||||
)
|
||||
break
|
||||
|
||||
info.moves += 1
|
||||
info.moves_info += f"+{candidate.get_name()}"
|
||||
|
||||
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
|
||||
mem_deltas = {}
|
||||
for n in [candidate, *_group_nodes(group_head, group_tail)]:
|
||||
mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index]
|
||||
# 0:
|
||||
group_head_prev = _prev[group_head]
|
||||
if group_head_prev:
|
||||
_next[group_head_prev] = candidate
|
||||
_prev[candidate] = group_head_prev
|
||||
_perform_double_linked_list_swap(candidate, group_head, group_tail)
|
||||
|
||||
# 2:
|
||||
candidate_next = _next[candidate]
|
||||
if candidate_next:
|
||||
_prev[candidate_next] = group_tail
|
||||
_next[group_tail] = candidate_next
|
||||
_update_memory_tracking_after_swap(
|
||||
candidate,
|
||||
gns,
|
||||
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
||||
_post_alloc_update,
|
||||
_size_free_delta_update,
|
||||
)
|
||||
|
||||
# 1:
|
||||
_prev[group_head] = candidate
|
||||
_next[candidate] = group_head
|
||||
if group_head == _head:
|
||||
_head = candidate
|
||||
if debug_iterative_memory_recompute:
|
||||
from .comms_debug import _debug_iterative_memory_recompute
|
||||
|
||||
# Recompute curr_memory
|
||||
_prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index]
|
||||
for n in _group_nodes(candidate, group_tail):
|
||||
_curr_memory[n] = _prev_curr_memory = (
|
||||
_prev_curr_memory + mem_deltas[n]
|
||||
iterative_recompute_error = _debug_iterative_memory_recompute(
|
||||
candidate,
|
||||
gns,
|
||||
_group_names(gns),
|
||||
_group_nodes(_head, None),
|
||||
name_to_freeable_input_buf,
|
||||
graph_outputs,
|
||||
peak_memory,
|
||||
_curr_memory,
|
||||
snodes_allocfree,
|
||||
"sink_waits_iterative",
|
||||
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
||||
)
|
||||
if iterative_recompute_error:
|
||||
break
|
||||
|
||||
candidate = _next[group_tail]
|
||||
curr = _prev[curr] # type: ignore[assignment]
|
||||
@ -850,11 +1202,11 @@ def _sink_waits_iterative_internal(
|
||||
overlap_log.info(log_str)
|
||||
new_snodes = _group_nodes(_head, None)
|
||||
assert len(new_snodes) == original_snodes_num
|
||||
new_peak_memory, curr_memory = estimate_peak_memory(
|
||||
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
|
||||
new_snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
log_str += f"\n peak_memory_before:{peak_memory}"
|
||||
log_str += f"\n peak_memory_after:{new_peak_memory}"
|
||||
log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}"
|
||||
log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}"
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
|
112
torch/_inductor/comms_debug.py
Normal file
112
torch/_inductor/comms_debug.py
Normal file
@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from torch._logging import trace_structured
|
||||
|
||||
from .memory import estimate_peak_memory_allocfree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .memory import FreeableInputBuffer, SNodeMemory
|
||||
from .scheduler import BaseSchedulerNode, SchedulerBuffer
|
||||
|
||||
|
||||
def _debug_iterative_memory_recompute(
|
||||
candidate: BaseSchedulerNode,
|
||||
gns: list[BaseSchedulerNode],
|
||||
group_names: str,
|
||||
snodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
peak_memory: int,
|
||||
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
|
||||
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
|
||||
tlparse_name: str,
|
||||
gn_to_bufs_last_use: dict[
|
||||
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
|
||||
],
|
||||
) -> bool:
|
||||
iterative_recompute_error = False
|
||||
candidate_allocfree = snodes_allocfree[candidate]
|
||||
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
|
||||
estimate_peak_memory_allocfree(
|
||||
snodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
)
|
||||
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
|
||||
iter_cm = iter_curr_memory[candidate]
|
||||
new_cm = est_curr_memory[candidate]
|
||||
log = ""
|
||||
if est_peak_memory > peak_memory:
|
||||
log = "ITERATIVE PEAK DOES NOT MATCH"
|
||||
iterative_recompute_error = True
|
||||
if iter_cm != new_cm:
|
||||
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
|
||||
iterative_recompute_error = True
|
||||
for i, gn in enumerate(gns):
|
||||
iter_gnm = iter_curr_memory[gn]
|
||||
new_gnm = est_curr_memory[gn]
|
||||
if iter_gnm != new_gnm:
|
||||
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
|
||||
iterative_recompute_error = True
|
||||
if iterative_recompute_error:
|
||||
log += (
|
||||
f"\nCANDIDATE:{candidate.get_name()}"
|
||||
f"\nGROUP:{group_names}"
|
||||
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
|
||||
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
|
||||
f"\nCANDIDATE:{candidate.debug_str()}"
|
||||
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
|
||||
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
|
||||
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
|
||||
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
|
||||
)
|
||||
peak_log = ""
|
||||
for i, (pre, post) in enumerate(snodes_curr_memory):
|
||||
if est_peak_memory == pre:
|
||||
n = snodes[i]
|
||||
peak_log = (
|
||||
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
|
||||
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
|
||||
)
|
||||
break
|
||||
group_log = ""
|
||||
for i, gn in enumerate(gns):
|
||||
iter_gnm = iter_curr_memory[gn]
|
||||
new_gnm = est_curr_memory[gn]
|
||||
group_log += (
|
||||
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
|
||||
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
|
||||
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
|
||||
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
|
||||
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
|
||||
)
|
||||
log += peak_log
|
||||
log += group_log
|
||||
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
|
||||
log += "\n\n".join(
|
||||
[
|
||||
(
|
||||
f"\nSNODE[{i}]\n{n.debug_str()}"
|
||||
f"\nITER_cur_mem:{iter_curr_memory[n]}"
|
||||
f"\nESTM_cur_mem:{est_curr_memory[n]}"
|
||||
f"\nITER_allocfree:{snodes_allocfree[n]}"
|
||||
f"\nESTM_allocfree:{snodes_allocfree[n]}"
|
||||
)
|
||||
for i, n in enumerate(snodes)
|
||||
]
|
||||
)
|
||||
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
|
||||
print(f"{tname}:\n{log}")
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": tname,
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: log,
|
||||
)
|
||||
return iterative_recompute_error
|
@ -389,6 +389,16 @@ reorder_prefetch_limit: Optional[int] = None
|
||||
# enable operator reordering for peak memory optimization
|
||||
reorder_for_peak_memory = True
|
||||
|
||||
reorder_iterative_debug_memory_recompute: bool = False
|
||||
reorder_iterative_debug_limit_to_reorder: Optional[int] = (
|
||||
None
|
||||
if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None
|
||||
else int(env_str)
|
||||
)
|
||||
sink_waits_iterative_debug_limit_to_sink: Optional[int] = (
|
||||
None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str)
|
||||
)
|
||||
|
||||
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
|
||||
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
|
||||
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None
|
||||
|
@ -4,7 +4,7 @@ import collections
|
||||
import dataclasses
|
||||
import heapq
|
||||
import logging
|
||||
from typing import Callable, TYPE_CHECKING, TypedDict, Union
|
||||
from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union
|
||||
|
||||
from torch._environment import is_fbcode
|
||||
from torch._utils_internal import signpost_event
|
||||
@ -76,7 +76,7 @@ def get_freeable_input_buf(
|
||||
Create and keep track of all input buffers that can be freed during the program
|
||||
|
||||
Returns:
|
||||
A dictionary containing all freeble input buffers, keyed by their names.
|
||||
A dictionary containing all freeable input buffers, keyed by their names.
|
||||
"""
|
||||
|
||||
def _dep_size_hint(dep: Dep) -> int:
|
||||
@ -315,7 +315,11 @@ def compute_memory_timeline(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
|
||||
) -> tuple[
|
||||
list[BufferInfo],
|
||||
dict[BaseSchedulerNode, int],
|
||||
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
|
||||
]:
|
||||
"""
|
||||
Compute buffer allocation and deallocation sizes and map their
|
||||
lifetime to the node schedule
|
||||
@ -329,15 +333,33 @@ def compute_memory_timeline(
|
||||
|
||||
# get buffers' size and liveliness information
|
||||
buf_info_list: list[BufferInfo] = []
|
||||
buf_to_snode_last_use: dict[
|
||||
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
|
||||
] = {}
|
||||
|
||||
def _get_end_step_and_snode(
|
||||
buf: Union[FreeableInputBuffer, SchedulerBuffer],
|
||||
) -> tuple[int, Optional[BaseSchedulerNode]]:
|
||||
max_step: int = -1
|
||||
max_step_snode: Optional[BaseSchedulerNode] = None
|
||||
succ_nodes = buf.mpi_buffer.succ_nodes
|
||||
if succ_nodes:
|
||||
for succ_node in succ_nodes:
|
||||
step = node_to_step[succ_node]
|
||||
if step > max_step:
|
||||
max_step = step
|
||||
max_step_snode = succ_node
|
||||
assert max_step_snode is not None
|
||||
return max_step, max_step_snode
|
||||
|
||||
# 1. for freeable input buffers
|
||||
for buf_name, input_buf in name_to_freeable_input_buf.items():
|
||||
end_step = (
|
||||
len(nodes) - 1
|
||||
if buf_name in graph_outputs
|
||||
else max(
|
||||
node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes
|
||||
)
|
||||
)
|
||||
end_step = -1
|
||||
if buf_name not in graph_outputs:
|
||||
end_step, end_step_snode = _get_end_step_and_snode(input_buf)
|
||||
assert end_step_snode is not None
|
||||
buf_to_snode_last_use[input_buf] = end_step_snode
|
||||
|
||||
buf_info_list.append(
|
||||
BufferInfo(
|
||||
input_buf,
|
||||
@ -354,17 +376,17 @@ def compute_memory_timeline(
|
||||
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
|
||||
# to be only used by its defining op (e.g., due to fusion when all consumers of
|
||||
# the buffer are fused with its defining op). In such cases, end_step is step.
|
||||
end_step = (
|
||||
len(nodes) - 1
|
||||
if sched_buf.get_name() in graph_outputs
|
||||
else max(
|
||||
[
|
||||
node_to_step[succ_node]
|
||||
for succ_node in sched_buf.mpi_buffer.succ_nodes
|
||||
],
|
||||
default=step,
|
||||
)
|
||||
)
|
||||
buf_name = sched_buf.get_name()
|
||||
end_step = -1
|
||||
if buf_name not in graph_outputs:
|
||||
end_step, end_step_snode = _get_end_step_and_snode(sched_buf)
|
||||
if end_step == -1:
|
||||
end_step = step
|
||||
buf_to_snode_last_use[sched_buf] = node
|
||||
else:
|
||||
assert end_step_snode is not None
|
||||
buf_to_snode_last_use[sched_buf] = end_step_snode
|
||||
|
||||
buf_info_list.append(
|
||||
BufferInfo(
|
||||
sched_buf,
|
||||
@ -375,7 +397,7 @@ def compute_memory_timeline(
|
||||
)
|
||||
)
|
||||
|
||||
return buf_info_list, node_to_step
|
||||
return buf_info_list, node_to_step, buf_to_snode_last_use
|
||||
|
||||
|
||||
def estimate_peak_memory(
|
||||
@ -392,7 +414,7 @@ def estimate_peak_memory(
|
||||
List[int]: memory usage at each node (or each step).
|
||||
"""
|
||||
|
||||
buf_info_list, _ = compute_memory_timeline(
|
||||
buf_info_list, _, _ = compute_memory_timeline(
|
||||
nodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
|
||||
@ -416,6 +438,73 @@ def estimate_peak_memory(
|
||||
return (max_memory, memories_at_nodes)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SNodeMemory:
|
||||
size_alloc: int
|
||||
size_free: int
|
||||
|
||||
|
||||
def estimate_peak_memory_allocfree(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> tuple[
|
||||
int,
|
||||
list[tuple[int, int]],
|
||||
dict[BaseSchedulerNode, SNodeMemory],
|
||||
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
|
||||
]:
|
||||
"""
|
||||
Alternative version of estimate_peak_memory, that respects the fact,
|
||||
that every SchedulerNode has multiple phases:
|
||||
1. alloc ( outputs )
|
||||
2. run_kernel
|
||||
3. dealloc last_use buffers
|
||||
estimate_peak_memory collapses memory into one value: size_alloc - size_free
|
||||
While peak memory happens after alloc.
|
||||
|
||||
Duplicating the code to not migrate all callsites at once,
|
||||
In future usages of estimate_peak_memory will migrate to this version.
|
||||
"""
|
||||
|
||||
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
|
||||
nodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
|
||||
# incremental memory changes at each step
|
||||
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))]
|
||||
|
||||
# for each buffer, update memory when created and when freed
|
||||
for buf_info in buf_info_list:
|
||||
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc
|
||||
if buf_info.end_step != -1:
|
||||
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
|
||||
|
||||
snodes_allocfree = {}
|
||||
for i, node in enumerate(nodes):
|
||||
snodes_allocfree[node] = step_idx_allocfree[i]
|
||||
|
||||
max_memory = 0
|
||||
cur_memory = 0
|
||||
snodes_curr_memory = []
|
||||
for t in range(len(nodes)):
|
||||
alloc = step_idx_allocfree[t].size_alloc
|
||||
free = step_idx_allocfree[t].size_free
|
||||
cur_memory += alloc
|
||||
post_alloc = cur_memory
|
||||
max_memory = max(max_memory, cur_memory)
|
||||
cur_memory -= free
|
||||
post_free = cur_memory
|
||||
snodes_curr_memory.append((post_alloc, post_free))
|
||||
|
||||
return (
|
||||
max_memory,
|
||||
snodes_curr_memory,
|
||||
snodes_allocfree,
|
||||
buf_to_snode_last_use,
|
||||
)
|
||||
|
||||
|
||||
def topological_sort_lpmf(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
@ -429,7 +518,7 @@ def topological_sort_lpmf(
|
||||
Buffer memory optimization for video codec application modeled in Simulink
|
||||
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
|
||||
|
||||
The algorithm maintain the max memory so far.
|
||||
The algorithm maintains the max memory so far.
|
||||
At every iteration, for each scheduleable node, it computes:
|
||||
- how much memory needs to be allocated for the output buffers of this node;
|
||||
- how much memory can be freed as a result of executing this node.
|
||||
|
@ -2160,6 +2160,12 @@ class Scheduler:
|
||||
OrderedSet(V.graph.get_output_names()),
|
||||
)
|
||||
if config.reorder_for_compute_comm_overlap:
|
||||
if not config.reorder_for_peak_memory:
|
||||
from .memory import assign_memory_planning_info_for_scheduler_buffers
|
||||
|
||||
assign_memory_planning_info_for_scheduler_buffers(
|
||||
self.nodes, self.name_to_buf
|
||||
)
|
||||
from torch._logging import trace_structured
|
||||
|
||||
trace_structured(
|
||||
@ -2556,7 +2562,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||
buf_info_list, _ = compute_memory_timeline(
|
||||
buf_info_list, _, _ = compute_memory_timeline(
|
||||
self.nodes,
|
||||
name_to_freeable_input_buf,
|
||||
graph_outputs,
|
||||
|
Reference in New Issue
Block a user