Overlap scheduler improvements (#165318)

Bucketing a number of smallish improvements:

- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
-  Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.

Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.

TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)

On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB

on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB

WIP: adding tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
This commit is contained in:
eellison
2025-10-15 11:48:43 -07:00
committed by PyTorch MergeBot
parent bc1f2108d7
commit b3f6d49b69
7 changed files with 197 additions and 80 deletions

View File

@ -70,6 +70,8 @@ def get_patches():
"force_disable_caches": True,
# Messes up existing test strings
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
# interferes with testing, / custom estimation
"test_configs.assume_bucketing_reduces_latency": False,
}
@ -364,6 +366,8 @@ def get_bucket_patches(compute_multiplier=1.0):
"force_disable_caches": True,
# messes up test strings
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
# interferes with testing, / custom estimation
"test_configs.assume_bucketing_reduces_latency": False,
}
@ -579,7 +583,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches(2.0))
def test_bucketing_split_for_overlap_blocking(self):
def test_bucketing_split_for_overlap_blocking_no_deps(self):
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
def func(a, b, c, d, *, ranks):

View File

@ -252,11 +252,6 @@ class TestMemoryTracker(InductorTestCase):
if node.op not in ("placeholder", "get_attr", "output")
]
if len(compute_nodes) < 3:
self.skipTest(
f"Need at least 3 compute nodes, got {len(compute_nodes)}"
)
# Test original order: zeros_like, add, sum
# zeros gets freed after sum (last use of zeros)
memory_tracker1 = MemoryTracker(fx_graph.graph, device_filter=device_filter)

View File

@ -356,7 +356,9 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
return size
def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> float:
def estimate_nccl_collective_runtime_from_fx_node(
fx_node: torch.fx.Node, override_size: Optional[int] = None
) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
@ -371,7 +373,10 @@ def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> flo
"""
from torch.distributed.distributed_c10d import _get_group_size_by_name
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
if override_size is None:
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
else:
tensor_storage_size_bytes = override_size
assert not isinstance(fx_node.target, str)
opt_args_kwargs = normalize_function(

View File

@ -2072,6 +2072,9 @@ class test_configs:
# to be migrated when ready for use
aten_fx_overlap_preserving_bucketing = False
# mostly disabled testing
assume_bucketing_reduces_latency = True
# to be migrated when ready for use
# runtime estimation function for ops
# for user-defined estimation function, pass in the function handle

View File

@ -34,6 +34,15 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
return (group_name, reduce_op, dtype)
def bucket_key(node: torch.fx.Node) -> Optional[object]:
if is_all_gather_into_tensor(node):
return _ag_group_key(node)
elif is_reduce_scatter_tensor(node):
return _rs_group_key(node)
else:
return None
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
"""
Determine the size of a bucket based on its ID.

View File

