mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make node source hashable (#158322)
Summary: as title Test Plan: ci Rollback Plan: Reviewed By: yushangdi Differential Revision: D78296410 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158322 Approved by: https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
4657a84bc5
commit
011026205a
@ -74,6 +74,9 @@ class TestFXNodeSource(TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(node_source1, node_source2)
|
self.assertEqual(node_source1, node_source2)
|
||||||
|
|
||||||
|
# Test hash function - equivalent objects should have same hash
|
||||||
|
self.assertEqual(hash(node_source1), hash(node_source2))
|
||||||
|
|
||||||
# Test two node sources are not same
|
# Test two node sources are not same
|
||||||
node_source3 = NodeSource(
|
node_source3 = NodeSource(
|
||||||
node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE
|
node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE
|
||||||
@ -83,6 +86,41 @@ class TestFXNodeSource(TestCase):
|
|||||||
)
|
)
|
||||||
self.assertNotEqual(node_source3, node_source4)
|
self.assertNotEqual(node_source3, node_source4)
|
||||||
|
|
||||||
|
# Test hash function - different objects should have different hash
|
||||||
|
self.assertNotEqual(hash(node_source3), hash(node_source4))
|
||||||
|
|
||||||
|
# Test that equivalent NodeSource objects can be used in sets and dicts
|
||||||
|
node_set = {node_source1, node_source2}
|
||||||
|
self.assertEqual(len(node_set), 1) # Should only contain one unique element
|
||||||
|
|
||||||
|
node_dict = {node_source1: "value1", node_source2: "value2"}
|
||||||
|
self.assertEqual(len(node_dict), 1) # Should only contain one key
|
||||||
|
self.assertEqual(node_dict[node_source1], "value2") # Last value should win
|
||||||
|
|
||||||
|
# Test with more complex NodeSource objects
|
||||||
|
node_source_with_node = NodeSource(
|
||||||
|
node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
|
||||||
|
)
|
||||||
|
node_source_with_node_copy = NodeSource(
|
||||||
|
node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
|
||||||
|
)
|
||||||
|
|
||||||
|
# These should be equal and have same hash
|
||||||
|
self.assertEqual(node_source_with_node, node_source_with_node_copy)
|
||||||
|
self.assertEqual(hash(node_source_with_node), hash(node_source_with_node_copy))
|
||||||
|
|
||||||
|
# Test with different actions
|
||||||
|
node_source_replace = NodeSource(
|
||||||
|
node=None, pass_name="test_pass", action=NodeSourceAction.REPLACE
|
||||||
|
)
|
||||||
|
node_source_create = NodeSource(
|
||||||
|
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
|
||||||
|
)
|
||||||
|
|
||||||
|
# These should be different and have different hashes
|
||||||
|
self.assertNotEqual(node_source_replace, node_source_create)
|
||||||
|
self.assertNotEqual(hash(node_source_replace), hash(node_source_create))
|
||||||
|
|
||||||
def test_graph_provenance(self):
|
def test_graph_provenance(self):
|
||||||
def check_node_source(node_source_dict, name, pass_name, action):
|
def check_node_source(node_source_dict, name, pass_name, action):
|
||||||
self.assertEqual(node_source_dict["name"], name)
|
self.assertEqual(node_source_dict["name"], name)
|
||||||
|
@ -128,6 +128,19 @@ class NodeSource:
|
|||||||
return False
|
return False
|
||||||
return self.to_dict() == other.to_dict()
|
return self.to_dict() == other.to_dict()
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
# Create a hash based on the dictionary representation
|
||||||
|
# We need to convert the dict to a hashable form
|
||||||
|
def _make_hashable(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return tuple(_make_hashable(item) for item in obj)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return hash(_make_hashable(self.to_dict()))
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
Reference in New Issue
Block a user