mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory: ``` """ Alternative version of estimate_peak_memory, that respects the fact, that every SchedulerNode has multiple phases: 1. alloc ( outputs ) 2. run_kernel 3. dealloc last_use buffers estimate_peak_memory collapses memory into one value: size_alloc - size_free While peak memory happens after alloc. Duplicating the code to not migrate all callsites at once, In future usages of estimate_peak_memory will migrate to this version. """ ``` - Applying this in `reorder_communication_preserving_peak_memory` pass. 2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode. - Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size). 4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder. What is after this PR: Iterative recomputation of memory estimations matches full memory estimations. Active memory is not regressing a lot, but reserved memory is significantly regressed. Investigation and fix of "reserved" memory will be in following PRs. BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb ``` [rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step: 1 loss: 12.2722 grad_norm: 4.2192 active_memory: 24.66GiB(25.96%) reserved_memory: 25.38GiB(26.72%) tps: 99 tflops: 5.71 mfu: 0.58% [rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step: 2 loss: 13.1738 grad_norm: 50.5566 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 4,448 tflops: 257.63 mfu: 26.05% [rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step: 3 loss: 15.6866 grad_norm: 80.0862 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,900 tflops: 341.72 mfu: 34.55% [rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step: 4 loss: 13.4853 grad_norm: 7.8538 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,881 tflops: 340.57 mfu: 34.44% [rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step: 5 loss: 16.1191 grad_norm: 53.2481 active_memory: 32.14GiB(33.83%) reserved_memory: 34.21GiB(36.01%) tps: 5,867 tflops: 339.77 mfu: 34.35% ``` REORDER: active: 32Gb reserved: 36Gb ``` [rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step: 1 loss: 12.2490 grad_norm: 4.1944 active_memory: 24.66GiB(25.96%) reserved_memory: 26.81GiB(28.22%) tps: 85 tflops: 4.90 mfu: 0.50% [rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step: 2 loss: 13.1427 grad_norm: 39.5942 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 3,205 tflops: 185.61 mfu: 18.77% [rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step: 3 loss: 14.6084 grad_norm: 51.0743 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,688 tflops: 329.44 mfu: 33.31% [rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step: 4 loss: 13.6181 grad_norm: 8.1122 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,744 tflops: 332.68 mfu: 33.64% [rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step: 5 loss: 15.8913 grad_norm: 59.8510 active_memory: 32.14GiB(33.83%) reserved_memory: 36.40GiB(38.31%) tps: 5,046 tflops: 292.22 mfu: 29.55% ``` REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb ``` [rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step: 1 loss: 12.2646 grad_norm: 4.1282 active_memory: 27.60GiB(29.05%) reserved_memory: 32.49GiB(34.20%) tps: 173 tflops: 10.00 mfu: 1.01% [rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step: 2 loss: 13.2353 grad_norm: 42.4234 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,152 tflops: 356.26 mfu: 36.02% [rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step: 3 loss: 13.8205 grad_norm: 24.0156 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,169 tflops: 357.29 mfu: 36.13% [rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step: 4 loss: 13.1033 grad_norm: 9.1167 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,183 tflops: 358.10 mfu: 36.21% [rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step: 5 loss: 16.3530 grad_norm: 51.8118 active_memory: 35.08GiB(36.92%) reserved_memory: 41.62GiB(43.80%) tps: 6,130 tflops: 355.03 mfu: 35.90% ``` Differential Revision: [D80718143](https://our.internmc.facebook.com/intern/diff/D80718143) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113 Approved by: https://github.com/wconstab, https://github.com/eellison Co-authored-by: eellison <elias.ellison@gmail.com>