Limit coll bucketing within node idxs (#164944)

Respect max_coll_distance from overlap scheduler in bucketing, also, add an optimization in path searching.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164944
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783
This commit is contained in:
eellison
2025-10-08 09:06:20 -07:00
committed by PyTorch MergeBot
parent 5a1fbf45ad
commit af40828bbb
4 changed files with 72 additions and 23 deletions

View File

@ -5,6 +5,7 @@ import torch
import torch.fx as fx
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
from torch.testing._internal.common_utils import TestCase
from torch.utils._ordered_set import OrderedSet
class TestAugmentedGraphHelper(TestCase):
@ -61,9 +62,29 @@ class TestAugmentedGraphHelper(TestCase):
]:
self.nodes[node.name] = node
# Get all nodes and create tracker
# Get all nodes and compute ancestors
self.all_nodes = list(self.graph.nodes)
self.tracker = AugmentedGraphHelper(self.graph)
self.node_ancestors = self._collect_node_ancestors(self.graph)
# Create tracker with ancestors
self.tracker = AugmentedGraphHelper(
self.graph, node_ancestors=self.node_ancestors
)
def _collect_node_ancestors(
self, graph: fx.Graph
) -> dict[fx.Node, OrderedSet[fx.Node]]:
"""Collect all ancestors for each node."""
from collections import defaultdict
from torch.utils._ordered_set import OrderedSet
ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for node in graph.nodes:
for input_node in node.all_input_nodes:
ancestors[node].add(input_node)
ancestors[node] |= ancestors[input_node]
return ancestors
def get_deps(self, node):
"""Helper to get dependencies for a node."""

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from typing import Optional
import torch
import torch.fx as fx
@ -14,13 +15,20 @@ class AugmentedGraphHelper:
graphcycles.cc
"""
def __init__(self, graph: fx.Graph):
def __init__(
self,
graph: fx.Graph,
node_ancestors: Optional[dict[fx.Node, OrderedSet[fx.Node]]] = None,
):
# Each node starts in its own singleton set
self.graph = graph
self.merge_sets = {node: OrderedSet([node]) for node in graph.nodes}
# Extra dependencies: node depends on dep (dep must come before node)
self.extra_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
# Note: only reflect original ancestors, not maintained through additional deps
# or merge sets
self.node_ancestors = node_ancestors
def add_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None:
"""Add extra dependency: node depends on dep."""
@ -90,14 +98,28 @@ class AugmentedGraphHelper:
while queue:
current = queue.pop()
# Get all dependencies
for dep in self.get_merged_deps(current):
# Check if we reached source or its equivalent
if dep in self.merge_sets[source]:
return True
if dep not in visited:
visited.add(dep)
queue.append(dep)
if dep in visited:
continue
# We are searching from target, so this node is necessarily an ancestor
# of target.
# If dep is an ancestor of source, any path through dep to source would imply a cycle
if self.node_ancestors:
source_set = self.merge_sets[source]
is_ancestor_of_source = any(
dep in self.node_ancestors[s] for s in source_set
)
# Add to visited to avoid recomputing this check if we see dep again
if is_ancestor_of_source:
visited.add(dep)
continue
visited.add(dep)
queue.append(dep)
return False

View File

@ -37,6 +37,7 @@ class OverlapPreservingBucketer:
node_ancestors: dict[fx.Node, OrderedSet[fx.Node]],
scheduled: OrderedSet[fx.Node],
max_bucket_memory_gb: float = 1.0,
max_coll_distance: int = 1000,
):
self.graph = graph
self.collective_info = collective_info
@ -44,20 +45,20 @@ class OverlapPreservingBucketer:
self.scheduled = scheduled
self.max_bucket_memory_gb = max_bucket_memory_gb
self.node_idx = {n: i for i, n in enumerate(scheduled)}
self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors)
self.max_coll_distance = max_coll_distance
def bucket_collectives(self) -> None:
"""Main entry point for bucketing collectives."""
aug_graph = AugmentedGraphHelper(self.graph)
# Add extra dependencies for hidden collectives
# For each hidden collective, add: compute -> start and wait -> compute
for start_node, info in self.collective_info.items():
if info.hiding_node and not info.is_exposed:
# Add edge: hiding_compute depends on start (start must come before compute)
aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node)
self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node)
# Add edge: wait depends on hiding_compute (compute must come before wait)
aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
self.aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
# Group collectives by bucket key (type, group, etc.)
grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
@ -68,7 +69,7 @@ class OverlapPreservingBucketer:
all_buckets: list[CollBucket] = []
for collective_group in grouped_collectives.values():
buckets = self._find_buckets(collective_group, aug_graph)
buckets = self._find_buckets(collective_group)
all_buckets.extend(buckets)
# Collect all extra dependencies to preserve after bucketing
@ -95,7 +96,6 @@ class OverlapPreservingBucketer:
def _find_buckets(
self,
collective_group: OrderedSet[fx.Node],
aug_graph: AugmentedGraphHelper,
) -> list[CollBucket]:
"""Find valid buckets within a group of similar collectives."""
@ -113,17 +113,23 @@ class OverlapPreservingBucketer:
total_bytes=self.collective_info[start_node].size_bytes,
)
processed.add(start_node)
start_node_idx = self.node_idx[start_node]
# TODO - limit within range
for candidate in collective_group:
if candidate in processed:
continue
candidate_idx = self.node_idx[candidate]
# Check if candidate is within max distance from the bucket start
if abs(candidate_idx - start_node_idx) > self.max_coll_distance:
continue
candidate_bytes = self.collective_info[candidate].size_bytes
if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes:
continue
if self._can_add_to_bucket(bucket_info, candidate, aug_graph):
if self._can_add_to_bucket(bucket_info, candidate):
bucket_info.collectives.append(candidate)
bucket_info.total_bytes += candidate_bytes
processed.add(candidate)
@ -141,7 +147,6 @@ class OverlapPreservingBucketer:
self,
bucket_info: CollBucket,
candidate: fx.Node,
aug_graph: AugmentedGraphHelper,
) -> bool:
"""
Check if candidate can be added to bucket without interfering
@ -177,26 +182,26 @@ class OverlapPreservingBucketer:
# TODO: we have a range of possible idxs of the merged node, and idx of new node.
# we should not do path search beyond that range
existing_coll = bucket_info.collectives[0]
if aug_graph.has_path(existing_coll, candidate):
if self.aug_graph.has_path(existing_coll, candidate):
return False
if aug_graph.has_path(candidate, existing_coll):
if self.aug_graph.has_path(candidate, existing_coll):
return False
# Safe to merge starts - do the merge
aug_graph.merge_to_set(existing_coll, candidate)
self.aug_graph.merge_to_set(existing_coll, candidate)
# Step 3: Check and merge waits
existing_wait = self.collective_info[existing_coll].wait_node
candidate_wait = candidate_info.wait_node
# TODO - as above, limit search by idx
if aug_graph.has_path(existing_wait, candidate_wait) or aug_graph.has_path(
candidate_wait, existing_wait
):
if self.aug_graph.has_path(
existing_wait, candidate_wait
) or self.aug_graph.has_path(candidate_wait, existing_wait):
# Unmerge the start we just merged
aug_graph.unmerge_node(candidate)
self.aug_graph.unmerge_node(candidate)
return False
aug_graph.merge_to_set(existing_wait, candidate_wait)
self.aug_graph.merge_to_set(existing_wait, candidate_wait)
return True
def _apply_bucket(

View File

@ -679,6 +679,7 @@ class OverlapScheduler:
node_ancestors=self.node_ancestors,
scheduled=self.scheduled,
max_bucket_memory_gb=1.0, # Could make this configurable
max_coll_distance=self.max_node_distance,
)
bucketer.bucket_collectives()