[inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)

1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory:
```
    """
    Alternative version of estimate_peak_memory, that respects the fact,
    that every SchedulerNode has multiple phases:
    1. alloc ( outputs )
    2. run_kernel
    3. dealloc last_use buffers
    estimate_peak_memory collapses memory into one value: size_alloc - size_free
    While peak memory happens after alloc.

    Duplicating the code to not migrate all callsites at once,
    In future usages of estimate_peak_memory will migrate to this version.
    """
```

- Applying this in `reorder_communication_preserving_peak_memory` pass.

2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode.

- Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size).

4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder.

What is after this PR:

Iterative recomputation of memory estimations matches full memory estimations.

Active memory is not regressing a lot, but reserved memory is significantly regressed.

Investigation and fix of "reserved" memory will be in following PRs.

BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb
```
[rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step:  1  loss: 12.2722  grad_norm:  4.2192  active_memory: 24.66GiB(25.96%)  reserved_memory: 25.38GiB(26.72%)  tps: 99  tflops: 5.71  mfu: 0.58%
[rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step:  2  loss: 13.1738  grad_norm: 50.5566  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 4,448  tflops: 257.63  mfu: 26.05%
[rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step:  3  loss: 15.6866  grad_norm: 80.0862  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,900  tflops: 341.72  mfu: 34.55%
[rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step:  4  loss: 13.4853  grad_norm:  7.8538  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,881  tflops: 340.57  mfu: 34.44%
[rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step:  5  loss: 16.1191  grad_norm: 53.2481  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,867  tflops: 339.77  mfu: 34.35%
```
REORDER: active: 32Gb reserved: 36Gb
```
[rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step:  1  loss: 12.2490  grad_norm:  4.1944  active_memory: 24.66GiB(25.96%)  reserved_memory: 26.81GiB(28.22%)  tps: 85  tflops: 4.90  mfu: 0.50%
[rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step:  2  loss: 13.1427  grad_norm: 39.5942  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 3,205  tflops: 185.61  mfu: 18.77%
[rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step:  3  loss: 14.6084  grad_norm: 51.0743  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,688  tflops: 329.44  mfu: 33.31%
[rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step:  4  loss: 13.6181  grad_norm:  8.1122  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,744  tflops: 332.68  mfu: 33.64%
[rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step:  5  loss: 15.8913  grad_norm: 59.8510  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,046  tflops: 292.22  mfu: 29.55%
```

REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb
```
[rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step:  1  loss: 12.2646  grad_norm:  4.1282  active_memory: 27.60GiB(29.05%)  reserved_memory: 32.49GiB(34.20%)  tps: 173  tflops: 10.00  mfu: 1.01%
[rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step:  2  loss: 13.2353  grad_norm: 42.4234  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,152  tflops: 356.26  mfu: 36.02%
[rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step:  3  loss: 13.8205  grad_norm: 24.0156  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,169  tflops: 357.29  mfu: 36.13%
[rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step:  4  loss: 13.1033  grad_norm:  9.1167  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,183  tflops: 358.10  mfu: 36.21%
[rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step:  5  loss: 16.3530  grad_norm: 51.8118  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,130  tflops: 355.03  mfu: 35.90%
```

