Files
pytorch/torch/_inductor/comms_debug.py
IvanKobzarev db44de4c0d [inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)
1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory:
```
    """
    Alternative version of estimate_peak_memory, that respects the fact,
    that every SchedulerNode has multiple phases:
    1. alloc ( outputs )
    2. run_kernel
    3. dealloc last_use buffers
    estimate_peak_memory collapses memory into one value: size_alloc - size_free
    While peak memory happens after alloc.

    Duplicating the code to not migrate all callsites at once,
    In future usages of estimate_peak_memory will migrate to this version.
    """
```

- Applying this in `reorder_communication_preserving_peak_memory` pass.

2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode.

- Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size).

4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder.

What is after this PR:

Iterative recomputation of memory estimations matches full memory estimations.

Active memory is not regressing a lot, but reserved memory is significantly regressed.

Investigation and fix of "reserved" memory will be in following PRs.

BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb
```
[rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step:  1  loss: 12.2722  grad_norm:  4.2192  active_memory: 24.66GiB(25.96%)  reserved_memory: 25.38GiB(26.72%)  tps: 99  tflops: 5.71  mfu: 0.58%
[rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step:  2  loss: 13.1738  grad_norm: 50.5566  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 4,448  tflops: 257.63  mfu: 26.05%
[rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step:  3  loss: 15.6866  grad_norm: 80.0862  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,900  tflops: 341.72  mfu: 34.55%
[rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step:  4  loss: 13.4853  grad_norm:  7.8538  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,881  tflops: 340.57  mfu: 34.44%
[rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step:  5  loss: 16.1191  grad_norm: 53.2481  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,867  tflops: 339.77  mfu: 34.35%
```
REORDER: active: 32Gb reserved: 36Gb
```
[rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step:  1  loss: 12.2490  grad_norm:  4.1944  active_memory: 24.66GiB(25.96%)  reserved_memory: 26.81GiB(28.22%)  tps: 85  tflops: 4.90  mfu: 0.50%
[rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step:  2  loss: 13.1427  grad_norm: 39.5942  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 3,205  tflops: 185.61  mfu: 18.77%
[rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step:  3  loss: 14.6084  grad_norm: 51.0743  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,688  tflops: 329.44  mfu: 33.31%
[rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step:  4  loss: 13.6181  grad_norm:  8.1122  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,744  tflops: 332.68  mfu: 33.64%
[rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step:  5  loss: 15.8913  grad_norm: 59.8510  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,046  tflops: 292.22  mfu: 29.55%
```

REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb
```
[rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step:  1  loss: 12.2646  grad_norm:  4.1282  active_memory: 27.60GiB(29.05%)  reserved_memory: 32.49GiB(34.20%)  tps: 173  tflops: 10.00  mfu: 1.01%
[rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step:  2  loss: 13.2353  grad_norm: 42.4234  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,152  tflops: 356.26  mfu: 36.02%
[rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step:  3  loss: 13.8205  grad_norm: 24.0156  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,169  tflops: 357.29  mfu: 36.13%
[rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step:  4  loss: 13.1033  grad_norm:  9.1167  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,183  tflops: 358.10  mfu: 36.21%
[rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step:  5  loss: 16.3530  grad_norm: 51.8118  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,130  tflops: 355.03  mfu: 35.90%
```

Differential Revision: [D80718143](https://our.internmc.facebook.com/intern/diff/D80718143)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113
Approved by: https://github.com/wconstab, https://github.com/eellison

Co-authored-by: eellison <elias.ellison@gmail.com>
2025-08-22 14:19:57 +00:00

113 lines
4.1 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Union
from torch._logging import trace_structured
from .memory import estimate_peak_memory_allocfree
if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
from .memory import FreeableInputBuffer, SNodeMemory
from .scheduler import BaseSchedulerNode, SchedulerBuffer
def _debug_iterative_memory_recompute(
candidate: BaseSchedulerNode,
gns: list[BaseSchedulerNode],
group_names: str,
snodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
peak_memory: int,
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
tlparse_name: str,
gn_to_bufs_last_use: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
],
) -> bool:
iterative_recompute_error = False
candidate_allocfree = snodes_allocfree[candidate]
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
iter_cm = iter_curr_memory[candidate]
new_cm = est_curr_memory[candidate]
log = ""
if est_peak_memory > peak_memory:
log = "ITERATIVE PEAK DOES NOT MATCH"
iterative_recompute_error = True
if iter_cm != new_cm:
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
iterative_recompute_error = True
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
if iter_gnm != new_gnm:
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
iterative_recompute_error = True
if iterative_recompute_error:
log += (
f"\nCANDIDATE:{candidate.get_name()}"
f"\nGROUP:{group_names}"
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
f"\nCANDIDATE:{candidate.debug_str()}"
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
)
peak_log = ""
for i, (pre, post) in enumerate(snodes_curr_memory):
if est_peak_memory == pre:
n = snodes[i]
peak_log = (
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
)
break
group_log = ""
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
group_log += (
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
)
log += peak_log
log += group_log
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
log += "\n\n".join(
[
(
f"\nSNODE[{i}]\n{n.debug_str()}"
f"\nITER_cur_mem:{iter_curr_memory[n]}"
f"\nESTM_cur_mem:{est_curr_memory[n]}"
f"\nITER_allocfree:{snodes_allocfree[n]}"
f"\nESTM_allocfree:{snodes_allocfree[n]}"
)
for i, n in enumerate(snodes)
]
)
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
print(f"{tname}:\n{log}")
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": tname,
"encoding": "string",
},
payload_fn=lambda: log,
)
return iterative_recompute_error