mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
make node source hashable (#158322)
Summary: as title Test Plan: ci Rollback Plan: Reviewed By: yushangdi Differential Revision: D78296410 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158322 Approved by: https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
4657a84bc5
commit
011026205a
@ -74,6 +74,9 @@ class TestFXNodeSource(TestCase):
|
||||
)
|
||||
self.assertEqual(node_source1, node_source2)
|
||||
|
||||
# Test hash function - equivalent objects should have same hash
|
||||
self.assertEqual(hash(node_source1), hash(node_source2))
|
||||
|
||||
# Test two node sources are not same
|
||||
node_source3 = NodeSource(
|
||||
node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE
|
||||
@ -83,6 +86,41 @@ class TestFXNodeSource(TestCase):
|
||||
)
|
||||
self.assertNotEqual(node_source3, node_source4)
|
||||
|
||||
# Test hash function - different objects should have different hash
|
||||
self.assertNotEqual(hash(node_source3), hash(node_source4))
|
||||
|
||||
# Test that equivalent NodeSource objects can be used in sets and dicts
|
||||
node_set = {node_source1, node_source2}
|
||||
self.assertEqual(len(node_set), 1) # Should only contain one unique element
|
||||
|
||||
node_dict = {node_source1: "value1", node_source2: "value2"}
|
||||
self.assertEqual(len(node_dict), 1) # Should only contain one key
|
||||
self.assertEqual(node_dict[node_source1], "value2") # Last value should win
|
||||
|
||||
# Test with more complex NodeSource objects
|
||||
node_source_with_node = NodeSource(
|
||||
node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
|
||||
)
|
||||
node_source_with_node_copy = NodeSource(
|
||||
node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
|
||||
)
|
||||
|
||||
# These should be equal and have same hash
|
||||
self.assertEqual(node_source_with_node, node_source_with_node_copy)
|
||||
self.assertEqual(hash(node_source_with_node), hash(node_source_with_node_copy))
|
||||
|
||||
# Test with different actions
|
||||
node_source_replace = NodeSource(
|
||||
node=None, pass_name="test_pass", action=NodeSourceAction.REPLACE
|
||||
)
|
||||
node_source_create = NodeSource(
|
||||
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
|
||||
)
|
||||
|
||||
# These should be different and have different hashes
|
||||
self.assertNotEqual(node_source_replace, node_source_create)
|
||||
self.assertNotEqual(hash(node_source_replace), hash(node_source_create))
|
||||
|
||||
def test_graph_provenance(self):
|
||||
def check_node_source(node_source_dict, name, pass_name, action):
|
||||
self.assertEqual(node_source_dict["name"], name)
|
||||
|
@ -128,6 +128,19 @@ class NodeSource:
|
||||
return False
|
||||
return self.to_dict() == other.to_dict()
|
||||
|
||||
def __hash__(self):
|
||||
# Create a hash based on the dictionary representation
|
||||
# We need to convert the dict to a hashable form
|
||||
def _make_hashable(obj):
|
||||
if isinstance(obj, dict):
|
||||
return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
|
||||
elif isinstance(obj, list):
|
||||
return tuple(_make_hashable(item) for item in obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
return hash(_make_hashable(self.to_dict()))
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@contextmanager
|
||||
|
Reference in New Issue
Block a user