From 011026205a9d4c38458130f8ca242028f6184bf0 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 15 Jul 2025 19:31:00 +0000 Subject: [PATCH] make node source hashable (#158322) Summary: as title Test Plan: ci Rollback Plan: Reviewed By: yushangdi Differential Revision: D78296410 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158322 Approved by: https://github.com/yushangdi --- test/fx/test_fx_traceback.py | 38 ++++++++++++++++++++++++++++++++++++ torch/fx/traceback.py | 13 ++++++++++++ 2 files changed, 51 insertions(+) diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index e11ee19daaac..f02bc5a2e159 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -74,6 +74,9 @@ class TestFXNodeSource(TestCase): ) self.assertEqual(node_source1, node_source2) + # Test hash function - equivalent objects should have same hash + self.assertEqual(hash(node_source1), hash(node_source2)) + # Test two node sources are not same node_source3 = NodeSource( node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE @@ -83,6 +86,41 @@ class TestFXNodeSource(TestCase): ) self.assertNotEqual(node_source3, node_source4) + # Test hash function - different objects should have different hash + self.assertNotEqual(hash(node_source3), hash(node_source4)) + + # Test that equivalent NodeSource objects can be used in sets and dicts + node_set = {node_source1, node_source2} + self.assertEqual(len(node_set), 1) # Should only contain one unique element + + node_dict = {node_source1: "value1", node_source2: "value2"} + self.assertEqual(len(node_dict), 1) # Should only contain one key + self.assertEqual(node_dict[node_source1], "value2") # Last value should win + + # Test with more complex NodeSource objects + node_source_with_node = NodeSource( + node=node, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + node_source_with_node_copy = NodeSource( + node=node, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + + # These should be equal and have same hash + self.assertEqual(node_source_with_node, node_source_with_node_copy) + self.assertEqual(hash(node_source_with_node), hash(node_source_with_node_copy)) + + # Test with different actions + node_source_replace = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.REPLACE + ) + node_source_create = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + + # These should be different and have different hashes + self.assertNotEqual(node_source_replace, node_source_create) + self.assertNotEqual(hash(node_source_replace), hash(node_source_create)) + def test_graph_provenance(self): def check_node_source(node_source_dict, name, pass_name, action): self.assertEqual(node_source_dict["name"], name) diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 9f316191a230..97391d567aba 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -128,6 +128,19 @@ class NodeSource: return False return self.to_dict() == other.to_dict() + def __hash__(self): + # Create a hash based on the dictionary representation + # We need to convert the dict to a hashable form + def _make_hashable(obj): + if isinstance(obj, dict): + return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items())) + elif isinstance(obj, list): + return tuple(_make_hashable(item) for item in obj) + else: + return obj + + return hash(_make_hashable(self.to_dict())) + @compatibility(is_backward_compatible=False) @contextmanager