mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:30:26 +08:00
1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory: ``` """ Alternative version of estimate_peak_memory, that respects the fact, that every SchedulerNode has multiple phases: 1. alloc ( outputs ) 2. run_kernel 3. dealloc last_use buffers estimate_peak_memory collapses memory into one value: size_alloc - size_free While peak memory happens after alloc. Duplicating the code to not migrate all callsites at once, In future usages of estimate_peak_memory will migrate to this version. """ ``` - Applying this in `reorder_communication_preserving_peak_memory` pass. 2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode. - Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size). 4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder. What is after this PR: Iterative recomputation of memory estimations matches full memory estimations. Active memory is not regressing a lot, but reserved memory is significantly regressed. Investigation and fix of "reserved" memory will be in following PRs. BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb ``` [rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step: 1 loss: 12.2722 grad_norm: 4.2192 active_memory: 24.66GiB(25.96%) reserved_memory: 25.38GiB(26.72%) tps: 99 tflops: 5.71 mfu: 0.58% [rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step: 2 loss: 13.1738 grad_norm: 50.5566 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 4,448 tflops: 257.63 mfu: 26.05% [rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step: 3 loss: 15.6866 grad_norm: 80.0862 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,900 tflops: 341.72 mfu: 34.55% [rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step: 4 loss: 13.4853 grad_norm: 7.8538 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,881 tflops: 340.57 mfu: 34.44% [rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step: 5 loss: 16.1191 grad_norm: 53.2481 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,867 tflops: 339.77 mfu: 34.35% ``` REORDER: active: 32Gb reserved: 36Gb ``` [rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step: 1 loss: 12.2490 grad_norm: 4.1944 active_memory: 24.66GiB(25.96%) reserved_memory: 26.81GiB(28.22%) tps: 85 tflops: 4.90 mfu: 0.50% [rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step: 2 loss: 13.1427 grad_norm: 39.5942 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 3,205 tflops: 185.61 mfu: 18.77% [rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step: 3 loss: 14.6084 grad_norm: 51.0743 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,688 tflops: 329.44 mfu: 33.31% [rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step: 4 loss: 13.6181 grad_norm: 8.1122 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,744 tflops: 332.68 mfu: 33.64% [rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step: 5 loss: 15.8913 grad_norm: 59.8510 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,046 tflops: 292.22 mfu: 29.55% ``` REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb ``` [rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step: 1 loss: 12.2646 grad_norm: 4.1282 active_memory: 27.60GiB(29.05%) reserved_memory: 32.49GiB(34.20%) tps: 173 tflops: 10.00 mfu: 1.01% [rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step: 2 loss: 13.2353 grad_norm: 42.4234 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,152 tflops: 356.26 mfu: 36.02% [rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step: 3 loss: 13.8205 grad_norm: 24.0156 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,169 tflops: 357.29 mfu: 36.13% [rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step: 4 loss: 13.1033 grad_norm: 9.1167 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,183 tflops: 358.10 mfu: 36.21% [rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step: 5 loss: 16.3530 grad_norm: 51.8118 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,130 tflops: 355.03 mfu: 35.90% ``` Differential Revision: [D80718143](https://our.internmc.facebook.com/intern/diff/D80718143) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113 Approved by: https://github.com/wconstab, https://github.com/eellison Co-authored-by: eellison <elias.ellison@gmail.com>
113 lines
4.1 KiB
Python
113 lines
4.1 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Union
|
|
|
|
from torch._logging import trace_structured
|
|
|
|
from .memory import estimate_peak_memory_allocfree
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from .memory import FreeableInputBuffer, SNodeMemory
|
|
from .scheduler import BaseSchedulerNode, SchedulerBuffer
|
|
|
|
|
|
def _debug_iterative_memory_recompute(
|
|
candidate: BaseSchedulerNode,
|
|
gns: list[BaseSchedulerNode],
|
|
group_names: str,
|
|
snodes: list[BaseSchedulerNode],
|
|
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
|
graph_outputs: OrderedSet[str],
|
|
peak_memory: int,
|
|
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
|
|
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
|
|
tlparse_name: str,
|
|
gn_to_bufs_last_use: dict[
|
|
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
|
|
],
|
|
) -> bool:
|
|
iterative_recompute_error = False
|
|
candidate_allocfree = snodes_allocfree[candidate]
|
|
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
|
|
estimate_peak_memory_allocfree(
|
|
snodes, name_to_freeable_input_buf, graph_outputs
|
|
)
|
|
)
|
|
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
|
|
iter_cm = iter_curr_memory[candidate]
|
|
new_cm = est_curr_memory[candidate]
|
|
log = ""
|
|
if est_peak_memory > peak_memory:
|
|
log = "ITERATIVE PEAK DOES NOT MATCH"
|
|
iterative_recompute_error = True
|
|
if iter_cm != new_cm:
|
|
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
|
|
iterative_recompute_error = True
|
|
for i, gn in enumerate(gns):
|
|
iter_gnm = iter_curr_memory[gn]
|
|
new_gnm = est_curr_memory[gn]
|
|
if iter_gnm != new_gnm:
|
|
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
|
|
iterative_recompute_error = True
|
|
if iterative_recompute_error:
|
|
log += (
|
|
f"\nCANDIDATE:{candidate.get_name()}"
|
|
f"\nGROUP:{group_names}"
|
|
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
|
|
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
|
|
f"\nCANDIDATE:{candidate.debug_str()}"
|
|
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
|
|
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
|
|
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
|
|
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
|
|
)
|
|
peak_log = ""
|
|
for i, (pre, post) in enumerate(snodes_curr_memory):
|
|
if est_peak_memory == pre:
|
|
n = snodes[i]
|
|
peak_log = (
|
|
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
|
|
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
|
|
)
|
|
break
|
|
group_log = ""
|
|
for i, gn in enumerate(gns):
|
|
iter_gnm = iter_curr_memory[gn]
|
|
new_gnm = est_curr_memory[gn]
|
|
group_log += (
|
|
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
|
|
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
|
|
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
|
|
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
|
|
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
|
|
)
|
|
log += peak_log
|
|
log += group_log
|
|
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
|
|
log += "\n\n".join(
|
|
[
|
|
(
|
|
f"\nSNODE[{i}]\n{n.debug_str()}"
|
|
f"\nITER_cur_mem:{iter_curr_memory[n]}"
|
|
f"\nESTM_cur_mem:{est_curr_memory[n]}"
|
|
f"\nITER_allocfree:{snodes_allocfree[n]}"
|
|
f"\nESTM_allocfree:{snodes_allocfree[n]}"
|
|
)
|
|
for i, n in enumerate(snodes)
|
|
]
|
|
)
|
|
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
|
|
print(f"{tname}:\n{log}")
|
|
trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": tname,
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: log,
|
|
)
|
|
return iterative_recompute_error
|