Differential Revision: [D79886535](https://our.internmc.facebook.com/intern/diff/D79886535)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113
Approved by: https://github.com/wconstab, https://github.com/eellison

Co-authored-by: eellison <elias.ellison@gmail.com>
This commit is contained in:
IvanKobzarev
2025-08-18 06:14:34 -07:00
committed by PyTorch MergeBot
parent 67b98da1b2
commit 9d18bf01b1
5 changed files with 741 additions and 190 deletions

View File

@ -4,7 +4,7 @@ import collections
import dataclasses
import heapq
import logging
from typing import Callable, TYPE_CHECKING, TypedDict, Union
from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union
from torch._environment import is_fbcode
from torch._utils_internal import signpost_event
@ -76,7 +76,7 @@ def get_freeable_input_buf(
Create and keep track of all input buffers that can be freed during the program
Returns:
A dictionary containing all freeble input buffers, keyed by their names.
A dictionary containing all freeable input buffers, keyed by their names.
"""
def _dep_size_hint(dep: Dep) -> int:
@ -303,7 +303,11 @@ def compute_memory_timeline(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
) -> tuple[
list[BufferInfo],
dict[BaseSchedulerNode, int],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Compute buffer allocation and deallocation sizes and map their
lifetime to the node schedule
@ -317,15 +321,33 @@ def compute_memory_timeline(
# get buffers' size and liveliness information
buf_info_list: list[BufferInfo] = []
buf_to_snode_last_use: dict[
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
] = {}
def _get_end_step_and_snode(
buf: Union[FreeableInputBuffer, SchedulerBuffer],
) -> tuple[int, Optional[BaseSchedulerNode]]:
max_step: int = -1
max_step_snode: Optional[BaseSchedulerNode] = None
succ_nodes = buf.mpi_buffer.succ_nodes
if succ_nodes:
for succ_node in succ_nodes:
step = node_to_step[succ_node]
if step > max_step:
max_step = step
max_step_snode = succ_node
assert max_step_snode is not None
return max_step, max_step_snode
# 1. for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items():
end_step = (
len(nodes) - 1
if buf_name in graph_outputs
else max(
node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes
)
)
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(input_buf)
assert end_step_snode is not None
buf_to_snode_last_use[input_buf] = end_step_snode
buf_info_list.append(
BufferInfo(
input_buf,
@ -342,17 +364,17 @@ def compute_memory_timeline(
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
# to be only used by its defining op (e.g., due to fusion when all consumers of
# the buffer are fused with its defining op). In such cases, end_step is step.
end_step = (
len(nodes) - 1
if sched_buf.get_name() in graph_outputs
else max(
[
node_to_step[succ_node]
for succ_node in sched_buf.mpi_buffer.succ_nodes
],
default=step,
)
)
buf_name = sched_buf.get_name()
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(sched_buf)
if end_step == -1:
end_step = step
buf_to_snode_last_use[sched_buf] = node
else:
assert end_step_snode is not None
buf_to_snode_last_use[sched_buf] = end_step_snode
buf_info_list.append(
BufferInfo(
sched_buf,
@ -363,7 +385,7 @@ def compute_memory_timeline(
)
)
return buf_info_list, node_to_step
return buf_info_list, node_to_step, buf_to_snode_last_use
def estimate_peak_memory(
@ -373,35 +395,84 @@ def estimate_peak_memory(
) -> tuple[int, list[int]]:
"""
Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
keeping track of the liveness of SchedulerBuffers and FreeableInputBuffers.
Returns:
int: peak memory
List[int]: memory usage at each node (or each step).
"""
# Use estimate_peak_memory_allocfree to keep one impl.
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
estimate_peak_memory_allocfree(nodes, name_to_freeable_input_buf, graph_outputs)
)
return peak_memory, [(curr_mem[0] + curr_mem[1]) for curr_mem in snodes_curr_memory]
buf_info_list, _ = compute_memory_timeline(
@dataclasses.dataclass
class SNodeMemory:
size_alloc: int
size_free: int
def estimate_peak_memory_allocfree(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[
int,
list[tuple[int, int]],
dict[BaseSchedulerNode, SNodeMemory],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Alternative version of estimate_peak_memory, that respects the fact,
that every SchedulerNode has multiple phases:
1. alloc ( outputs )
2. run_kernel
3. dealloc last_use buffers
estimate_peak_memory collapses memory into one value: size_alloc - size_free
While peak memory happens after alloc.
Duplicating the code to not migrate all callsites at once,
In future usages of estimate_peak_memory will migrate to this version.
"""
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs
)
# incremental memory changes at each step
memory = [0 for _ in range(len(nodes) + 1)]
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))]
# for each buffer, update memory when created and when freed
for buf_info in buf_info_list:
memory[buf_info.start_step] += buf_info.size_alloc
memory[buf_info.end_step + 1] -= buf_info.size_free
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc
if buf_info.end_step != -1:
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
snodes_allocfree = {}
for i, node in enumerate(nodes):
snodes_allocfree[node] = step_idx_allocfree[i]
# get peak memory by compute the cumulative memories
max_memory = 0
cur_memory = 0
memories_at_nodes = []
for t in range(len(nodes) + 1):
cur_memory += memory[t]
memories_at_nodes.append(cur_memory)
snodes_curr_memory = []
for t in range(len(nodes)):
alloc = step_idx_allocfree[t].size_alloc
free = step_idx_allocfree[t].size_free
cur_memory += alloc
post_alloc = cur_memory
max_memory = max(max_memory, cur_memory)
cur_memory -= free
post_free = cur_memory
snodes_curr_memory.append((post_alloc, post_free))
return (max_memory, memories_at_nodes)
return (
max_memory,
snodes_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
)
def topological_sort_lpmf(
@ -417,7 +488,7 @@ def topological_sort_lpmf(
Buffer memory optimization for video codec application modeled in Simulink
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
The algorithm maintain the max memory so far.
The algorithm maintains the max memory so far.
At every iteration, for each scheduleable node, it computes:
- how much memory needs to be allocated for the output buffers of this node;
- how much memory can be freed as a result of executing this node.