Files
pytorch/test/export/test_upgrader.py

285 lines
11 KiB
Python

# Owner(s): ["oncall: export"]
import json
import torch
from torch.testing._internal.common_utils import TestCase
class TestUpgrader(TestCase):
def setUp(self) -> None:
# Register example upgraders dynamically
torch._C._export.register_example_upgraders()
def tearDown(self) -> None:
# Clean up registered upgraders
torch._C._export.deregister_example_upgraders()
def test_nn_module_stack_transformation_from_v0(self):
"""Test that nn_module_stack strings are prepended with 'test_upgrader_' when upgrading from version 0"""
# Create a mock JSON object that simulates version 0 schema
# with nn_module_stack as a string that needs to be upgraded
mock_json = {
"schema_version": {"major": 0, "minor": 0},
"graph_module": {
"graph": {
"nodes": [
{
"target": "aten.add.Tensor",
"inputs": [],
"outputs": [],
"metadata": {
"nn_module_stack": "original_stack_info",
"other_field": "some_value",
},
},
{
"target": "aten.mul.Tensor",
"inputs": [],
"outputs": [],
"metadata": {
"nn_module_stack": "another_stack",
"stack_trace": "some trace",
},
},
]
}
},
}
# Test the upgrader using the Python binding
serialized_json = json.dumps(mock_json)
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
upgraded_json = json.loads(upgraded_json_str)
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
# Verify nn_module_stack was prepended with "test_upgrader_"
nodes = upgraded_json["graph_module"]["graph"]["nodes"]
# Check first node
first_node_metadata = nodes[0]["metadata"]
nn_stack = first_node_metadata["nn_module_stack"]
self.assertIsInstance(nn_stack, str)
self.assertEqual(nn_stack, "test_upgrader_original_stack_info")
# Other metadata should be unchanged
self.assertEqual(first_node_metadata["other_field"], "some_value")
# Check second node
second_node_metadata = nodes[1]["metadata"]
nn_stack2 = second_node_metadata["nn_module_stack"]
self.assertIsInstance(nn_stack2, str)
self.assertEqual(nn_stack2, "test_upgrader_another_stack")
# Other metadata should be unchanged
self.assertEqual(second_node_metadata["stack_trace"], "some trace")
def test_nn_module_stack_error_handling_invalid_type(self):
"""Test error handling when nn_module_stack is not a string"""
# Test case: nn_module_stack is not a string
mock_json_invalid_type = {
"schema_version": {"major": 0, "minor": 0},
"graph_module": {
"graph": {
"nodes": [
{
"target": "aten.add.Tensor",
"inputs": [],
"outputs": [],
"metadata": {
"nn_module_stack": 42 # Invalid: should be string
},
}
]
}
},
}
with self.assertRaisesRegex(
RuntimeError,
"Error in upgrader 'version_0_upgrader_registered'",
):
serialized_json = json.dumps(mock_json_invalid_type)
torch._C._export.upgrade(serialized_json, 2)
def test_nodes_without_metadata_handled_gracefully(self):
"""Test that nodes without metadata or nn_module_stack are handled gracefully"""
mock_json = {
"schema_version": {"major": 0, "minor": 0},
"graph_module": {
"graph": {
"nodes": [
{
"target": "aten.add.Tensor",
"inputs": [],
"outputs": []
# No metadata field
},
{
"target": "aten.mul.Tensor",
"inputs": [],
"outputs": [],
"metadata": {
"stack_trace": "some trace"
# No nn_module_stack field
},
},
]
}
},
}
# Should not raise an error
serialized_json = json.dumps(mock_json)
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
upgraded_json = json.loads(upgraded_json_str)
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
# Verify nodes are unchanged
nodes = upgraded_json["graph_module"]["graph"]["nodes"]
self.assertEqual(len(nodes), 2)
# First node should have no metadata
self.assertNotIn("metadata", nodes[0])
# Second node should have unchanged metadata
self.assertEqual(nodes[1]["metadata"]["stack_trace"], "some trace")
self.assertNotIn("nn_module_stack", nodes[1]["metadata"])
def test_field_renaming_chain_from_v0_complete(self):
"""Test complete field renaming chain from v0: old_test_field -> new_test_field -> new_test_field2"""
mock_json = {
"schema_version": {"major": 0, "minor": 0},
"graph_module": {
"graph": {
"inputs": [],
"outputs": [],
"nodes": [
{
"target": "aten.add.Tensor",
"inputs": [],
"outputs": [],
"metadata": {"nn_module_stack": "test_stack"},
}
],
"old_test_field": "original_value",
"existing_field": "existing_value",
}
},
}
# Test the upgrader using the Python binding
serialized_json = json.dumps(mock_json)
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
upgraded_json = json.loads(upgraded_json_str)
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
# Verify complete field transformation: old_test_field -> new_test_field -> new_test_field2
graph = upgraded_json["graph_module"]["graph"]
self.assertIn("new_test_field2", graph)
self.assertEqual(graph["new_test_field2"], "original_value")
self.assertNotIn("old_test_field", graph)
self.assertNotIn("new_test_field", graph)
# Verify existing fields are preserved
self.assertEqual(graph["existing_field"], "existing_value")
self.assertIn("inputs", graph)
self.assertIn("outputs", graph)
self.assertIn("nodes", graph)
# Verify the nn_module_stack was also upgraded by the other upgrader
nodes = graph["nodes"]
self.assertEqual(
nodes[0]["metadata"]["nn_module_stack"], "test_upgrader_test_stack"
)
def test_field_renaming_chain_from_v0_missing_field(self):
"""Test that upgraders work gracefully when old_test_field doesn't exist"""
mock_json = {
"schema_version": {"major": 0, "minor": 0},
"graph_module": {
"graph": {
"inputs": [],
"outputs": [],
"nodes": [],
"existing_field": "existing_value",
}
},
}
# Test the upgrader using the Python binding
serialized_json = json.dumps(mock_json)
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
upgraded_json = json.loads(upgraded_json_str)
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
# Verify no field transformations occurred since old_test_field didn't exist
graph = upgraded_json["graph_module"]["graph"]
self.assertNotIn("new_test_field2", graph)
self.assertNotIn("new_test_field", graph)
self.assertNotIn("old_test_field", graph)
# Verify existing fields are preserved
self.assertEqual(graph["existing_field"], "existing_value")
self.assertIn("inputs", graph)
self.assertIn("outputs", graph)
self.assertIn("nodes", graph)
def test_field_renaming_from_v1_partial_chain(self):
"""Test partial upgrade chain starting from v1: new_test_field -> new_test_field2"""
mock_json = {
"schema_version": {"major": 1, "minor": 0},
"graph_module": {
"graph": {
"inputs": [],
"outputs": [],
"nodes": [],
"new_test_field": "test_value",
"existing_field": "existing_value",
}
},
}
# Test the upgrader using the Python binding
serialized_json = json.dumps(mock_json)
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
upgraded_json = json.loads(upgraded_json_str)
# Verify the schema version was updated (version 1 -> version 2 due to v1 upgrader only)
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
# Verify new_test_field was renamed to new_test_field2
graph = upgraded_json["graph_module"]["graph"]
self.assertIn("new_test_field2", graph)
self.assertEqual(graph["new_test_field2"], "test_value")
self.assertNotIn("new_test_field", graph)
# Verify existing fields are preserved
self.assertEqual(graph["existing_field"], "existing_value")
self.assertIn("inputs", graph)
self.assertIn("outputs", graph)
self.assertIn("nodes", graph)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()