@ -1,12 +1,10 @@
from collections import defaultdict
from typing import Optional
import torch
import torch.fx as fx
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
from torch._inductor.fx_passes.bucketing import (
_ag_group_key,
_rs_group_key,
bucket_key,
is_all_gather_into_tensor as is_all_gather,
is_reduce_scatter_tensor as is_reduce_scatter,
is_wait_tensor,
@ -15,15 +13,6 @@ from torch._inductor.fx_passes.overlap_scheduling import CollBucket, CollectiveI
from torch.utils._ordered_set import OrderedSet
def bucket_key(node: torch.fx.Node) -> Optional[object]:
if is_all_gather(node):
return _ag_group_key(node)
elif is_reduce_scatter(node):
return _rs_group_key(node)
else:
return None
class OverlapPreservingBucketer:
"""
Buckets collective operations while preserving compute-collective overlap relationships.

View File

@ -17,15 +17,31 @@ from torch._inductor.fx_passes.memory_estimator import (
build_memory_profile,
MemoryTracker,
)
from torch.fx.operator_schemas import normalize_function
from torch.utils._mode_utils import no_dispatch
from torch.utils._ordered_set import OrderedSet
log = logging.getLogger(__name__)
from torch._inductor.fx_passes.bucketing import bucket_key
from ..pattern_matcher import stable_topological_sort
def get_group_name(n: fx.Node) -> str:
"""Extract the group name from a collective operation node."""
opt_args_kwargs = normalize_function(
n.target, # type: ignore[arg-type]
args=n.args,
kwargs=n.kwargs,
normalize_to_only_use_kwargs=True,
)
assert opt_args_kwargs is not None
_, kwargs = opt_args_kwargs
return kwargs["group_name"]
def get_custom_estimation(n: fx.Node) -> Optional[float]:
runtime_estimation = torch._inductor.config.test_configs.estimate_aten_runtime
if runtime_estimation == "default":
@ -35,12 +51,13 @@ def get_custom_estimation(n: fx.Node) -> Optional[float]:
return runtime_estimation(n)
def estimate_collective_time(n: fx.Node) -> float:
def estimate_collective_time(n: fx.Node, override_size: Optional[int] = None) -> float:
"""Estimate the runtime of a collective operation, optionally with an overridden size."""
if (est := get_custom_estimation(n)) is not None:
return est
return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
n
n, override_size
)
@ -55,6 +72,10 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
def is_compute_node(n: fx.Node) -> bool:
"""
Should we consider this node computationally expensive ?
Currently uses flop registration, but we could expand more generally.
"""
return (
getattr(n.target, "overloadpacket", None)
in torch.utils.flop_counter.flop_registry
@ -173,6 +194,11 @@ class CollBucket:
total_bytes: int = 0
def gb_to_bytes(gb: float) -> int:
"""Convert gigabytes to bytes."""
return int(gb * 1024 * 1024 * 1024)
class OverlapScheduler:
"""
Scheduler that reorders operations to maximize compute-collective overlap.
@ -180,7 +206,7 @@ class OverlapScheduler:
The reordering is done as a scheduling pass. We maintain a priority queue of
schedulable nodes. The nodes are ranked by:
1) the compute node depth they dominate. this allows reordering locally, such as with
1) the compute node index they dominate. this allows reordering locally, such as with
parallel mms, and also allows overlapping reduce scatter nodes outputs in the backward
with compute by deferring their waits.
@ -200,15 +226,16 @@ class OverlapScheduler:
def __init__(
self,
gm: torch.fx.GraphModule,
max_in_flight_gb: float = 2.0,
max_in_flight_gb: float = 0.5,
compute_overlap_multipler: float = 2.0,
max_coll_distance: int = 1000,
max_compute_pre_fetch: int = 5,
):
self.gm = gm
self.graph = gm.graph
self.compute_overlap_multipler = compute_overlap_multipler
self.max_node_distance = max_coll_distance
self.max_in_flight_bytes: int = int(max_in_flight_gb * 1024 * 1024 * 1024)
self.max_in_flight_bytes: int = gb_to_bytes(max_in_flight_gb)
# Build structures
stable_topological_sort(self.graph)
@ -231,8 +258,9 @@ class OverlapScheduler:
self.wait_to_start: dict[fx.Node, fx.Node] = {}
self._identify_collectives()
self.compute_depth = self._calculate_compute_node_depth()
self.compute_index_domination = self._calculate_compute_node_domination_index()
self.compute_nodes = [n for n in self.nodes if is_compute_node(n)]
self.current_compute_index = 0
# Scheduling state
self.potentially_hidden_collectives = (
@ -249,6 +277,7 @@ class OverlapScheduler:
self.in_flight: dict[fx.Node, CollectiveInfo] = {} # start -> info
self.in_flight_bytes = 0
self.scheduled: OrderedSet[fx.Node] = OrderedSet()
self.max_compute_pre_fetch = max_compute_pre_fetch
def _collect_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
"""Collect all ancestors for each node."""
@ -260,6 +289,10 @@ class OverlapScheduler:
return ancestors
def off_compute_path(self, n: fx.Node) -> bool:
"""Check if a node is off the compute path (doesn't block any compute)."""
return self.compute_index_domination[n] == sys.maxsize
def _identify_collectives(self) -> None:
"""Identify all collective operations."""
for node in self.nodes:
@ -278,51 +311,30 @@ class OverlapScheduler:
self.wait_to_start[node] = start
self.unscheduled_collectives.add(start)
def _calculate_compute_node_depth(self) -> dict[fx.Node, int]:
"""Compute forward depth and minimum dominance depth (infinity if blocks no compute)."""
# First pass: forward compute depth
in_degree: dict[fx.Node, int] = {}
compute_depth: dict[fx.Node, int] = {}
queue: list[fx.Node] = []
def _calculate_compute_node_domination_index(self) -> dict[fx.Node, int]:
"""
Compute the topological index of the earliest compute node each node dominates.
Compute nodes are assigned indices based on their topological order (0, 1, 2, ...).
For each node, returns the minimum index of compute nodes it blocks/dominates.
Returns sys.maxsize if the node doesn't block any compute nodes.
"""
compute_node_index: dict[fx.Node, int] = {}
for node in self.graph.nodes:
num_inputs = len(node.all_input_nodes)
if num_inputs == 0:
queue.append(node)
else:
in_degree[node] = num_inputs
while queue:
node = queue.pop()
max_input_depth = max(
(compute_depth[inp] for inp in node.all_input_nodes), default=0
)
compute_depth[node] = max_input_depth + is_compute_node(node)
for use in node.users:
in_degree[use] -= 1
if in_degree[use] == 0:
queue.append(use)
# Second pass: minimum dominance (what's the earliest compute this blocks)
compute_depth_dominance: dict[fx.Node, int] = {}
for node in reversed(self.graph.nodes):
if is_compute_node(node):
# consider compute nodes to be at their own depth
dominance = compute_depth[node]
compute_node_index[node] = len(compute_node_index)
domination_index: dict[fx.Node, int] = {}
for node in reversed(self.graph.nodes):
if node in compute_node_index:
# Compute nodes dominate themselves (return their own index)
domination_index[node] = compute_node_index[node]
else:
# For non-compute nodes, find minimum compute they block
dominance = min(
(compute_depth_dominance[succ] for succ in node.users),
default=sys.maxsize,
domination_index[node] = min(
(domination_index[succ] for succ in node.users), default=sys.maxsize
)
compute_depth_dominance[node] = dominance
return compute_depth_dominance
return domination_index
def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks(
self,
@ -435,6 +447,13 @@ class OverlapScheduler:
self.scheduled.add(node)
self.memory_tracker.schedule_node(node)
log.debug(
"Scheduled node %s: current_memory=%d bytes, total_scheduled=%d",
node.name,
self.memory_tracker.get_current_memory_bytes(),
len(self.scheduled),
)
for user in node.users:
self.in_degree[user] -= 1
if self.in_degree[user] == 0:
@ -445,15 +464,20 @@ class OverlapScheduler:
if is_wait_tensor(node):
info = self.collective_info[self.wait_to_start[node]]
# TODO: we could consider even deferring waits that are not potentially hidden
# so as to overlap comm with itself. although exposed comms should bucketed with each other.
overlappable = info.is_exposed and node in self.potentially_hidden_waits
# defer waits locally if they are exposed.
compute_local_priority = int(info.is_exposed)
else:
overlappable = self.in_overlappable_collective_unary_chain(node)
# if we're scheduling this collective via its queue, then it was not
# pre-fetched. we might as well maximize overlap for the
# local, non-mm nodes prior to the next compute node.
if self.in_overlappable_collective_unary_chain(node):
compute_local_priority = -1
else:
compute_local_priority = 0
return (
self.compute_depth[node], # what depth compute it blocks
overlappable, # Defer hideable collective ops
self.compute_index_domination[node], # what index compute it blocks
compute_local_priority, # collective_start=-1, wait=1, or neither=0
self.node_idx[node], # Original order for stability
)
@ -473,7 +497,7 @@ class OverlapScheduler:
return False
if user in self.unscheduled_collectives:
return user in self.potentially_hidden_collectives
return True
if not self.is_cheap_fn(user):
return False
@ -484,7 +508,11 @@ class OverlapScheduler:
def _should_force_wait_for_memory(self) -> bool:
"""Check if we need to force a wait due to memory pressure"""
return self.in_flight_bytes >= self.max_in_flight_bytes
if not self.in_flight:
return False
return self.in_flight_bytes >= self.max_in_flight_bytes or (
self.memory_tracker.current_memory_bytes - self.original_peak_memory
) > gb_to_bytes(1.0)
def _force_oldest_wait(self) -> None:
"""Schedule the oldest in flight wait"""
@ -493,6 +521,12 @@ class OverlapScheduler:
def _handle_collective_start(self, node: fx.Node) -> None:
"""Handle scheduling a collective start."""
info = self.collective_info[node]
if self.should_assume_bucketed(node):
latency = estimate_collective_time(node, 0)
assert latency <= info.exposed_time_ms
info.exposed_time_ms = info.exposed_time_ms - latency
self.in_flight[node] = info
self.in_flight_bytes += info.size_bytes
self.unscheduled_collectives.discard(node)
@ -502,8 +536,22 @@ class OverlapScheduler:
"""Handle scheduling a wait."""
assert node in self.wait_to_start
coll_start = self.wait_to_start[node]
assert coll_start in self.in_flight
# Scheduling a wait of a collective also forces the wait
# of every node enqueued prior to the collective on the
# same process group
group_name = get_group_name(coll_start)
to_schedule: list[fx.Node] = []
for in_flight_coll in self.in_flight:
if in_flight_coll == coll_start:
break
if get_group_name(in_flight_coll) == group_name:
to_schedule.append(in_flight_coll)
for coll_to_schedule in to_schedule:
self._handle_wait(self.collective_info[coll_to_schedule].wait_node)
self.in_flight_bytes -= self.in_flight[coll_start].size_bytes
del self.in_flight[coll_start]
self._schedule(node)
@ -514,6 +562,7 @@ class OverlapScheduler:
compute_time = benchmark_node(node)
available_compute = compute_time * self.compute_overlap_multipler
# TODO: separate overlap time per process group
# First reduce exposed time of in-flight collectives
for info in self.in_flight.values():
if info.exposed_time_ms == 0:
@ -531,6 +580,7 @@ class OverlapScheduler:
self._schedule_collectives_for_overlap(node, available_compute)
self._schedule(node)
self.current_compute_index += 1
def _schedule_collectives_for_overlap(
self, compute_node: fx.Node, available_compute_time: float
@ -538,15 +588,39 @@ class OverlapScheduler:
"""Opportunistically schedule collectives that can be hidden by compute."""
compute_ancestors = self.node_ancestors[compute_node]
# copy unscheduled_collectives to local because we modify it during iteration
# Filter collectives by distance and compute index domination
possible_collectives = []
for collective in self.unscheduled_collectives:
distance = abs(self.node_idx[compute_node] - self.node_idx[collective])
if distance > self.max_node_distance:
break
# Skip collectives that are too far ahead in compute index, but allow scheduling
# collectives which are off compute path (which typically release memory)
# TODO: we could potentially be more strict about limiting the amount of
# pre-fetched memory before memory peak, and adjust allowed collective mem.
if not self.off_compute_path(collective):
if (
self.compute_index_domination[collective]
- self.current_compute_index
) > self.max_compute_pre_fetch:
continue
possible_collectives.append(collective)
possible_collectives = sorted(
possible_collectives,
key=lambda n: (self.compute_index_domination[n], self.node_idx[n]),
)
log.debug(
"Scheduling collectives for overlap: compute_node=%s, available_time=%.2f ms, candidates=%d, current_memory=%d bytes",
compute_node.name,
available_compute_time,
len(possible_collectives),
self.memory_tracker.current_memory_bytes,
)
for collective in possible_collectives:
if available_compute_time == 0:
break
@ -575,15 +649,25 @@ class OverlapScheduler:
if path is None:
continue
log.debug(
"Overlapping collective %s with compute %s: coll_domination=%d, current_depth=%d",
collective.name,
compute_node.name,
self.compute_index_domination[collective],
self.current_compute_index,
)
# Schedule path to this collective
self._schedule_path_to_collective(path, compute_node)
self._handle_collective_start(collective)
# Update the exposed time for this newly scheduled collective
overlap_amount = min(info.estimated_time_ms, available_compute_time)
# after scheduling, which will account for latency reduction of bucketing
overlap_amount = min(available_compute_time, info.exposed_time_ms)
info.exposed_time_ms -= overlap_amount
if info.exposed_time_ms == 0:
info.hiding_node = compute_node
available_compute_time -= overlap_amount
self._handle_collective_start(collective)
def _find_schedulable_path(
self, target: fx.Node, curr_compute_node: Optional[fx.Node]
@ -618,6 +702,24 @@ class OverlapScheduler:
return unscheduled_ancestors
def should_assume_bucketed(self, node: fx.Node) -> bool:
"""
Check if there's an in-flight collective that can be bucketed with the given node. If so, assume they will bucket.
This is a optimistic heuristic to account for latency reduction with bucketing. The two nodes may not get bucketed.
"""
if not torch._inductor.config.test_configs.assume_bucketing_reduces_latency:
return False
key = bucket_key(node)
if key is None:
return False
for in_flight_coll in self.in_flight.keys():
if bucket_key(in_flight_coll) == key:
return True
return False
def _get_oldest_wait(self) -> fx.Node:
oldest_start = next(iter(self.in_flight))
return self.collective_info[oldest_start].wait_node
@ -633,10 +735,18 @@ class OverlapScheduler:
self, path: OrderedSet[fx.Node], curr_compute_node: fx.Node
) -> None:
"""Schedule all nodes needed to reach a collective."""
assert all(n not in self.scheduled for n in path)
for node in sorted(path, key=lambda n: self.node_idx[n]):
assert not (is_compute_node(node) or node in self.unscheduled_collectives)
if is_wait_tensor(node):
# When we schedule wait tensors, we also force realization of all
# collectives enqueued prior to their corresponding collective.
# It's possible the scheduling of one wait tensor here has forced
# another in the path. If so, skip scheduling it.
if node in self.scheduled:
continue
info = self.collective_info[self.wait_to_start[node]]
assert info.hiding_node != curr_compute_node
self._handle_wait(node)
@ -672,15 +782,17 @@ class OverlapScheduler:
counters["inductor"]["overlap_scheduling_potentially_hidden"] += len(
potentially_hidden_collectives
)
counters["inductor"]["overlap_original_mem"] = self.original_peak_memory
counters["inductor"]["rescheduled_mem"] = self.memory_tracker.peak_memory
log.info(
"Overlap scheduling: total exposed %s, total bad exposed %s, total potentially hidden %s",
"Overlap scheduling results: exposed=%d, bad_exposed=%d, potentially_hidden=%d, "
"original_peak_memory=%d bytes, rescheduled_peak_memory=%d bytes",
len(exposed),
len(bad_exposed),
len(potentially_hidden_collectives),
self.original_peak_memory,
self.memory_tracker.peak_memory,
)
self.reorder_graph()