mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Differential Revision: [D77459912](https://our.internmc.facebook.com/intern/diff/D77459912) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156761 Approved by: https://github.com/angelayi
285 lines
11 KiB
Python
285 lines
11 KiB
Python
# Owner(s): ["oncall: export"]
|
|
|
|
import json
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
|
|
class TestUpgrader(TestCase):
|
|
def setUp(self) -> None:
|
|
# Register example upgraders dynamically
|
|
torch._C._export.register_example_upgraders()
|
|
|
|
def tearDown(self) -> None:
|
|
# Clean up registered upgraders
|
|
torch._C._export.deregister_example_upgraders()
|
|
|
|
def test_nn_module_stack_transformation_from_v0(self):
|
|
"""Test that nn_module_stack strings are prepended with 'test_upgrader_' when upgrading from version 0"""
|
|
|
|
# Create a mock JSON object that simulates version 0 schema
|
|
# with nn_module_stack as a string that needs to be upgraded
|
|
mock_json = {
|
|
"schema_version": {"major": 0, "minor": 0},
|
|
"graph_module": {
|
|
"graph": {
|
|
"nodes": [
|
|
{
|
|
"target": "aten.add.Tensor",
|
|
"inputs": [],
|
|
"outputs": [],
|
|
"metadata": {
|
|
"nn_module_stack": "original_stack_info",
|
|
"other_field": "some_value",
|
|
},
|
|
},
|
|
{
|
|
"target": "aten.mul.Tensor",
|
|
"inputs": [],
|
|
"outputs": [],
|
|
"metadata": {
|
|
"nn_module_stack": "another_stack",
|
|
"stack_trace": "some trace",
|
|
},
|
|
},
|
|
]
|
|
}
|
|
},
|
|
}
|
|
|
|
# Test the upgrader using the Python binding
|
|
serialized_json = json.dumps(mock_json)
|
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
|
upgraded_json = json.loads(upgraded_json_str)
|
|
|
|
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
|
|
|
# Verify nn_module_stack was prepended with "test_upgrader_"
|
|
nodes = upgraded_json["graph_module"]["graph"]["nodes"]
|
|
|
|
# Check first node
|
|
first_node_metadata = nodes[0]["metadata"]
|
|
nn_stack = first_node_metadata["nn_module_stack"]
|
|
self.assertIsInstance(nn_stack, str)
|
|
self.assertEqual(nn_stack, "test_upgrader_original_stack_info")
|
|
# Other metadata should be unchanged
|
|
self.assertEqual(first_node_metadata["other_field"], "some_value")
|
|
|
|
# Check second node
|
|
second_node_metadata = nodes[1]["metadata"]
|
|
nn_stack2 = second_node_metadata["nn_module_stack"]
|
|
self.assertIsInstance(nn_stack2, str)
|
|
self.assertEqual(nn_stack2, "test_upgrader_another_stack")
|
|
# Other metadata should be unchanged
|
|
self.assertEqual(second_node_metadata["stack_trace"], "some trace")
|
|
|
|
def test_nn_module_stack_error_handling_invalid_type(self):
|
|
"""Test error handling when nn_module_stack is not a string"""
|
|
|
|
# Test case: nn_module_stack is not a string
|
|
mock_json_invalid_type = {
|
|
"schema_version": {"major": 0, "minor": 0},
|
|
"graph_module": {
|
|
"graph": {
|
|
"nodes": [
|
|
{
|
|
"target": "aten.add.Tensor",
|
|
"inputs": [],
|
|
"outputs": [],
|
|
"metadata": {
|
|
"nn_module_stack": 42 # Invalid: should be string
|
|
},
|
|
}
|
|
]
|
|
}
|
|
},
|
|
}
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Error in upgrader 'version_0_upgrader_registered'",
|
|
):
|
|
serialized_json = json.dumps(mock_json_invalid_type)
|
|
torch._C._export.upgrade(serialized_json, 2)
|
|
|
|
def test_nodes_without_metadata_handled_gracefully(self):
|
|
"""Test that nodes without metadata or nn_module_stack are handled gracefully"""
|
|
|
|
mock_json = {
|
|
"schema_version": {"major": 0, "minor": 0},
|
|
"graph_module": {
|
|
"graph": {
|
|
"nodes": [
|
|
{
|
|
"target": "aten.add.Tensor",
|
|
"inputs": [],
|
|
"outputs": [],
|
|
# No metadata field
|
|
},
|
|
{
|
|
"target": "aten.mul.Tensor",
|
|
"inputs": [],
|
|
"outputs": [],
|
|
"metadata": {
|
|
"stack_trace": "some trace"
|
|
# No nn_module_stack field
|
|
},
|
|
},
|
|
]
|
|
}
|
|
},
|
|
}
|
|
|
|
# Should not raise an error
|
|
serialized_json = json.dumps(mock_json)
|
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
|
upgraded_json = json.loads(upgraded_json_str)
|
|
|
|
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
|
|
|
# Verify nodes are unchanged
|
|
nodes = upgraded_json["graph_module"]["graph"]["nodes"]
|
|
self.assertEqual(len(nodes), 2)
|
|
|
|
# First node should have no metadata
|
|
self.assertNotIn("metadata", nodes[0])
|
|
|
|
# Second node should have unchanged metadata
|
|
self.assertEqual(nodes[1]["metadata"]["stack_trace"], "some trace")
|
|
self.assertNotIn("nn_module_stack", nodes[1]["metadata"])
|
|
|
|
def test_field_renaming_chain_from_v0_complete(self):
|
|
"""Test complete field renaming chain from v0: old_test_field -> new_test_field -> new_test_field2"""
|
|
|
|
mock_json = {
|
|
"schema_version": {"major": 0, "minor": 0},
|
|
"graph_module": {
|
|
"graph": {
|
|
"inputs": [],
|
|
"outputs": [],
|
|
"nodes": [
|
|
{
|
|
"target": "aten.add.Tensor",
|
|
"inputs": [],
|
|
"outputs": [],
|
|
"metadata": {"nn_module_stack": "test_stack"},
|
|
}
|
|
],
|
|
"old_test_field": "original_value",
|
|
"existing_field": "existing_value",
|
|
}
|
|
},
|
|
}
|
|
|
|
# Test the upgrader using the Python binding
|
|
serialized_json = json.dumps(mock_json)
|
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
|
upgraded_json = json.loads(upgraded_json_str)
|
|
|
|
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
|
|
|
# Verify complete field transformation: old_test_field -> new_test_field -> new_test_field2
|
|
graph = upgraded_json["graph_module"]["graph"]
|
|
self.assertIn("new_test_field2", graph)
|
|
self.assertEqual(graph["new_test_field2"], "original_value")
|
|
self.assertNotIn("old_test_field", graph)
|
|
self.assertNotIn("new_test_field", graph)
|
|
|
|
# Verify existing fields are preserved
|
|
self.assertEqual(graph["existing_field"], "existing_value")
|
|
self.assertIn("inputs", graph)
|
|
self.assertIn("outputs", graph)
|
|
self.assertIn("nodes", graph)
|
|
|
|
# Verify the nn_module_stack was also upgraded by the other upgrader
|
|
nodes = graph["nodes"]
|
|
self.assertEqual(
|
|
nodes[0]["metadata"]["nn_module_stack"], "test_upgrader_test_stack"
|
|
)
|
|
|
|
def test_field_renaming_chain_from_v0_missing_field(self):
|
|
"""Test that upgraders work gracefully when old_test_field doesn't exist"""
|
|
|
|
mock_json = {
|
|
"schema_version": {"major": 0, "minor": 0},
|
|
"graph_module": {
|
|
"graph": {
|
|
"inputs": [],
|
|
"outputs": [],
|
|
"nodes": [],
|
|
"existing_field": "existing_value",
|
|
}
|
|
},
|
|
}
|
|
|
|
# Test the upgrader using the Python binding
|
|
serialized_json = json.dumps(mock_json)
|
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
|
upgraded_json = json.loads(upgraded_json_str)
|
|
|
|
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
|
|
|
# Verify no field transformations occurred since old_test_field didn't exist
|
|
graph = upgraded_json["graph_module"]["graph"]
|
|
self.assertNotIn("new_test_field2", graph)
|
|
self.assertNotIn("new_test_field", graph)
|
|
self.assertNotIn("old_test_field", graph)
|
|
|
|
# Verify existing fields are preserved
|
|
self.assertEqual(graph["existing_field"], "existing_value")
|
|
self.assertIn("inputs", graph)
|
|
self.assertIn("outputs", graph)
|
|
self.assertIn("nodes", graph)
|
|
|
|
def test_field_renaming_from_v1_partial_chain(self):
|
|
"""Test partial upgrade chain starting from v1: new_test_field -> new_test_field2"""
|
|
|
|
mock_json = {
|
|
"schema_version": {"major": 1, "minor": 0},
|
|
"graph_module": {
|
|
"graph": {
|
|
"inputs": [],
|
|
"outputs": [],
|
|
"nodes": [],
|
|
"new_test_field": "test_value",
|
|
"existing_field": "existing_value",
|
|
}
|
|
},
|
|
}
|
|
|
|
# Test the upgrader using the Python binding
|
|
serialized_json = json.dumps(mock_json)
|
|
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
|
upgraded_json = json.loads(upgraded_json_str)
|
|
|
|
# Verify the schema version was updated (version 1 -> version 2 due to v1 upgrader only)
|
|
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
|
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
|
|
|
# Verify new_test_field was renamed to new_test_field2
|
|
graph = upgraded_json["graph_module"]["graph"]
|
|
self.assertIn("new_test_field2", graph)
|
|
self.assertEqual(graph["new_test_field2"], "test_value")
|
|
self.assertNotIn("new_test_field", graph)
|
|
|
|
# Verify existing fields are preserved
|
|
self.assertEqual(graph["existing_field"], "existing_value")
|
|
self.assertIn("inputs", graph)
|
|
self.assertIn("outputs", graph)
|
|
self.assertIn("nodes", graph)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|