Files
pytorch/torch
IvanKobzarev db44de4c0d [inductor] Estimate peak memory allocfree and applying to reordering collectives (#160113)
1. Applying @eellison idea from https://github.com/pytorch/pytorch/pull/146562#discussion_r2059363672 for estimate_peak_memory:
```
    """
    Alternative version of estimate_peak_memory, that respects the fact,
    that every SchedulerNode has multiple phases:
    1. alloc ( outputs )
    2. run_kernel
    3. dealloc last_use buffers
    estimate_peak_memory collapses memory into one value: size_alloc - size_free
    While peak memory happens after alloc.

    Duplicating the code to not migrate all callsites at once,
    In future usages of estimate_peak_memory will migrate to this version.
    """
```

- Applying this in `reorder_communication_preserving_peak_memory` pass.

2. Buffers during reordering can change deallocation point, if candidate and group to swap both are users of the f_input_buf and group contains last_use_snode.

- Addressing this tracking the last_use_snode for each buffer and recomputing current memory respecting the change in size_free (group_node after reordering is not the last user of the buffer and its size_free -= buffer_size, while candidate becomes the last user and candidate.size_free += buffer_size).

4. Adding env var `PYTORCH_REORDER_COLLECTIVES_LIMIT` for ablation to limit number of collectives to reorder.

What is after this PR:

Iterative recomputation of memory estimations matches full memory estimations.

Active memory is not regressing a lot, but reserved memory is significantly regressed.

Investigation and fix of "reserved" memory will be in following PRs.

BASELINE (bucketing AG and RS): active: 32Gb reserved: 34Gb
```
[rank0]:[titan] 2025-08-11 11:28:36,798 - root - INFO - step:  1  loss: 12.2722  grad_norm:  4.2192  active_memory: 24.66GiB(25.96%)  reserved_memory: 25.38GiB(26.72%)  tps: 99  tflops: 5.71  mfu: 0.58%
[rank0]:[titan] 2025-08-11 11:28:38,640 - root - INFO - step:  2  loss: 13.1738  grad_norm: 50.5566  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 4,448  tflops: 257.63  mfu: 26.05%
[rank0]:[titan] 2025-08-11 11:28:40,029 - root - INFO - step:  3  loss: 15.6866  grad_norm: 80.0862  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,900  tflops: 341.72  mfu: 34.55%
[rank0]:[titan] 2025-08-11 11:28:41,423 - root - INFO - step:  4  loss: 13.4853  grad_norm:  7.8538  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,881  tflops: 340.57  mfu: 34.44%
[rank0]:[titan] 2025-08-11 11:28:42,820 - root - INFO - step:  5  loss: 16.1191  grad_norm: 53.2481  active_memory: 32.14GiB(33.83%)  reserved_memory: 34.21GiB(36.01%)  tps: 5,867  tflops: 339.77  mfu: 34.35%
```
REORDER: active: 32Gb reserved: 36Gb
```
[rank0]:[titan] 2025-08-11 11:34:32,772 - root - INFO - step:  1  loss: 12.2490  grad_norm:  4.1944  active_memory: 24.66GiB(25.96%)  reserved_memory: 26.81GiB(28.22%)  tps: 85  tflops: 4.90  mfu: 0.50%
[rank0]:[titan] 2025-08-11 11:34:35,329 - root - INFO - step:  2  loss: 13.1427  grad_norm: 39.5942  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 3,205  tflops: 185.61  mfu: 18.77%
[rank0]:[titan] 2025-08-11 11:34:36,770 - root - INFO - step:  3  loss: 14.6084  grad_norm: 51.0743  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,688  tflops: 329.44  mfu: 33.31%
[rank0]:[titan] 2025-08-11 11:34:38,197 - root - INFO - step:  4  loss: 13.6181  grad_norm:  8.1122  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,744  tflops: 332.68  mfu: 33.64%
[rank0]:[titan] 2025-08-11 11:34:39,821 - root - INFO - step:  5  loss: 15.8913  grad_norm: 59.8510  active_memory: 32.14GiB(33.83%)  reserved_memory: 36.40GiB(38.31%)  tps: 5,046  tflops: 292.22  mfu: 29.55%
```

REORDER + SINK_WAITS_ITERATIVE: active: 35Gb reserved: 41Gb
```
[rank0]:[titan] 2025-08-11 11:31:36,119 - root - INFO - step:  1  loss: 12.2646  grad_norm:  4.1282  active_memory: 27.60GiB(29.05%)  reserved_memory: 32.49GiB(34.20%)  tps: 173  tflops: 10.00  mfu: 1.01%
[rank0]:[titan] 2025-08-11 11:31:37,452 - root - INFO - step:  2  loss: 13.2353  grad_norm: 42.4234  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,152  tflops: 356.26  mfu: 36.02%
[rank0]:[titan] 2025-08-11 11:31:38,780 - root - INFO - step:  3  loss: 13.8205  grad_norm: 24.0156  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,169  tflops: 357.29  mfu: 36.13%
[rank0]:[titan] 2025-08-11 11:31:40,106 - root - INFO - step:  4  loss: 13.1033  grad_norm:  9.1167  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,183  tflops: 358.10  mfu: 36.21%
[rank0]:[titan] 2025-08-11 11:31:41,443 - root - INFO - step:  5  loss: 16.3530  grad_norm: 51.8118  active_memory: 35.08GiB(36.92%)  reserved_memory: 41.62GiB(43.80%)  tps: 6,130  tflops: 355.03  mfu: 35.90%
```

Differential Revision: [D80718143](https://our.internmc.facebook.com/intern/diff/D80718143)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113
Approved by: https://github.com/wconstab, https://github.com/eellison

Co-authored-by: eellison <elias.ellison@gmail.com>
2025-08-22 14:19:57 +00:00
..
2025-08-22 11:31:09 +00:00
2025-08-15 02:09:31 +00:00
2025-04-27 09:56:42 +00:00
2025-04-27 09:56:42 +00:00
2025-06-14 18:18:43 +00:00