make node source hashable (#158322)

Summary: as title

Test Plan:
ci

Rollback Plan:

Reviewed By: yushangdi

Differential Revision: D78296410

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158322
Approved by: https://github.com/yushangdi
This commit is contained in:
Songhao Jia
2025-07-15 19:31:00 +00:00
committed by PyTorch MergeBot
parent 4657a84bc5
commit 011026205a
2 changed files with 51 additions and 0 deletions

View File

@ -74,6 +74,9 @@ class TestFXNodeSource(TestCase):
)
self.assertEqual(node_source1, node_source2)
# Test hash function - equivalent objects should have same hash
self.assertEqual(hash(node_source1), hash(node_source2))
# Test two node sources are not same
node_source3 = NodeSource(
node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE
@ -83,6 +86,41 @@ class TestFXNodeSource(TestCase):
)
self.assertNotEqual(node_source3, node_source4)
# Test hash function - different objects should have different hash
self.assertNotEqual(hash(node_source3), hash(node_source4))
# Test that equivalent NodeSource objects can be used in sets and dicts
node_set = {node_source1, node_source2}
self.assertEqual(len(node_set), 1) # Should only contain one unique element
node_dict = {node_source1: "value1", node_source2: "value2"}
self.assertEqual(len(node_dict), 1) # Should only contain one key
self.assertEqual(node_dict[node_source1], "value2") # Last value should win
# Test with more complex NodeSource objects
node_source_with_node = NodeSource(
node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
)
node_source_with_node_copy = NodeSource(
node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
)
# These should be equal and have same hash
self.assertEqual(node_source_with_node, node_source_with_node_copy)
self.assertEqual(hash(node_source_with_node), hash(node_source_with_node_copy))
# Test with different actions
node_source_replace = NodeSource(
node=None, pass_name="test_pass", action=NodeSourceAction.REPLACE
)
node_source_create = NodeSource(
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
)
# These should be different and have different hashes
self.assertNotEqual(node_source_replace, node_source_create)
self.assertNotEqual(hash(node_source_replace), hash(node_source_create))
def test_graph_provenance(self):
def check_node_source(node_source_dict, name, pass_name, action):
self.assertEqual(node_source_dict["name"], name)

View File

@ -128,6 +128,19 @@ class NodeSource:
return False
return self.to_dict() == other.to_dict()
def __hash__(self):
# Create a hash based on the dictionary representation
# We need to convert the dict to a hashable form
def _make_hashable(obj):
if isinstance(obj, dict):
return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
elif isinstance(obj, list):
return tuple(_make_hashable(item) for item in obj)
else:
return obj
return hash(_make_hashable(self.to_dict()))
@compatibility(is_backward_compatible=False)
@contextmanager