Bucketing of multiple dtypes to be processed in one bucketed collective.
First target is to bucket bf16 and f32, but already can be used with other dtypes.
For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162470
Approved by: https://github.com/eellison
Bucketing a number of smallish improvements:
- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
- Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.
Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.
TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)
On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB
on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB
WIP: adding tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
tl;dr performs bucketing while preserving comm-compute overlap.
In comm-compute overlap we will have a graph with:
```
def foo(...):
ag = all_gather(...)
hiding_compute = mm(...)
wait(ag)
```
There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.
Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.
We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.
TODO:
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163960
Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754, #163959
The previous implementation was creating `n_gpu * n_tensors` intermediate tensors, which was adding a lot of CPU overhead, specially given that inductor was generating a number of individual tensor copy kernels for `torch.cat` .
This PR changes the implementation so that only `n_tensors` are created, making the CPU overhead proportional to the number of tensors being bucketed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159723
Approved by: https://github.com/IvanKobzarev
The output of a reduce_scatter is n_gpu times smaller than its input, while the output of an all_gather is n_gpu times larger than its input. This means that in the current heuristic for bucketing reduce_scatter, we would need to use a bucket size which is n_gpu times larger than the bucket for all_gather, making it gpu-dependent and less intuitive. This PRs propose to use instead the max between the input and output sizes, so that one can use the same bucket_size value for both passes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159717
Approved by: https://github.com/wconstab
Rewriting bucketing of all_gather and reduce_scatter with defining of "merge graph" via torch function.
`all_gather_merge_fn_to_trace`
`reduce_scatter_merge_fn_to_trace`
(Instead of creating nodes and doing FakeTensor prop manually)
This allows to experiment with merge function.
Used foreach_copy_ in merging function for all_gather - added lowering for inductor for `foreach_copy_`
Adding topological sort after bucketing passes (comment in post_grad.py):
```
# Fx collectives bucketing passes require topological sort for the cases:
# when bucketed collectives have users before the last collective in the bucket
# AND when inputs of bucketed collective have ancestors after the first collective in the bucket.
#
# In this case we can not manually pick the place for bucketed collective insertion.
# But we are guaranteed by the bucketing (independent collectives in the bucket),
# that it is possible to reorder nodes to satisfy all ordering requirements.
#
# --- before bucketing ---
# in0 = ...
# wait_ag0 = ag(in0)
# user0(wait_ag0)
# ...
# pre_in1 = ...
# in1 = transform(pre_in1)
# wait_ag1 = ag(in1)
# user1(wait_ag1)
#
# --- after bucketing ---
#
# in0 = ...
# user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective.
#
# pre_in1 = ...
# in1 = transform(pre_in1)
# ag_bucket(in0+in1)
# wait_bucket
# wait_ag0 = wait_bucket[0]
# wait_ag1 = wait_bucket[1]
# user1(wait_ag1)
````
Correctness of the passes verified by loss curve for llama3 8b for simple_fsdp and for autoparallel:
<img width="1364" height="495" alt="Screenshot 2025-07-22 at 14 27 28" src="https://github.com/user-attachments/assets/67b2cabb-3206-450b-b529-e23c24292fc6" />
<img width="1355" height="509" alt="Screenshot 2025-07-22 at 14 27 56" src="https://github.com/user-attachments/assets/4d0e6b25-2eb1-47b2-8d68-dcec185239c4" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158663
Approved by: https://github.com/wconstab
Main changes:
- bucketing collectives only from the same process_group by group_name
- Support of groups like [0,2,4,6], [0,1,3,5] using `rank_idx_dict` for in pass operations for slice idxs etc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158632
Approved by: https://github.com/wconstab
Porting passes to bucket all_gathers
The main logic of the pass is done via
1. Searching for all all_gathers from the buckets
Copying tests from @wconstab PR to test compatibility with reordering.
Test checks only compatibility, as because of (3) the joint all_gather will be scheduled already as early as possible and no space for reordering.
Pass changes:
Using mutation ops to match performance of fsdp, in future the perfect scenario will be to have only functional graph, that inductor does all memory optimizations on its own without mutable ops.
Inductor changes:
Adding foreach_copy_ lowering
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157396
Approved by: https://github.com/wconstab