add eq function to NodeSource (#158170)

Summary: add eq function to NodeSouce by comparing their dict representation.

Test Plan:
ci

Rollback Plan:

Differential Revision: D78200762

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158170
Approved by: https://github.com/ezyang, https://github.com/yushangdi
This commit is contained in:
Songhao Jia
2025-07-15 00:50:03 +00:00
committed by PyTorch MergeBot
parent 7e433d5f42
commit 1c6057fd17
2 changed files with 61 additions and 0 deletions

View File

@ -2,6 +2,7 @@
import torch
from torch._inductor.compile_fx import aot_export_module
from torch.export import default_decompositions
from torch.fx.traceback import get_graph_provenance_json, NodeSource, NodeSourceAction
from torch.testing._internal.common_utils import TestCase
@ -64,6 +65,24 @@ class TestFXNodeSource(TestCase):
},
)
# Test two node sources are same
node_source1 = NodeSource(
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
)
node_source2 = NodeSource(
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
)
self.assertEqual(node_source1, node_source2)
# Test two node sources are not same
node_source3 = NodeSource(
node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE
)
node_source4 = NodeSource(
node=None, pass_name="test_pass_2", action=NodeSourceAction.CREATE
)
self.assertNotEqual(node_source3, node_source4)
def test_graph_provenance(self):
def check_node_source(node_source_dict, name, pass_name, action):
self.assertEqual(node_source_dict["name"], name)
@ -95,6 +114,43 @@ class TestFXNodeSource(TestCase):
model = Model()
example_inputs = (torch.randn(8, 10),)
ep = torch.export.export(model, example_inputs, strict=True)
decomposed_ep = ep.run_decompositions(default_decompositions())
# node decomposed from same ancestor node should have same from_node info
for node in decomposed_ep.graph.nodes:
if node.op not in {"placeholder", "output"}:
assert "from_node" in node.meta
node_name_to_from_node = {
node.name: node.meta["from_node"]
for node in decomposed_ep.graph.nodes
if node.op not in {"placeholder", "output"}
}
same_ancestor_nodes = {
"permute": "addmm",
"addmm": "permute",
"permute_1": "addmm_1",
"addmm_1": "permute_1",
}
for node_name_1 in node_name_to_from_node:
for node_name_2 in node_name_to_from_node:
if node_name_2 in {
node_name_1,
same_ancestor_nodes[node_name_1]
if node_name_1 in same_ancestor_nodes
else None,
}:
self.assertTrue(
node_name_to_from_node[node_name_1]
== node_name_to_from_node[node_name_2]
)
else:
self.assertTrue(
node_name_to_from_node[node_name_1]
!= node_name_to_from_node[node_name_2]
)
gm = ep.module()
provenance = get_graph_provenance_json(gm.graph)
self.assertEqual(

View File

@ -123,6 +123,11 @@ class NodeSource:
"from_node": [node.to_dict() for node in self.from_node],
}
def __eq__(self, other: object):
if not isinstance(other, NodeSource):
return False
return self.to_dict() == other.to_dict()
@compatibility(is_backward_compatible=False)
@contextmanager