Overlap scheduler improvements (#165318)

Bucketing a number of smallish improvements:

- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
-  Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.

Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.

TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)

On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB

on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB

WIP: adding tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
This commit is contained in:
eellison
2025-10-15 11:48:43 -07:00
committed by PyTorch MergeBot
parent bc1f2108d7
commit b3f6d49b69
7 changed files with 197 additions and 80 deletions

View File

@ -356,7 +356,9 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
return size
def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> float:
def estimate_nccl_collective_runtime_from_fx_node(
fx_node: torch.fx.Node, override_size: Optional[int] = None
) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
@ -371,7 +373,10 @@ def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> flo
"""
from torch.distributed.distributed_c10d import _get_group_size_by_name
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
if override_size is None:
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
else:
tensor_storage_size_bytes = override_size
assert not isinstance(fx_node.target, str)
opt_args_kwargs = normalize_function(