mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #157891 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157892 Approved by: https://github.com/ezyang
66 lines
2.4 KiB
Python
66 lines
2.4 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
from collections.abc import Mapping
|
|
|
|
import torch
|
|
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
|
from torch.fx.passes.operator_support import OperatorSupport
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
|
|
class DummyDevOperatorSupport(OperatorSupport):
|
|
def is_node_supported(
|
|
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
|
|
) -> bool:
|
|
return True
|
|
|
|
|
|
class DummyPartitioner(CapabilityBasedPartitioner):
|
|
def __init__(self, graph_module: torch.fx.GraphModule):
|
|
super().__init__(
|
|
graph_module,
|
|
DummyDevOperatorSupport(),
|
|
allows_single_node_partition=True,
|
|
)
|
|
|
|
|
|
# original graph node order is: ['x', 'add', 'add_1', 'output']
|
|
class AddModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.add(x, x)
|
|
z = torch.add(y, x)
|
|
return z
|
|
|
|
|
|
class TestPartitionerOrder(TestCase):
|
|
# partitoner test to check graph node order remains the same with the original graph after partitioning
|
|
def test_partitioner_graph_node_order(self):
|
|
m = AddModule()
|
|
traced_m = torch.fx.symbolic_trace(m)
|
|
origin_node_order = [n.name for n in traced_m.graph.nodes]
|
|
partions = DummyPartitioner(traced_m).propose_partitions()
|
|
partion_nodes = [list(partition.nodes) for partition in partions]
|
|
partition_node_order = [n.name for n in partion_nodes[0]]
|
|
self.assertTrue(partition_node_order == origin_node_order)
|
|
|
|
# partitoner test to check graph node order remains the same during multiple runs
|
|
def test_partitioner_multiple_runs_order(self):
|
|
m = AddModule()
|
|
traced_m = torch.fx.symbolic_trace(m)
|
|
partitions = DummyPartitioner(traced_m).propose_partitions()
|
|
partition_nodes = [list(partition.nodes) for partition in partitions]
|
|
node_order = [n.name for n in partition_nodes[0]]
|
|
for _ in range(10):
|
|
traced_m = torch.fx.symbolic_trace(m)
|
|
new_partion = DummyPartitioner(traced_m).propose_partitions()
|
|
new_partion_nodes = [list(partition.nodes) for partition in new_partion]
|
|
new_node_order = [n.name for n in new_partion_nodes[0]]
|
|
self.assertTrue(node_order == new_node_order)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test is not currently used and should be "
|
|
"enabled in discover_tests.py if required."
|
|
)
|