diff --git a/test/inductor/test_augmented_graph_helper.py b/test/inductor/test_augmented_graph_helper.py index 7267b4660169..ef1f92e23268 100644 --- a/test/inductor/test_augmented_graph_helper.py +++ b/test/inductor/test_augmented_graph_helper.py @@ -5,6 +5,7 @@ import torch import torch.fx as fx from torch._inductor.augmented_graph_helper import AugmentedGraphHelper from torch.testing._internal.common_utils import TestCase +from torch.utils._ordered_set import OrderedSet class TestAugmentedGraphHelper(TestCase): @@ -61,9 +62,29 @@ class TestAugmentedGraphHelper(TestCase): ]: self.nodes[node.name] = node - # Get all nodes and create tracker + # Get all nodes and compute ancestors self.all_nodes = list(self.graph.nodes) - self.tracker = AugmentedGraphHelper(self.graph) + self.node_ancestors = self._collect_node_ancestors(self.graph) + + # Create tracker with ancestors + self.tracker = AugmentedGraphHelper( + self.graph, node_ancestors=self.node_ancestors + ) + + def _collect_node_ancestors( + self, graph: fx.Graph + ) -> dict[fx.Node, OrderedSet[fx.Node]]: + """Collect all ancestors for each node.""" + from collections import defaultdict + + from torch.utils._ordered_set import OrderedSet + + ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for node in graph.nodes: + for input_node in node.all_input_nodes: + ancestors[node].add(input_node) + ancestors[node] |= ancestors[input_node] + return ancestors def get_deps(self, node): """Helper to get dependencies for a node.""" diff --git a/torch/_inductor/augmented_graph_helper.py b/torch/_inductor/augmented_graph_helper.py index c83bdd7d5396..ac61c015888e 100644 --- a/torch/_inductor/augmented_graph_helper.py +++ b/torch/_inductor/augmented_graph_helper.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import Optional import torch import torch.fx as fx @@ -14,13 +15,20 @@ class AugmentedGraphHelper: graphcycles.cc """ - def __init__(self, graph: fx.Graph): + def __init__( + self, + graph: fx.Graph, + node_ancestors: Optional[dict[fx.Node, OrderedSet[fx.Node]]] = None, + ): # Each node starts in its own singleton set self.graph = graph self.merge_sets = {node: OrderedSet([node]) for node in graph.nodes} # Extra dependencies: node depends on dep (dep must come before node) self.extra_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + # Note: only reflect original ancestors, not maintained through additional deps + # or merge sets + self.node_ancestors = node_ancestors def add_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None: """Add extra dependency: node depends on dep.""" @@ -90,14 +98,28 @@ class AugmentedGraphHelper: while queue: current = queue.pop() - # Get all dependencies for dep in self.get_merged_deps(current): # Check if we reached source or its equivalent if dep in self.merge_sets[source]: return True - if dep not in visited: - visited.add(dep) - queue.append(dep) + if dep in visited: + continue + + # We are searching from target, so this node is necessarily an ancestor + # of target. + # If dep is an ancestor of source, any path through dep to source would imply a cycle + if self.node_ancestors: + source_set = self.merge_sets[source] + is_ancestor_of_source = any( + dep in self.node_ancestors[s] for s in source_set + ) + # Add to visited to avoid recomputing this check if we see dep again + if is_ancestor_of_source: + visited.add(dep) + continue + + visited.add(dep) + queue.append(dep) return False diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index aed6076207eb..0e3c627a95c5 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -37,6 +37,7 @@ class OverlapPreservingBucketer: node_ancestors: dict[fx.Node, OrderedSet[fx.Node]], scheduled: OrderedSet[fx.Node], max_bucket_memory_gb: float = 1.0, + max_coll_distance: int = 1000, ): self.graph = graph self.collective_info = collective_info @@ -44,20 +45,20 @@ class OverlapPreservingBucketer: self.scheduled = scheduled self.max_bucket_memory_gb = max_bucket_memory_gb self.node_idx = {n: i for i, n in enumerate(scheduled)} + self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors) + self.max_coll_distance = max_coll_distance def bucket_collectives(self) -> None: """Main entry point for bucketing collectives.""" - aug_graph = AugmentedGraphHelper(self.graph) - # Add extra dependencies for hidden collectives # For each hidden collective, add: compute -> start and wait -> compute for start_node, info in self.collective_info.items(): if info.hiding_node and not info.is_exposed: # Add edge: hiding_compute depends on start (start must come before compute) - aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node) + self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node) # Add edge: wait depends on hiding_compute (compute must come before wait) - aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node) + self.aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node) # Group collectives by bucket key (type, group, etc.) grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet) @@ -68,7 +69,7 @@ class OverlapPreservingBucketer: all_buckets: list[CollBucket] = [] for collective_group in grouped_collectives.values(): - buckets = self._find_buckets(collective_group, aug_graph) + buckets = self._find_buckets(collective_group) all_buckets.extend(buckets) # Collect all extra dependencies to preserve after bucketing @@ -95,7 +96,6 @@ class OverlapPreservingBucketer: def _find_buckets( self, collective_group: OrderedSet[fx.Node], - aug_graph: AugmentedGraphHelper, ) -> list[CollBucket]: """Find valid buckets within a group of similar collectives.""" @@ -113,17 +113,23 @@ class OverlapPreservingBucketer: total_bytes=self.collective_info[start_node].size_bytes, ) processed.add(start_node) + start_node_idx = self.node_idx[start_node] # TODO - limit within range for candidate in collective_group: if candidate in processed: continue + candidate_idx = self.node_idx[candidate] + # Check if candidate is within max distance from the bucket start + if abs(candidate_idx - start_node_idx) > self.max_coll_distance: + continue + candidate_bytes = self.collective_info[candidate].size_bytes if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes: continue - if self._can_add_to_bucket(bucket_info, candidate, aug_graph): + if self._can_add_to_bucket(bucket_info, candidate): bucket_info.collectives.append(candidate) bucket_info.total_bytes += candidate_bytes processed.add(candidate) @@ -141,7 +147,6 @@ class OverlapPreservingBucketer: self, bucket_info: CollBucket, candidate: fx.Node, - aug_graph: AugmentedGraphHelper, ) -> bool: """ Check if candidate can be added to bucket without interfering @@ -177,26 +182,26 @@ class OverlapPreservingBucketer: # TODO: we have a range of possible idxs of the merged node, and idx of new node. # we should not do path search beyond that range existing_coll = bucket_info.collectives[0] - if aug_graph.has_path(existing_coll, candidate): + if self.aug_graph.has_path(existing_coll, candidate): return False - if aug_graph.has_path(candidate, existing_coll): + if self.aug_graph.has_path(candidate, existing_coll): return False # Safe to merge starts - do the merge - aug_graph.merge_to_set(existing_coll, candidate) + self.aug_graph.merge_to_set(existing_coll, candidate) # Step 3: Check and merge waits existing_wait = self.collective_info[existing_coll].wait_node candidate_wait = candidate_info.wait_node # TODO - as above, limit search by idx - if aug_graph.has_path(existing_wait, candidate_wait) or aug_graph.has_path( - candidate_wait, existing_wait - ): + if self.aug_graph.has_path( + existing_wait, candidate_wait + ) or self.aug_graph.has_path(candidate_wait, existing_wait): # Unmerge the start we just merged - aug_graph.unmerge_node(candidate) + self.aug_graph.unmerge_node(candidate) return False - aug_graph.merge_to_set(existing_wait, candidate_wait) + self.aug_graph.merge_to_set(existing_wait, candidate_wait) return True def _apply_bucket( diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index e6468d349e2e..69cc0bb476b9 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -679,6 +679,7 @@ class OverlapScheduler: node_ancestors=self.node_ancestors, scheduled=self.scheduled, max_bucket_memory_gb=1.0, # Could make this configurable + max_coll_distance=self.max_node_distance, ) bucketer.bucket_collectives()