mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Helper to augment graph with additional deps (#163959)"
This reverts commit b5d4d350f573db12b8181ee13f9386d6ef8a1e57.
Reverted https://github.com/pytorch/pytorch/pull/163959 on behalf of https://github.com/yangw-dev due to seems fails inductor/test_aten_comm_compute_reordering for macos test, see c9b5af9a38 (51526707590-box)
([comment](https://github.com/pytorch/pytorch/pull/163215#issuecomment-3349177940))
This commit is contained in:
@ -1,346 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class TestAugmentedGraphHelper(TestCase):
|
||||
"""Test suite for AugmentedGraphHelper dependency and merge management."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create a simple graph structure for testing."""
|
||||
# Create a torch.fx.Graph with multiple nodes
|
||||
self.graph = fx.Graph()
|
||||
|
||||
# Create placeholder nodes (inputs)
|
||||
self.x = self.graph.placeholder("x")
|
||||
self.y = self.graph.placeholder("y")
|
||||
|
||||
# Create computation nodes with specific names for easy reference
|
||||
self.node_a = self.graph.call_function(
|
||||
torch.add, args=(self.x, self.y), name="A"
|
||||
)
|
||||
self.node_b = self.graph.call_function(
|
||||
torch.mul, args=(self.node_a, self.x), name="B"
|
||||
)
|
||||
self.node_c = self.graph.call_function(
|
||||
torch.sub, args=(self.node_a, self.y), name="C"
|
||||
)
|
||||
self.node_d = self.graph.call_function(
|
||||
torch.div, args=(self.node_b, self.node_c), name="D"
|
||||
)
|
||||
self.node_e = self.graph.call_function(
|
||||
operator.neg, args=(self.node_d,), name="E"
|
||||
)
|
||||
self.node_f = self.graph.call_function(torch.abs, args=(self.node_e,), name="F")
|
||||
self.node_g = self.graph.call_function(
|
||||
torch.relu, args=(self.node_f,), name="G"
|
||||
)
|
||||
self.node_h = self.graph.call_function(
|
||||
torch.sigmoid, args=(self.node_g,), name="H"
|
||||
)
|
||||
|
||||
# Create output
|
||||
self.graph.output(self.node_h)
|
||||
|
||||
# Create a mapping of nodes by name for easier access in tests
|
||||
self.nodes = {}
|
||||
for node in self.graph.nodes:
|
||||
if hasattr(node, "name") and node.name in [
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"D",
|
||||
"E",
|
||||
"F",
|
||||
"G",
|
||||
"H",
|
||||
]:
|
||||
self.nodes[node.name] = node
|
||||
|
||||
# Get all nodes and create tracker
|
||||
self.all_nodes = list(self.graph.nodes)
|
||||
self.tracker = AugmentedGraphHelper(self.graph)
|
||||
|
||||
def get_deps(self, node):
|
||||
"""Helper to get dependencies for a node."""
|
||||
return list(getattr(node, "args", []))
|
||||
|
||||
# ========== Basic Functionality Tests ==========
|
||||
|
||||
def test_initial_state(self):
|
||||
"""Test that nodes start as singletons."""
|
||||
for node in self.all_nodes:
|
||||
merge_set = self.tracker.merge_sets[node]
|
||||
self.assertEqual(merge_set, {node})
|
||||
self.assertEqual(len(merge_set), 1)
|
||||
|
||||
def test_simple_merge(self):
|
||||
"""Test merging two nodes."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
|
||||
self.merge_nodes(self.tracker, [node_a, node_b])
|
||||
|
||||
# Both should be in same merge set
|
||||
self.assertEqual(self.tracker.merge_sets[node_a], {node_a, node_b})
|
||||
self.assertEqual(self.tracker.merge_sets[node_b], {node_a, node_b})
|
||||
self.assertEqual(
|
||||
self.tracker.merge_sets[node_a], self.tracker.merge_sets[node_b]
|
||||
)
|
||||
|
||||
def test_transitive_merge(self):
|
||||
"""Test merging already merged nodes."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
node_c = self.nodes["C"]
|
||||
node_d = self.nodes["D"]
|
||||
|
||||
# Merge A-B and C-D separately
|
||||
for node in node_b, node_c, node_d:
|
||||
self.tracker.merge_to_set(node_a, node)
|
||||
|
||||
expected_set = {node_a, node_b, node_c, node_d}
|
||||
for node in [node_a, node_b, node_c, node_d]:
|
||||
self.assertEqual(self.tracker.merge_sets[node], expected_set)
|
||||
|
||||
def merge_nodes(self, tracker, nodes):
|
||||
for n in nodes[1:]:
|
||||
tracker.merge_to_set(nodes[0], n)
|
||||
|
||||
def test_unmerge_node(self):
|
||||
"""Test removing a node from its merge set."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
node_c = self.nodes["C"]
|
||||
|
||||
# Merge all three
|
||||
self.merge_nodes(self.tracker, [node_a, node_b, node_c])
|
||||
self.assertEqual(len(self.tracker.merge_sets[node_a]), 3)
|
||||
|
||||
# Unmerge B
|
||||
self.tracker.unmerge_node(node_b)
|
||||
|
||||
# B should be singleton
|
||||
self.assertEqual(self.tracker.merge_sets[node_b], {node_b})
|
||||
|
||||
# A and C should still be together
|
||||
self.assertEqual(self.tracker.merge_sets[node_a], {node_a, node_c})
|
||||
self.assertEqual(self.tracker.merge_sets[node_c], {node_a, node_c})
|
||||
|
||||
def test_unmerge_from_singleton(self):
|
||||
"""Test unmerging a node that's already singleton."""
|
||||
node_a = self.nodes["A"]
|
||||
|
||||
# Should be no-op
|
||||
self.tracker.unmerge_node(node_a)
|
||||
self.assertEqual(self.tracker.merge_sets[node_a], {node_a})
|
||||
|
||||
# ========== Dependency Propagation Tests ==========
|
||||
|
||||
def test_merged_deps_collection(self):
|
||||
"""Test that dependencies are collected from all merged nodes."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
node_c = self.nodes["C"]
|
||||
|
||||
# B already depends on A (and x) from graph construction
|
||||
# C already depends on A (and y) from graph construction
|
||||
|
||||
# Merge B and C
|
||||
self.merge_nodes(self.tracker, [node_b, node_c])
|
||||
|
||||
# Get merged deps for B - should include deps from both B and C
|
||||
deps = self.tracker.get_merged_deps(node_b)
|
||||
|
||||
# Should include all dependencies from both nodes
|
||||
self.assertIn(node_a, deps) # From both B and C
|
||||
self.assertIn(self.x, deps) # From B
|
||||
self.assertIn(self.y, deps) # From C
|
||||
|
||||
def test_extra_deps_with_merge(self):
|
||||
"""Test extra dependencies work correctly with merged nodes."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
node_c = self.nodes["C"]
|
||||
node_d = self.nodes["D"]
|
||||
|
||||
# Add extra dep from A to C
|
||||
self.tracker.add_extra_dep(n=node_a, dep=node_c)
|
||||
|
||||
# Merge A and B
|
||||
self.merge_nodes(self.tracker, [node_a, node_b])
|
||||
|
||||
# Add extra dep from D to the merged node (via B)
|
||||
self.tracker.add_extra_dep(n=node_d, dep=node_b)
|
||||
|
||||
# D should depend on B through extra deps
|
||||
deps = self.tracker.get_merged_deps(node_d)
|
||||
self.assertIn(node_b, deps)
|
||||
|
||||
# A should still have its dep on C
|
||||
deps = self.tracker.get_merged_deps(node_a)
|
||||
self.assertIn(node_c, deps)
|
||||
|
||||
# ========== Path Finding Tests ==========
|
||||
|
||||
def test_has_path_direct(self):
|
||||
"""Test path finding for direct dependencies."""
|
||||
# In our graph: B depends on A
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
|
||||
self.assertTrue(self.tracker.has_path(node_a, node_b))
|
||||
self.assertFalse(self.tracker.has_path(node_b, node_a))
|
||||
|
||||
def test_has_path_transitive(self):
|
||||
"""Test path finding through multiple nodes."""
|
||||
# In our graph: A -> B -> D and A -> C -> D -> E
|
||||
node_a = self.nodes["A"]
|
||||
node_e = self.nodes["E"]
|
||||
|
||||
self.assertTrue(self.tracker.has_path(node_a, node_e))
|
||||
self.assertFalse(self.tracker.has_path(node_e, node_a))
|
||||
|
||||
def test_has_path_through_merge(self):
|
||||
"""Test path finding when nodes are merged."""
|
||||
# Create a new graph for this specific test
|
||||
graph2 = fx.Graph()
|
||||
x2 = graph2.placeholder("x")
|
||||
a2 = graph2.call_function(torch.neg, args=(x2,), name="A2")
|
||||
b2 = graph2.call_function(torch.abs, args=(a2,), name="B2")
|
||||
c2 = graph2.call_function(torch.relu, args=(x2,), name="C2")
|
||||
d2 = graph2.call_function(torch.sigmoid, args=(c2,), name="D2")
|
||||
graph2.output(d2)
|
||||
|
||||
tracker2 = AugmentedGraphHelper(graph2)
|
||||
|
||||
# Initially no path from B2 to D2
|
||||
self.assertFalse(tracker2.has_path(b2, d2))
|
||||
|
||||
# Merge B2 and C2
|
||||
tracker2.merge_to_set(b2, c2)
|
||||
|
||||
# Now there should be a path B2/C2 -> D2
|
||||
self.assertTrue(tracker2.has_path(b2, d2))
|
||||
|
||||
def test_has_path_with_extra_deps(self):
|
||||
"""Test path finding with extra dependencies."""
|
||||
|
||||
graph2 = fx.Graph()
|
||||
x2 = graph2.placeholder("x")
|
||||
a2 = graph2.call_function(torch.neg, args=(x2,), name="A2")
|
||||
b2 = graph2.call_function(torch.abs, args=(a2,), name="B2")
|
||||
c2 = graph2.call_function(torch.relu, args=(x2,), name="C2")
|
||||
d2 = graph2.call_function(torch.sigmoid, args=(c2,), name="D2")
|
||||
graph2.output(d2)
|
||||
|
||||
tracker2 = AugmentedGraphHelper(graph2)
|
||||
|
||||
# Initially no path from B2 to D2
|
||||
self.assertFalse(tracker2.has_path(b2, d2))
|
||||
|
||||
tracker2.add_extra_dep(n=c2, dep=b2)
|
||||
|
||||
# Now there should be a path B2/C2 -> D2
|
||||
self.assertTrue(tracker2.has_path(b2, d2))
|
||||
|
||||
# ========== Cycle Detection Tests ==========
|
||||
|
||||
def test_no_cycle_in_dag(self):
|
||||
"""Test that DAG has no cycles."""
|
||||
# Our original graph is a DAG, should have no cycles
|
||||
self.assertFalse(self.tracker.has_cycle())
|
||||
|
||||
def test_simple_cycle_detection(self):
|
||||
"""Test detection of simple cycle."""
|
||||
# Create a graph with a cycle
|
||||
graph3 = fx.Graph()
|
||||
x3 = graph3.placeholder("x")
|
||||
|
||||
# We can't create true cycles in fx.Graph directly,
|
||||
# but we can simulate with extra_deps
|
||||
a3 = graph3.call_function(torch.neg, args=(x3,))
|
||||
b3 = graph3.call_function(torch.abs, args=(a3,))
|
||||
c3 = graph3.call_function(torch.relu, args=(b3,))
|
||||
graph3.output(c3)
|
||||
|
||||
tracker3 = AugmentedGraphHelper(graph3)
|
||||
self.assertFalse(tracker3.has_cycle())
|
||||
|
||||
# Add extra dep to create cycle: a3 -> c3
|
||||
tracker3.add_extra_dep(n=a3, dep=c3)
|
||||
|
||||
self.assertTrue(tracker3.has_cycle())
|
||||
|
||||
def test_cycle_through_merge(self):
|
||||
"""Test that merging can create cycles."""
|
||||
# Create specific graph for this test
|
||||
graph4 = fx.Graph()
|
||||
x4 = graph4.placeholder("x")
|
||||
a4 = graph4.call_function(torch.neg, args=(x4,))
|
||||
b4 = graph4.call_function(torch.abs, args=(a4,))
|
||||
c4 = graph4.call_function(torch.relu, args=(x4,))
|
||||
d4 = graph4.call_function(torch.sigmoid, args=(c4,))
|
||||
graph4.output(d4)
|
||||
|
||||
tracker4 = AugmentedGraphHelper(graph4)
|
||||
|
||||
# Add extra dep d4 -> a4
|
||||
tracker4.add_extra_dep(n=a4, dep=d4)
|
||||
|
||||
# Now: a4 -> b4, c4 -> d4 -> a4
|
||||
# Merging b4 and c4 would create cycle
|
||||
tracker4.merge_to_set(b4, c4)
|
||||
|
||||
self.assertTrue(tracker4.has_cycle())
|
||||
|
||||
def test_cycle_with_extra_deps(self):
|
||||
"""Test cycle detection with extra dependencies."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
|
||||
# B already depends on A naturally
|
||||
# Add reverse dependency to create cycle
|
||||
self.tracker.add_extra_dep(n=node_a, dep=node_b)
|
||||
|
||||
self.assertTrue(self.tracker.has_cycle())
|
||||
|
||||
def test_multiple_merge_unmerge(self):
|
||||
"""Test sequence of merge and unmerge operations."""
|
||||
nodes = [self.nodes[c] for c in ["A", "B", "C", "D", "E"]]
|
||||
|
||||
# Merge A, B, C
|
||||
self.merge_nodes(self.tracker, nodes[:3])
|
||||
self.assertEqual(len(self.tracker.merge_sets[nodes[0]]), 3)
|
||||
|
||||
# Merge D, E
|
||||
self.merge_nodes(self.tracker, nodes[3:5])
|
||||
self.assertEqual(len(self.tracker.merge_sets[nodes[3]]), 2)
|
||||
|
||||
# Merge the two groups via B and D
|
||||
try:
|
||||
self.merge_nodes(self.tracker, [nodes[1], nodes[3]])
|
||||
thrown = False
|
||||
except AssertionError:
|
||||
thrown = True
|
||||
self.assertTrue(thrown)
|
||||
|
||||
# Unmerge C
|
||||
self.tracker.unmerge_node(nodes[2])
|
||||
self.assertEqual(len(self.tracker.merge_sets[nodes[0]]), 2)
|
||||
self.assertEqual(self.tracker.merge_sets[nodes[2]], {nodes[2]})
|
||||
|
||||
# Unmerge A
|
||||
self.tracker.unmerge_node(nodes[0])
|
||||
self.assertEqual(self.tracker.merge_sets[nodes[0]], {nodes[0]})
|
||||
self.assertEqual(len(self.tracker.merge_sets[nodes[1]]), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -324,11 +324,10 @@ def _create_subgraph(
|
||||
return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec
|
||||
|
||||
|
||||
def _stable_topological_sort_impl(
|
||||
def _stable_topological_sort(
|
||||
graph: torch.fx.Graph,
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
||||
do_sort: bool = True,
|
||||
) -> bool:
|
||||
) -> None:
|
||||
# Nodes are in exactly one of these four collections:
|
||||
|
||||
# - Nodes in `pending` are waiting to be processed (in reverse order):
|
||||
@ -367,7 +366,7 @@ def _stable_topological_sort_impl(
|
||||
waiting[waiting_for[-1]].append(node)
|
||||
else:
|
||||
ready.add(node)
|
||||
if cursor and cursor.next is not node and do_sort:
|
||||
if cursor and cursor.next is not node:
|
||||
cursor.append(node)
|
||||
cursor = node
|
||||
# Mark the nodes that have been waiting for this node to finish as
|
||||
@ -375,23 +374,7 @@ def _stable_topological_sort_impl(
|
||||
pending.extend(reversed(waiting.pop(node, ())))
|
||||
|
||||
ready.update(outputs)
|
||||
return not waiting and len(ready) == len(graph.nodes)
|
||||
|
||||
|
||||
def _stable_topological_sort(
|
||||
graph: torch.fx.Graph,
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
||||
) -> None:
|
||||
assert _stable_topological_sort_impl(graph, node_to_additional_deps)
|
||||
|
||||
|
||||
def _has_cycle(
|
||||
graph: torch.fx.Graph,
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
||||
) -> bool:
|
||||
return not _stable_topological_sort_impl(
|
||||
graph, node_to_additional_deps, do_sort=False
|
||||
)
|
||||
assert not waiting and len(ready) == len(graph.nodes)
|
||||
|
||||
|
||||
def _populate_additional_deps(
|
||||
|
@ -1,103 +0,0 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
class AugmentedGraphHelper:
|
||||
"""
|
||||
Graph helper that augments the original graph with additional
|
||||
dependencies and uses, plus tracks node equivalences for coalescing.
|
||||
|
||||
TODO: if this becomes too large of compile time, consider binding
|
||||
graphcycles.cc
|
||||
"""
|
||||
|
||||
def __init__(self, graph: fx.Graph):
|
||||
# Each node starts in its own singleton set
|
||||
self.graph = graph
|
||||
self.merge_sets = {node: OrderedSet([node]) for node in graph.nodes}
|
||||
|
||||
# Extra dependencies: node depends on dep (dep must come before node)
|
||||
self.extra_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
|
||||
def add_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None:
|
||||
"""Add extra dependency: node depends on dep."""
|
||||
self.extra_deps[n].add(dep)
|
||||
|
||||
def merge_to_set(self, existing_node: fx.Node, new_node: fx.Node) -> None:
|
||||
"""
|
||||
Merge new_node into existing_node's set. The new node must be a singleton set.
|
||||
"""
|
||||
existing_set = self.merge_sets[existing_node]
|
||||
new_set = self.merge_sets[new_node]
|
||||
assert len(new_set) == 1
|
||||
|
||||
# Add all nodes from new_set to existing_set
|
||||
existing_set.update(new_set)
|
||||
|
||||
# Update all nodes from new_set to point to existing_set
|
||||
for node in new_set:
|
||||
self.merge_sets[node] = existing_set
|
||||
|
||||
def unmerge_node(self, node: fx.Node) -> None:
|
||||
"""Remove a node from its merge set, making it singleton."""
|
||||
old_set = self.merge_sets[node]
|
||||
|
||||
# If already singleton, nothing to do
|
||||
if len(old_set) == 1:
|
||||
return
|
||||
|
||||
# Remove from old set
|
||||
old_set.remove(node)
|
||||
|
||||
# Make node singleton
|
||||
self.merge_sets[node] = OrderedSet([node])
|
||||
|
||||
def get_merged_deps(self, node: fx.Node) -> OrderedSet[fx.Node]:
|
||||
"""
|
||||
Get all dependencies of a node considering merges and extra deps.
|
||||
Combines:
|
||||
1. Direct deps (all_input_nodes) of node and its merge equivalents
|
||||
2. Extra deps of node and its merge equivalents
|
||||
"""
|
||||
deps: OrderedSet[fx.Node] = OrderedSet()
|
||||
|
||||
# For each node in the merge set
|
||||
for merged_node in self.merge_sets[node]:
|
||||
# Add direct dependencies from all_input_nodes
|
||||
deps.update(merged_node.all_input_nodes)
|
||||
# Add extra dependencies
|
||||
deps.update(self.extra_deps[merged_node])
|
||||
|
||||
return deps
|
||||
|
||||
def has_cycle(self) -> bool:
|
||||
merged_deps = {n: self.get_merged_deps(n) for n in self.graph.nodes}
|
||||
return torch._dynamo.graph_deduplication._has_cycle(self.graph, merged_deps)
|
||||
|
||||
def has_path(self, source: fx.Node, target: fx.Node) -> bool:
|
||||
"""Check if there's a path from source to target."""
|
||||
# we should not be checking path from node to itself
|
||||
assert self.merge_sets[source] is not self.merge_sets[target]
|
||||
|
||||
# search backwards from target to source
|
||||
visited: OrderedSet[fx.Node] = OrderedSet()
|
||||
queue = [target]
|
||||
visited.add(target)
|
||||
|
||||
while queue:
|
||||
current = queue.pop()
|
||||
|
||||
# Get all dependencies
|
||||
for dep in self.get_merged_deps(current):
|
||||
# Check if we reached source or its equivalent
|
||||
if dep in self.merge_sets[source]:
|
||||
return True
|
||||
|
||||
if dep not in visited:
|
||||
visited.add(dep)
|
||||
queue.append(dep)
|
||||
|
||||
return False
|
Reference in New Issue
Block a user