Overlap scheduler improvements (#165318)

Bucketing a number of smallish improvements:

- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
-  Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.

Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.

TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)

On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB

on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB

WIP: adding tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
This commit is contained in:
eellison
2025-10-15 11:48:43 -07:00
committed by PyTorch MergeBot
parent bc1f2108d7
commit b3f6d49b69
7 changed files with 197 additions and 80 deletions

View File

@ -34,6 +34,15 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
return (group_name, reduce_op, dtype)
def bucket_key(node: torch.fx.Node) -> Optional[object]:
if is_all_gather_into_tensor(node):
return _ag_group_key(node)
elif is_reduce_scatter_tensor(node):
return _rs_group_key(node)
else:
return None
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
"""
Determine the size of a bucket based on its ID.