mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Limit coll bucketing within node idxs (#164944)
Respect max_coll_distance from overlap scheduler in bucketing, also, add an optimization in path searching. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164944 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #164738, #164783
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a1fbf45ad
commit
af40828bbb
@ -5,6 +5,7 @@ import torch
|
||||
import torch.fx as fx
|
||||
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
class TestAugmentedGraphHelper(TestCase):
|
||||
@ -61,9 +62,29 @@ class TestAugmentedGraphHelper(TestCase):
|
||||
]:
|
||||
self.nodes[node.name] = node
|
||||
|
||||
# Get all nodes and create tracker
|
||||
# Get all nodes and compute ancestors
|
||||
self.all_nodes = list(self.graph.nodes)
|
||||
self.tracker = AugmentedGraphHelper(self.graph)
|
||||
self.node_ancestors = self._collect_node_ancestors(self.graph)
|
||||
|
||||
# Create tracker with ancestors
|
||||
self.tracker = AugmentedGraphHelper(
|
||||
self.graph, node_ancestors=self.node_ancestors
|
||||
)
|
||||
|
||||
def _collect_node_ancestors(
|
||||
self, graph: fx.Graph
|
||||
) -> dict[fx.Node, OrderedSet[fx.Node]]:
|
||||
"""Collect all ancestors for each node."""
|
||||
from collections import defaultdict
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
for node in graph.nodes:
|
||||
for input_node in node.all_input_nodes:
|
||||
ancestors[node].add(input_node)
|
||||
ancestors[node] |= ancestors[input_node]
|
||||
return ancestors
|
||||
|
||||
def get_deps(self, node):
|
||||
"""Helper to get dependencies for a node."""
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -14,13 +15,20 @@ class AugmentedGraphHelper:
|
||||
graphcycles.cc
|
||||
"""
|
||||
|
||||
def __init__(self, graph: fx.Graph):
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.Graph,
|
||||
node_ancestors: Optional[dict[fx.Node, OrderedSet[fx.Node]]] = None,
|
||||
):
|
||||
# Each node starts in its own singleton set
|
||||
self.graph = graph
|
||||
self.merge_sets = {node: OrderedSet([node]) for node in graph.nodes}
|
||||
|
||||
# Extra dependencies: node depends on dep (dep must come before node)
|
||||
self.extra_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
# Note: only reflect original ancestors, not maintained through additional deps
|
||||
# or merge sets
|
||||
self.node_ancestors = node_ancestors
|
||||
|
||||
def add_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None:
|
||||
"""Add extra dependency: node depends on dep."""
|
||||
@ -90,13 +98,27 @@ class AugmentedGraphHelper:
|
||||
while queue:
|
||||
current = queue.pop()
|
||||
|
||||
# Get all dependencies
|
||||
for dep in self.get_merged_deps(current):
|
||||
# Check if we reached source or its equivalent
|
||||
if dep in self.merge_sets[source]:
|
||||
return True
|
||||
|
||||
if dep not in visited:
|
||||
if dep in visited:
|
||||
continue
|
||||
|
||||
# We are searching from target, so this node is necessarily an ancestor
|
||||
# of target.
|
||||
# If dep is an ancestor of source, any path through dep to source would imply a cycle
|
||||
if self.node_ancestors:
|
||||
source_set = self.merge_sets[source]
|
||||
is_ancestor_of_source = any(
|
||||
dep in self.node_ancestors[s] for s in source_set
|
||||
)
|
||||
# Add to visited to avoid recomputing this check if we see dep again
|
||||
if is_ancestor_of_source:
|
||||
visited.add(dep)
|
||||
continue
|
||||
|
||||
visited.add(dep)
|
||||
queue.append(dep)
|
||||
|
||||
|
@ -37,6 +37,7 @@ class OverlapPreservingBucketer:
|
||||
node_ancestors: dict[fx.Node, OrderedSet[fx.Node]],
|
||||
scheduled: OrderedSet[fx.Node],
|
||||
max_bucket_memory_gb: float = 1.0,
|
||||
max_coll_distance: int = 1000,
|
||||
):
|
||||
self.graph = graph
|
||||
self.collective_info = collective_info
|
||||
@ -44,20 +45,20 @@ class OverlapPreservingBucketer:
|
||||
self.scheduled = scheduled
|
||||
self.max_bucket_memory_gb = max_bucket_memory_gb
|
||||
self.node_idx = {n: i for i, n in enumerate(scheduled)}
|
||||
self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors)
|
||||
self.max_coll_distance = max_coll_distance
|
||||
|
||||
def bucket_collectives(self) -> None:
|
||||
"""Main entry point for bucketing collectives."""
|
||||
|
||||
aug_graph = AugmentedGraphHelper(self.graph)
|
||||
|
||||
# Add extra dependencies for hidden collectives
|
||||
# For each hidden collective, add: compute -> start and wait -> compute
|
||||
for start_node, info in self.collective_info.items():
|
||||
if info.hiding_node and not info.is_exposed:
|
||||
# Add edge: hiding_compute depends on start (start must come before compute)
|
||||
aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node)
|
||||
self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node)
|
||||
# Add edge: wait depends on hiding_compute (compute must come before wait)
|
||||
aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
|
||||
self.aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
|
||||
|
||||
# Group collectives by bucket key (type, group, etc.)
|
||||
grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
@ -68,7 +69,7 @@ class OverlapPreservingBucketer:
|
||||
|
||||
all_buckets: list[CollBucket] = []
|
||||
for collective_group in grouped_collectives.values():
|
||||
buckets = self._find_buckets(collective_group, aug_graph)
|
||||
buckets = self._find_buckets(collective_group)
|
||||
all_buckets.extend(buckets)
|
||||
|
||||
# Collect all extra dependencies to preserve after bucketing
|
||||
@ -95,7 +96,6 @@ class OverlapPreservingBucketer:
|
||||
def _find_buckets(
|
||||
self,
|
||||
collective_group: OrderedSet[fx.Node],
|
||||
aug_graph: AugmentedGraphHelper,
|
||||
) -> list[CollBucket]:
|
||||
"""Find valid buckets within a group of similar collectives."""
|
||||
|
||||
@ -113,17 +113,23 @@ class OverlapPreservingBucketer:
|
||||
total_bytes=self.collective_info[start_node].size_bytes,
|
||||
)
|
||||
processed.add(start_node)
|
||||
start_node_idx = self.node_idx[start_node]
|
||||
|
||||
# TODO - limit within range
|
||||
for candidate in collective_group:
|
||||
if candidate in processed:
|
||||
continue
|
||||
|
||||
candidate_idx = self.node_idx[candidate]
|
||||
# Check if candidate is within max distance from the bucket start
|
||||
if abs(candidate_idx - start_node_idx) > self.max_coll_distance:
|
||||
continue
|
||||
|
||||
candidate_bytes = self.collective_info[candidate].size_bytes
|
||||
if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes:
|
||||
continue
|
||||
|
||||
if self._can_add_to_bucket(bucket_info, candidate, aug_graph):
|
||||
if self._can_add_to_bucket(bucket_info, candidate):
|
||||
bucket_info.collectives.append(candidate)
|
||||
bucket_info.total_bytes += candidate_bytes
|
||||
processed.add(candidate)
|
||||
@ -141,7 +147,6 @@ class OverlapPreservingBucketer:
|
||||
self,
|
||||
bucket_info: CollBucket,
|
||||
candidate: fx.Node,
|
||||
aug_graph: AugmentedGraphHelper,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if candidate can be added to bucket without interfering
|
||||
@ -177,26 +182,26 @@ class OverlapPreservingBucketer:
|
||||
# TODO: we have a range of possible idxs of the merged node, and idx of new node.
|
||||
# we should not do path search beyond that range
|
||||
existing_coll = bucket_info.collectives[0]
|
||||
if aug_graph.has_path(existing_coll, candidate):
|
||||
if self.aug_graph.has_path(existing_coll, candidate):
|
||||
return False
|
||||
if aug_graph.has_path(candidate, existing_coll):
|
||||
if self.aug_graph.has_path(candidate, existing_coll):
|
||||
return False
|
||||
|
||||
# Safe to merge starts - do the merge
|
||||
aug_graph.merge_to_set(existing_coll, candidate)
|
||||
self.aug_graph.merge_to_set(existing_coll, candidate)
|
||||
|
||||
# Step 3: Check and merge waits
|
||||
existing_wait = self.collective_info[existing_coll].wait_node
|
||||
candidate_wait = candidate_info.wait_node
|
||||
# TODO - as above, limit search by idx
|
||||
if aug_graph.has_path(existing_wait, candidate_wait) or aug_graph.has_path(
|
||||
candidate_wait, existing_wait
|
||||
):
|
||||
if self.aug_graph.has_path(
|
||||
existing_wait, candidate_wait
|
||||
) or self.aug_graph.has_path(candidate_wait, existing_wait):
|
||||
# Unmerge the start we just merged
|
||||
aug_graph.unmerge_node(candidate)
|
||||
self.aug_graph.unmerge_node(candidate)
|
||||
return False
|
||||
|
||||
aug_graph.merge_to_set(existing_wait, candidate_wait)
|
||||
self.aug_graph.merge_to_set(existing_wait, candidate_wait)
|
||||
return True
|
||||
|
||||
def _apply_bucket(
|
||||
|
@ -679,6 +679,7 @@ class OverlapScheduler:
|
||||
node_ancestors=self.node_ancestors,
|
||||
scheduled=self.scheduled,
|
||||
max_bucket_memory_gb=1.0, # Could make this configurable
|
||||
max_coll_distance=self.max_node_distance,
|
||||
)
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
|
Reference in New Issue
Block a user