mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add eq function to NodeSource (#158170)
Summary: add eq function to NodeSouce by comparing their dict representation. Test Plan: ci Rollback Plan: Differential Revision: D78200762 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158170 Approved by: https://github.com/ezyang, https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
7e433d5f42
commit
1c6057fd17
@ -2,6 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch._inductor.compile_fx import aot_export_module
|
||||
from torch.export import default_decompositions
|
||||
from torch.fx.traceback import get_graph_provenance_json, NodeSource, NodeSourceAction
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
@ -64,6 +65,24 @@ class TestFXNodeSource(TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
# Test two node sources are same
|
||||
node_source1 = NodeSource(
|
||||
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
|
||||
)
|
||||
node_source2 = NodeSource(
|
||||
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
|
||||
)
|
||||
self.assertEqual(node_source1, node_source2)
|
||||
|
||||
# Test two node sources are not same
|
||||
node_source3 = NodeSource(
|
||||
node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE
|
||||
)
|
||||
node_source4 = NodeSource(
|
||||
node=None, pass_name="test_pass_2", action=NodeSourceAction.CREATE
|
||||
)
|
||||
self.assertNotEqual(node_source3, node_source4)
|
||||
|
||||
def test_graph_provenance(self):
|
||||
def check_node_source(node_source_dict, name, pass_name, action):
|
||||
self.assertEqual(node_source_dict["name"], name)
|
||||
@ -95,6 +114,43 @@ class TestFXNodeSource(TestCase):
|
||||
model = Model()
|
||||
example_inputs = (torch.randn(8, 10),)
|
||||
ep = torch.export.export(model, example_inputs, strict=True)
|
||||
|
||||
decomposed_ep = ep.run_decompositions(default_decompositions())
|
||||
# node decomposed from same ancestor node should have same from_node info
|
||||
for node in decomposed_ep.graph.nodes:
|
||||
if node.op not in {"placeholder", "output"}:
|
||||
assert "from_node" in node.meta
|
||||
|
||||
node_name_to_from_node = {
|
||||
node.name: node.meta["from_node"]
|
||||
for node in decomposed_ep.graph.nodes
|
||||
if node.op not in {"placeholder", "output"}
|
||||
}
|
||||
same_ancestor_nodes = {
|
||||
"permute": "addmm",
|
||||
"addmm": "permute",
|
||||
"permute_1": "addmm_1",
|
||||
"addmm_1": "permute_1",
|
||||
}
|
||||
|
||||
for node_name_1 in node_name_to_from_node:
|
||||
for node_name_2 in node_name_to_from_node:
|
||||
if node_name_2 in {
|
||||
node_name_1,
|
||||
same_ancestor_nodes[node_name_1]
|
||||
if node_name_1 in same_ancestor_nodes
|
||||
else None,
|
||||
}:
|
||||
self.assertTrue(
|
||||
node_name_to_from_node[node_name_1]
|
||||
== node_name_to_from_node[node_name_2]
|
||||
)
|
||||
else:
|
||||
self.assertTrue(
|
||||
node_name_to_from_node[node_name_1]
|
||||
!= node_name_to_from_node[node_name_2]
|
||||
)
|
||||
|
||||
gm = ep.module()
|
||||
provenance = get_graph_provenance_json(gm.graph)
|
||||
self.assertEqual(
|
||||
|
@ -123,6 +123,11 @@ class NodeSource:
|
||||
"from_node": [node.to_dict() for node in self.from_node],
|
||||
}
|
||||
|
||||
def __eq__(self, other: object):
|
||||
if not isinstance(other, NodeSource):
|
||||
return False
|
||||
return self.to_dict() == other.to_dict()
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@contextmanager
|
||||
|
Reference in New Issue
Block a user