mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)
1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory: ``` """ Alternative version of estimate_peak_memory, that respects the fact, that every SchedulerNode has multiple phases: 1. alloc ( outputs ) 2. run_kernel 3. dealloc last_use buffers estimate_peak_memory collapses memory into one value: size_alloc - size_free While peak memory happens after alloc. Duplicating the code to not migrate all callsites at once, In future usages of estimate_peak_memory will migrate to this version. """ ``` - Applying this in `reorder_communication_preserving_peak_memory` pass. 2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode. - Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size). 4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder. What is after this PR: Iterative recomputation of memory estimations matches full memory estimations. Active memory is not regressing a lot, but reserved memory is significantly regressed. Investigation and fix of "reserved" memory will be in following PRs. BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb ``` [rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step: 1 loss: 12.2722 grad_norm: 4.2192 active_memory: 24.66GiB(25.96%) reserved_memory: 25.38GiB(26.72%) tps: 99 tflops: 5.71 mfu: 0.58% [rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step: 2 loss: 13.1738 grad_norm: 50.5566 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 4,448 tflops: 257.63 mfu: 26.05% [rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step: 3 loss: 15.6866 grad_norm: 80.0862 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,900 tflops: 341.72 mfu: 34.55% [rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step: 4 loss: 13.4853 grad_norm: 7.8538 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,881 tflops: 340.57 mfu: 34.44% [rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step: 5 loss: 16.1191 grad_norm: 53.2481 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,867 tflops: 339.77 mfu: 34.35% ``` REORDER: active: 32Gb reserved: 36Gb ``` [rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step: 1 loss: 12.2490 grad_norm: 4.1944 active_memory: 24.66GiB(25.96%) reserved_memory: 26.81GiB(28.22%) tps: 85 tflops: 4.90 mfu: 0.50% [rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step: 2 loss: 13.1427 grad_norm: 39.5942 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 3,205 tflops: 185.61 mfu: 18.77% [rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step: 3 loss: 14.6084 grad_norm: 51.0743 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,688 tflops: 329.44 mfu: 33.31% [rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step: 4 loss: 13.6181 grad_norm: 8.1122 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,744 tflops: 332.68 mfu: 33.64% [rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step: 5 loss: 15.8913 grad_norm: 59.8510 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,046 tflops: 292.22 mfu: 29.55% ``` REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb ``` [rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step: 1 loss: 12.2646 grad_norm: 4.1282 active_memory: 27.60GiB(29.05%) reserved_memory: 32.49GiB(34.20%) tps: 173 tflops: 10.00 mfu: 1.01% [rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step: 2 loss: 13.2353 grad_norm: 42.4234 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,152 tflops: 356.26 mfu: 36.02% [rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step: 3 loss: 13.8205 grad_norm: 24.0156 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,169 tflops: 357.29 mfu: 36.13% [rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step: 4 loss: 13.1033 grad_norm: 9.1167 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,183 tflops: 358.10 mfu: 36.21% [rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step: 5 loss: 16.3530 grad_norm: 51.8118 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,130 tflops: 355.03 mfu: 35.90% ``` Differential Revision: [D79886535](https://our.internmc.facebook.com/intern/diff/D79886535) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113 Approved by: https://github.com/wconstab, https://github.com/eellison Co-authored-by: eellison <elias.ellison@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
67b98da1b2
commit
9d18bf01b1
@ -4,7 +4,7 @@ import collections
|
||||
import dataclasses
|
||||
import heapq
|
||||
import logging
|
||||
from typing import Callable, TYPE_CHECKING, TypedDict, Union
|
||||
from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union
|
||||
|
||||
from torch._environment import is_fbcode
|
||||
from torch._utils_internal import signpost_event
|
||||
@ -76,7 +76,7 @@ def get_freeable_input_buf(
|
||||
Create and keep track of all input buffers that can be freed during the program
|
||||
|
||||
Returns:
|
||||
A dictionary containing all freeble input buffers, keyed by their names.
|
||||
A dictionary containing all freeable input buffers, keyed by their names.
|
||||
"""
|
||||
|
||||
def _dep_size_hint(dep: Dep) -> int:
|
||||
@ -303,7 +303,11 @@ def compute_memory_timeline(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
|
||||
) -> tuple[
|
||||
list[BufferInfo],
|
||||
dict[BaseSchedulerNode, int],
|
||||
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
|
||||
]:
|
||||
"""
|
||||
Compute buffer allocation and deallocation sizes and map their
|
||||
lifetime to the node schedule
|
||||
@ -317,15 +321,33 @@ def compute_memory_timeline(
|
||||
|
||||
# get buffers' size and liveliness information
|
||||
buf_info_list: list[BufferInfo] = []
|
||||
buf_to_snode_last_use: dict[
|
||||
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
|
||||
] = {}
|
||||
|
||||
def _get_end_step_and_snode(
|
||||
buf: Union[FreeableInputBuffer, SchedulerBuffer],
|
||||
) -> tuple[int, Optional[BaseSchedulerNode]]:
|
||||
max_step: int = -1
|
||||
max_step_snode: Optional[BaseSchedulerNode] = None
|
||||
succ_nodes = buf.mpi_buffer.succ_nodes
|
||||
if succ_nodes:
|
||||
for succ_node in succ_nodes:
|
||||
step = node_to_step[succ_node]
|
||||
if step > max_step:
|
||||
max_step = step
|
||||
max_step_snode = succ_node
|
||||
assert max_step_snode is not None
|
||||
return max_step, max_step_snode
|
||||
|
||||
# 1. for freeable input buffers
|
||||
for buf_name, input_buf in name_to_freeable_input_buf.items():
|
||||
end_step = (
|
||||
len(nodes) - 1
|
||||
if buf_name in graph_outputs
|
||||
else max(
|
||||
node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes
|
||||
)
|
||||
)
|
||||
end_step = -1
|
||||
if buf_name not in graph_outputs:
|
||||
end_step, end_step_snode = _get_end_step_and_snode(input_buf)
|
||||
assert end_step_snode is not None
|
||||
buf_to_snode_last_use[input_buf] = end_step_snode
|
||||
|
||||
buf_info_list.append(
|
||||
BufferInfo(
|
||||
input_buf,
|
||||
@ -342,17 +364,17 @@ def compute_memory_timeline(
|
||||
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
|
||||
# to be only used by its defining op (e.g., due to fusion when all consumers of
|
||||
# the buffer are fused with its defining op). In such cases, end_step is step.
|
||||
end_step = (
|
||||
len(nodes) - 1
|
||||
if sched_buf.get_name() in graph_outputs
|
||||
else max(
|
||||
[
|
||||
node_to_step[succ_node]
|
||||
for succ_node in sched_buf.mpi_buffer.succ_nodes
|
||||
],
|
||||
default=step,
|
||||
)
|
||||
)
|
||||
buf_name = sched_buf.get_name()
|
||||
end_step = -1
|
||||
if buf_name not in graph_outputs:
|
||||
end_step, end_step_snode = _get_end_step_and_snode(sched_buf)
|
||||
if end_step == -1:
|
||||
end_step = step
|
||||
buf_to_snode_last_use[sched_buf] = node
|
||||
else:
|
||||
assert end_step_snode is not None
|
||||
buf_to_snode_last_use[sched_buf] = end_step_snode
|
||||
|
||||
buf_info_list.append(
|
||||
BufferInfo(
|
||||
sched_buf,
|
||||
@ -363,7 +385,7 @@ def compute_memory_timeline(
|
||||
)
|
||||
)
|
||||
|
||||
return buf_info_list, node_to_step
|
||||
return buf_info_list, node_to_step, buf_to_snode_last_use
|
||||
|
||||
|
||||
def estimate_peak_memory(
|
||||
@ -373,35 +395,84 @@ def estimate_peak_memory(
|
||||
) -> tuple[int, list[int]]:
|
||||
"""
|
||||
Given a list of nodes in their execution order, estimate the peak memory, by
|
||||
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
|
||||
keeping track of the liveness of SchedulerBuffers and FreeableInputBuffers.
|
||||
|
||||
Returns:
|
||||
int: peak memory
|
||||
List[int]: memory usage at each node (or each step).
|
||||
"""
|
||||
# Use estimate_peak_memory_allocfree to keep one impl.
|
||||
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
|
||||
estimate_peak_memory_allocfree(nodes, name_to_freeable_input_buf, graph_outputs)
|
||||
)
|
||||
return peak_memory, [(curr_mem[0] + curr_mem[1]) for curr_mem in snodes_curr_memory]
|
||||
|
||||
buf_info_list, _ = compute_memory_timeline(
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SNodeMemory:
|
||||
size_alloc: int
|
||||
size_free: int
|
||||
|
||||
|
||||
def estimate_peak_memory_allocfree(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> tuple[
|
||||
int,
|
||||
list[tuple[int, int]],
|
||||
dict[BaseSchedulerNode, SNodeMemory],
|
||||
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
|
||||
]:
|
||||
"""
|
||||
Alternative version of estimate_peak_memory, that respects the fact,
|
||||
that every SchedulerNode has multiple phases:
|
||||
1. alloc ( outputs )
|
||||
2. run_kernel
|
||||
3. dealloc last_use buffers
|
||||
estimate_peak_memory collapses memory into one value: size_alloc - size_free
|
||||
While peak memory happens after alloc.
|
||||
|
||||
Duplicating the code to not migrate all callsites at once,
|
||||
In future usages of estimate_peak_memory will migrate to this version.
|
||||
"""
|
||||
|
||||
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
|
||||
nodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
|
||||
# incremental memory changes at each step
|
||||
memory = [0 for _ in range(len(nodes) + 1)]
|
||||
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))]
|
||||
|
||||
# for each buffer, update memory when created and when freed
|
||||
for buf_info in buf_info_list:
|
||||
memory[buf_info.start_step] += buf_info.size_alloc
|
||||
memory[buf_info.end_step + 1] -= buf_info.size_free
|
||||
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc
|
||||
if buf_info.end_step != -1:
|
||||
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
|
||||
|
||||
snodes_allocfree = {}
|
||||
for i, node in enumerate(nodes):
|
||||
snodes_allocfree[node] = step_idx_allocfree[i]
|
||||
|
||||
# get peak memory by compute the cumulative memories
|
||||
max_memory = 0
|
||||
cur_memory = 0
|
||||
memories_at_nodes = []
|
||||
for t in range(len(nodes) + 1):
|
||||
cur_memory += memory[t]
|
||||
memories_at_nodes.append(cur_memory)
|
||||
snodes_curr_memory = []
|
||||
for t in range(len(nodes)):
|
||||
alloc = step_idx_allocfree[t].size_alloc
|
||||
free = step_idx_allocfree[t].size_free
|
||||
cur_memory += alloc
|
||||
post_alloc = cur_memory
|
||||
max_memory = max(max_memory, cur_memory)
|
||||
cur_memory -= free
|
||||
post_free = cur_memory
|
||||
snodes_curr_memory.append((post_alloc, post_free))
|
||||
|
||||
return (max_memory, memories_at_nodes)
|
||||
return (
|
||||
max_memory,
|
||||
snodes_curr_memory,
|
||||
snodes_allocfree,
|
||||
buf_to_snode_last_use,
|
||||
)
|
||||
|
||||
|
||||
def topological_sort_lpmf(
|
||||
@ -417,7 +488,7 @@ def topological_sort_lpmf(
|
||||
Buffer memory optimization for video codec application modeled in Simulink
|
||||
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
|
||||
|
||||
The algorithm maintain the max memory so far.
|
||||
The algorithm maintains the max memory so far.
|
||||
At every iteration, for each scheduleable node, it computes:
|
||||
- how much memory needs to be allocated for the output buffers of this node;
|
||||
- how much memory can be freed as a result of executing this node.
|
||||
|
Reference in New Issue
Block a user