mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[schema_upgrader] add C++ upgrader for json based upgrading (#156761)
Differential Revision: [D77459912](https://our.internmc.facebook.com/intern/diff/D77459912) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156761 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
2815ade9a8
commit
61712e6f2b
@ -894,6 +894,8 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
"torch/csrc/mtia/Module.cpp",
|
||||
"torch/csrc/export/pybind.cpp",
|
||||
"torch/csrc/export/upgrader.cpp",
|
||||
"torch/csrc/export/example_upgraders.cpp",
|
||||
"torch/csrc/inductor/aoti_package/pybind.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
|
||||
|
284
test/export/test_upgrader.py
Normal file
284
test/export/test_upgrader.py
Normal file
@ -0,0 +1,284 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
|
||||
import json
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class TestUpgrader(TestCase):
|
||||
def setUp(self) -> None:
|
||||
# Register example upgraders dynamically
|
||||
torch._C._export.register_example_upgraders()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
# Clean up registered upgraders
|
||||
torch._C._export.deregister_example_upgraders()
|
||||
|
||||
def test_nn_module_stack_transformation_from_v0(self):
|
||||
"""Test that nn_module_stack strings are prepended with 'test_upgrader_' when upgrading from version 0"""
|
||||
|
||||
# Create a mock JSON object that simulates version 0 schema
|
||||
# with nn_module_stack as a string that needs to be upgraded
|
||||
mock_json = {
|
||||
"schema_version": {"major": 0, "minor": 0},
|
||||
"graph_module": {
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"target": "aten.add.Tensor",
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"nn_module_stack": "original_stack_info",
|
||||
"other_field": "some_value",
|
||||
},
|
||||
},
|
||||
{
|
||||
"target": "aten.mul.Tensor",
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"nn_module_stack": "another_stack",
|
||||
"stack_trace": "some trace",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Test the upgrader using the Python binding
|
||||
serialized_json = json.dumps(mock_json)
|
||||
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||
upgraded_json = json.loads(upgraded_json_str)
|
||||
|
||||
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
||||
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||
|
||||
# Verify nn_module_stack was prepended with "test_upgrader_"
|
||||
nodes = upgraded_json["graph_module"]["graph"]["nodes"]
|
||||
|
||||
# Check first node
|
||||
first_node_metadata = nodes[0]["metadata"]
|
||||
nn_stack = first_node_metadata["nn_module_stack"]
|
||||
self.assertIsInstance(nn_stack, str)
|
||||
self.assertEqual(nn_stack, "test_upgrader_original_stack_info")
|
||||
# Other metadata should be unchanged
|
||||
self.assertEqual(first_node_metadata["other_field"], "some_value")
|
||||
|
||||
# Check second node
|
||||
second_node_metadata = nodes[1]["metadata"]
|
||||
nn_stack2 = second_node_metadata["nn_module_stack"]
|
||||
self.assertIsInstance(nn_stack2, str)
|
||||
self.assertEqual(nn_stack2, "test_upgrader_another_stack")
|
||||
# Other metadata should be unchanged
|
||||
self.assertEqual(second_node_metadata["stack_trace"], "some trace")
|
||||
|
||||
def test_nn_module_stack_error_handling_invalid_type(self):
|
||||
"""Test error handling when nn_module_stack is not a string"""
|
||||
|
||||
# Test case: nn_module_stack is not a string
|
||||
mock_json_invalid_type = {
|
||||
"schema_version": {"major": 0, "minor": 0},
|
||||
"graph_module": {
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"target": "aten.add.Tensor",
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"nn_module_stack": 42 # Invalid: should be string
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Error in upgrader 'version_0_upgrader_registered'",
|
||||
):
|
||||
serialized_json = json.dumps(mock_json_invalid_type)
|
||||
torch._C._export.upgrade(serialized_json, 2)
|
||||
|
||||
def test_nodes_without_metadata_handled_gracefully(self):
|
||||
"""Test that nodes without metadata or nn_module_stack are handled gracefully"""
|
||||
|
||||
mock_json = {
|
||||
"schema_version": {"major": 0, "minor": 0},
|
||||
"graph_module": {
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"target": "aten.add.Tensor",
|
||||
"inputs": [],
|
||||
"outputs": []
|
||||
# No metadata field
|
||||
},
|
||||
{
|
||||
"target": "aten.mul.Tensor",
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"stack_trace": "some trace"
|
||||
# No nn_module_stack field
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Should not raise an error
|
||||
serialized_json = json.dumps(mock_json)
|
||||
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||
upgraded_json = json.loads(upgraded_json_str)
|
||||
|
||||
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
||||
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||
|
||||
# Verify nodes are unchanged
|
||||
nodes = upgraded_json["graph_module"]["graph"]["nodes"]
|
||||
self.assertEqual(len(nodes), 2)
|
||||
|
||||
# First node should have no metadata
|
||||
self.assertNotIn("metadata", nodes[0])
|
||||
|
||||
# Second node should have unchanged metadata
|
||||
self.assertEqual(nodes[1]["metadata"]["stack_trace"], "some trace")
|
||||
self.assertNotIn("nn_module_stack", nodes[1]["metadata"])
|
||||
|
||||
def test_field_renaming_chain_from_v0_complete(self):
|
||||
"""Test complete field renaming chain from v0: old_test_field -> new_test_field -> new_test_field2"""
|
||||
|
||||
mock_json = {
|
||||
"schema_version": {"major": 0, "minor": 0},
|
||||
"graph_module": {
|
||||
"graph": {
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"nodes": [
|
||||
{
|
||||
"target": "aten.add.Tensor",
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"metadata": {"nn_module_stack": "test_stack"},
|
||||
}
|
||||
],
|
||||
"old_test_field": "original_value",
|
||||
"existing_field": "existing_value",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Test the upgrader using the Python binding
|
||||
serialized_json = json.dumps(mock_json)
|
||||
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||
upgraded_json = json.loads(upgraded_json_str)
|
||||
|
||||
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
||||
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||
|
||||
# Verify complete field transformation: old_test_field -> new_test_field -> new_test_field2
|
||||
graph = upgraded_json["graph_module"]["graph"]
|
||||
self.assertIn("new_test_field2", graph)
|
||||
self.assertEqual(graph["new_test_field2"], "original_value")
|
||||
self.assertNotIn("old_test_field", graph)
|
||||
self.assertNotIn("new_test_field", graph)
|
||||
|
||||
# Verify existing fields are preserved
|
||||
self.assertEqual(graph["existing_field"], "existing_value")
|
||||
self.assertIn("inputs", graph)
|
||||
self.assertIn("outputs", graph)
|
||||
self.assertIn("nodes", graph)
|
||||
|
||||
# Verify the nn_module_stack was also upgraded by the other upgrader
|
||||
nodes = graph["nodes"]
|
||||
self.assertEqual(
|
||||
nodes[0]["metadata"]["nn_module_stack"], "test_upgrader_test_stack"
|
||||
)
|
||||
|
||||
def test_field_renaming_chain_from_v0_missing_field(self):
|
||||
"""Test that upgraders work gracefully when old_test_field doesn't exist"""
|
||||
|
||||
mock_json = {
|
||||
"schema_version": {"major": 0, "minor": 0},
|
||||
"graph_module": {
|
||||
"graph": {
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"nodes": [],
|
||||
"existing_field": "existing_value",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Test the upgrader using the Python binding
|
||||
serialized_json = json.dumps(mock_json)
|
||||
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||
upgraded_json = json.loads(upgraded_json_str)
|
||||
|
||||
# Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders)
|
||||
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||
|
||||
# Verify no field transformations occurred since old_test_field didn't exist
|
||||
graph = upgraded_json["graph_module"]["graph"]
|
||||
self.assertNotIn("new_test_field2", graph)
|
||||
self.assertNotIn("new_test_field", graph)
|
||||
self.assertNotIn("old_test_field", graph)
|
||||
|
||||
# Verify existing fields are preserved
|
||||
self.assertEqual(graph["existing_field"], "existing_value")
|
||||
self.assertIn("inputs", graph)
|
||||
self.assertIn("outputs", graph)
|
||||
self.assertIn("nodes", graph)
|
||||
|
||||
def test_field_renaming_from_v1_partial_chain(self):
|
||||
"""Test partial upgrade chain starting from v1: new_test_field -> new_test_field2"""
|
||||
|
||||
mock_json = {
|
||||
"schema_version": {"major": 1, "minor": 0},
|
||||
"graph_module": {
|
||||
"graph": {
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"nodes": [],
|
||||
"new_test_field": "test_value",
|
||||
"existing_field": "existing_value",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Test the upgrader using the Python binding
|
||||
serialized_json = json.dumps(mock_json)
|
||||
upgraded_json_str = torch._C._export.upgrade(serialized_json, 2)
|
||||
upgraded_json = json.loads(upgraded_json_str)
|
||||
|
||||
# Verify the schema version was updated (version 1 -> version 2 due to v1 upgrader only)
|
||||
self.assertEqual(upgraded_json["schema_version"]["major"], 2)
|
||||
self.assertEqual(upgraded_json["schema_version"]["minor"], 0)
|
||||
|
||||
# Verify new_test_field was renamed to new_test_field2
|
||||
graph = upgraded_json["graph_module"]["graph"]
|
||||
self.assertIn("new_test_field2", graph)
|
||||
self.assertEqual(graph["new_test_field2"], "test_value")
|
||||
self.assertNotIn("new_test_field", graph)
|
||||
|
||||
# Verify existing fields are preserved
|
||||
self.assertEqual(graph["existing_field"], "existing_value")
|
||||
self.assertIn("inputs", graph)
|
||||
self.assertIn("outputs", graph)
|
||||
self.assertIn("nodes", graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
89
torch/csrc/export/example_upgraders.cpp
Normal file
89
torch/csrc/export/example_upgraders.cpp
Normal file
@ -0,0 +1,89 @@
|
||||
#include <torch/csrc/export/example_upgraders.h>
|
||||
#include <torch/csrc/export/upgrader.h>
|
||||
|
||||
namespace torch::_export {
|
||||
|
||||
/// Register test upgraders for the upgrader system.
|
||||
/// and shows some common upgrade patterns.
|
||||
static bool test_upgraders_registered = false;
|
||||
|
||||
void registerExampleUpgraders() {
|
||||
if (test_upgraders_registered) {
|
||||
return;
|
||||
}
|
||||
|
||||
registerUpgrader(
|
||||
0,
|
||||
"graph_module.graph.nodes",
|
||||
[](const nlohmann::json& nodes_array) -> nlohmann::json {
|
||||
nlohmann::json upgraded_nodes = nodes_array;
|
||||
|
||||
// Process each node in the nodes array
|
||||
for (auto& node : upgraded_nodes) {
|
||||
if (node.contains("metadata") && node["metadata"].is_object()) {
|
||||
// Process each metadata key-value pair
|
||||
for (auto& [key, value] : node["metadata"].items()) {
|
||||
if (key == "nn_module_stack") {
|
||||
// Transform nn_module_stack values by prepending prefix
|
||||
if (value.is_string()) {
|
||||
std::string stack_str = value.get<std::string>();
|
||||
value = "test_upgrader_" + stack_str;
|
||||
} else {
|
||||
throwUpgraderError(
|
||||
"version_0_upgrader_registered",
|
||||
0,
|
||||
"nn_module_stack metadata value must be a string, got: " +
|
||||
std::string(value.type_name()),
|
||||
node);
|
||||
}
|
||||
}
|
||||
// Other metadata keys remain unchanged
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return upgraded_nodes;
|
||||
});
|
||||
|
||||
registerUpgrader(
|
||||
0,
|
||||
"graph_module.graph",
|
||||
[](const nlohmann::json& graph_obj) -> nlohmann::json {
|
||||
nlohmann::json upgraded_graph = graph_obj;
|
||||
|
||||
// Rename field if it exists in the graph object
|
||||
if (upgraded_graph.contains("old_test_field")) {
|
||||
upgraded_graph["new_test_field"] = upgraded_graph["old_test_field"];
|
||||
upgraded_graph.erase("old_test_field");
|
||||
}
|
||||
|
||||
return upgraded_graph;
|
||||
});
|
||||
|
||||
registerUpgrader(
|
||||
1,
|
||||
std::vector<std::string>{"graph_module", "graph"},
|
||||
[](const nlohmann::json& graph_obj) -> nlohmann::json {
|
||||
nlohmann::json upgraded_graph = graph_obj;
|
||||
|
||||
// Continue the field renaming chain from version 0
|
||||
if (upgraded_graph.contains("new_test_field")) {
|
||||
upgraded_graph["new_test_field2"] = upgraded_graph["new_test_field"];
|
||||
upgraded_graph.erase("new_test_field");
|
||||
}
|
||||
|
||||
return upgraded_graph;
|
||||
});
|
||||
|
||||
test_upgraders_registered = true;
|
||||
}
|
||||
|
||||
/// Deregister test upgraders for the upgrader system.
|
||||
void deregisterExampleUpgraders() {
|
||||
deregisterUpgrader(0, "graph_module.graph.nodes");
|
||||
deregisterUpgrader(0, "graph_module.graph");
|
||||
deregisterUpgrader(1, std::vector<std::string>{"graph_module", "graph"});
|
||||
test_upgraders_registered = false;
|
||||
}
|
||||
|
||||
} // namespace torch::_export
|
15
torch/csrc/export/example_upgraders.h
Normal file
15
torch/csrc/export/example_upgraders.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
namespace torch::_export {
|
||||
|
||||
/// Register example upgraders for the upgrader system for testing.
|
||||
/// This function demonstrates common upgrade patterns and is primarily
|
||||
/// used for testing and demonstration purposes.
|
||||
void registerExampleUpgraders();
|
||||
|
||||
/// Deregister example upgraders for the upgrader system for testing.
|
||||
/// This function cleans up the example upgraders that were registered
|
||||
/// by registerExampleUpgraders().
|
||||
void deregisterExampleUpgraders();
|
||||
|
||||
} // namespace torch::_export
|
@ -1,5 +1,7 @@
|
||||
#include <torch/csrc/export/example_upgraders.h>
|
||||
#include <torch/csrc/export/pt2_archive_constants.h>
|
||||
#include <torch/csrc/export/pybind.h>
|
||||
#include <torch/csrc/export/upgrader.h>
|
||||
#include <torch/csrc/utils/generated_serialization_types.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
@ -15,13 +17,37 @@ void initExportBindings(PyObject* module) {
|
||||
|
||||
exportModule.def(
|
||||
"deserialize_exported_program", [](const std::string& serialized) {
|
||||
return nlohmann::json::parse(serialized).get<ExportedProgram>();
|
||||
auto parsed = nlohmann::json::parse(serialized);
|
||||
|
||||
// Query the current Python schema version as target
|
||||
// TODO: expose schema_version in gneerated_serialization_types.h and
|
||||
// access it here directly.
|
||||
py::module_ schema_module =
|
||||
py::module_::import("torch._export.serde.schema");
|
||||
py::tuple schema_version_tuple = schema_module.attr("SCHEMA_VERSION");
|
||||
int target_version = schema_version_tuple[0].cast<int>();
|
||||
|
||||
auto upgraded = upgrade(parsed, target_version);
|
||||
return upgraded.get<ExportedProgram>();
|
||||
});
|
||||
|
||||
exportModule.def("serialize_exported_program", [](const ExportedProgram& ep) {
|
||||
return nlohmann::json(ep).dump();
|
||||
});
|
||||
|
||||
exportModule.def(
|
||||
"upgrade", [](const std::string& serialized_json, int target_version) {
|
||||
auto parsed = nlohmann::json::parse(serialized_json);
|
||||
auto upgraded = upgrade(parsed, target_version);
|
||||
return upgraded.dump();
|
||||
});
|
||||
|
||||
exportModule.def(
|
||||
"register_example_upgraders", []() { registerExampleUpgraders(); });
|
||||
|
||||
exportModule.def(
|
||||
"deregister_example_upgraders", []() { deregisterExampleUpgraders(); });
|
||||
|
||||
for (const auto& entry : torch::_export::archive_spec::kAllConstants) {
|
||||
pt2ArchiveModule.attr(entry.first) = entry.second;
|
||||
}
|
||||
|
242
torch/csrc/export/upgrader.cpp
Normal file
242
torch/csrc/export/upgrader.cpp
Normal file
@ -0,0 +1,242 @@
|
||||
#include <torch/csrc/export/upgrader.h>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::_export {
|
||||
|
||||
// Global upgrader registry organized by version.
|
||||
// Using std::multiset to maintain automatic bottom-up ordering where
|
||||
// deeper keypaths are processed before shallower ones.
|
||||
static std::map<int, std::multiset<Upgrader>> upgrader_registry;
|
||||
|
||||
static const std::multiset<Upgrader>& getUpgrader(int current_version) {
|
||||
static const std::multiset<Upgrader> empty_upgraders;
|
||||
auto it = upgrader_registry.find(current_version);
|
||||
if (it != upgrader_registry.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return empty_upgraders;
|
||||
}
|
||||
|
||||
static nlohmann::json getFieldByKeypath(
|
||||
const nlohmann::json& obj,
|
||||
const std::vector<std::string>& keypath) {
|
||||
nlohmann::json current = obj;
|
||||
for (const auto& key : keypath) {
|
||||
if (!current.contains(key)) {
|
||||
throw std::runtime_error("Keypath not found: " + key);
|
||||
}
|
||||
current = current[key];
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
static void setFieldByKeypath(
|
||||
nlohmann::json& obj,
|
||||
const std::vector<std::string>& keypath,
|
||||
const nlohmann::json& value) {
|
||||
nlohmann::json* current = &obj;
|
||||
for (size_t i = 0; i < keypath.size() - 1; ++i) {
|
||||
const auto& key = keypath[i];
|
||||
if (!current->contains(key)) {
|
||||
throw std::runtime_error("Keypath not found: " + key);
|
||||
}
|
||||
current = &((*current)[key]);
|
||||
}
|
||||
if (!current->contains(keypath.back())) {
|
||||
throw std::runtime_error("Keypath not found: " + keypath.back());
|
||||
}
|
||||
(*current)[keypath.back()] = value;
|
||||
}
|
||||
|
||||
Upgrader::Upgrader(std::vector<std::string> kp, UpgraderFunction func)
|
||||
: keypath(std::move(kp)), upgrade_func(std::move(func)) {}
|
||||
|
||||
bool Upgrader::operator<(const Upgrader& other) const {
|
||||
// First compare by depth - deeper paths come first for bottom-up processing
|
||||
if (keypath.size() != other.keypath.size()) {
|
||||
return keypath.size() > other.keypath.size();
|
||||
}
|
||||
// If same depth, compare lexicographically for deterministic ordering
|
||||
return keypath < other.keypath;
|
||||
}
|
||||
|
||||
void registerUpgrader(
|
||||
int version,
|
||||
const std::vector<std::string>& keypath,
|
||||
const UpgraderFunction& upgrade_func) {
|
||||
// Check if an upgrader already exists for this version and keypath
|
||||
auto version_it = upgrader_registry.find(version);
|
||||
if (version_it != upgrader_registry.end()) {
|
||||
const auto& upgraders = version_it->second;
|
||||
|
||||
// Search for existing upgrader with the same keypath
|
||||
for (const auto& existing_upgrader : upgraders) {
|
||||
if (existing_upgrader.keypath == keypath) {
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Upgrader already registered for version " << version
|
||||
<< " and keypath: ";
|
||||
for (size_t i = 0; i < keypath.size(); ++i) {
|
||||
if (i > 0)
|
||||
error_stream << ".";
|
||||
error_stream << keypath[i];
|
||||
}
|
||||
throw std::runtime_error(error_stream.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
upgrader_registry[version].emplace(keypath, upgrade_func);
|
||||
}
|
||||
|
||||
void registerUpgrader(
|
||||
int version,
|
||||
const std::string& dot_keypath,
|
||||
const UpgraderFunction& upgrade_func) {
|
||||
// Convert dot-separated keypath to vector and delegate to main implementation
|
||||
std::vector<std::string> keypath_vector;
|
||||
std::stringstream ss(dot_keypath);
|
||||
std::string component;
|
||||
|
||||
while (std::getline(ss, component, '.')) {
|
||||
if (component.empty()) {
|
||||
throw std::invalid_argument("Empty component in keypath: " + dot_keypath);
|
||||
}
|
||||
keypath_vector.push_back(component);
|
||||
}
|
||||
|
||||
if (keypath_vector.empty()) {
|
||||
throw std::invalid_argument("Empty keypath provided");
|
||||
}
|
||||
|
||||
registerUpgrader(version, keypath_vector, upgrade_func);
|
||||
}
|
||||
|
||||
bool deregisterUpgrader(int version, const std::vector<std::string>& keypath) {
|
||||
auto version_it = upgrader_registry.find(version);
|
||||
if (version_it == upgrader_registry.end()) {
|
||||
return false; // Version not found
|
||||
}
|
||||
|
||||
auto& upgraders = version_it->second;
|
||||
|
||||
// Find the upgrader with matching keypath
|
||||
for (auto it = upgraders.begin(); it != upgraders.end(); ++it) {
|
||||
if (it->keypath == keypath) {
|
||||
upgraders.erase(it);
|
||||
|
||||
// If this was the last upgrader for this version, remove the version
|
||||
// entry
|
||||
if (upgraders.empty()) {
|
||||
upgrader_registry.erase(version_it);
|
||||
}
|
||||
|
||||
return true; // Successfully removed
|
||||
}
|
||||
}
|
||||
|
||||
return false; // Upgrader not found
|
||||
}
|
||||
|
||||
bool deregisterUpgrader(int version, const std::string& dot_keypath) {
|
||||
// Convert dot-separated keypath to vector and delegate to main implementation
|
||||
std::vector<std::string> keypath_vector;
|
||||
std::stringstream ss(dot_keypath);
|
||||
std::string component;
|
||||
|
||||
while (std::getline(ss, component, '.')) {
|
||||
if (component.empty()) {
|
||||
throw std::invalid_argument("Empty component in keypath: " + dot_keypath);
|
||||
}
|
||||
keypath_vector.push_back(component);
|
||||
}
|
||||
|
||||
if (keypath_vector.empty()) {
|
||||
throw std::invalid_argument("Empty keypath provided");
|
||||
}
|
||||
|
||||
return deregisterUpgrader(version, keypath_vector);
|
||||
}
|
||||
|
||||
void throwUpgraderError(
|
||||
const std::string& upgrader_name,
|
||||
int from_version,
|
||||
const std::string& error_message,
|
||||
const nlohmann::json& problematic_object) {
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Error in upgrader '" << upgrader_name << "' "
|
||||
<< "while upgrading from version " << from_version
|
||||
<< " to version " << from_version + 1 << ": " << error_message;
|
||||
|
||||
if (!problematic_object.empty()) {
|
||||
error_stream << "\nProblematic object: " << problematic_object.dump(2);
|
||||
}
|
||||
|
||||
throw std::runtime_error(error_stream.str());
|
||||
}
|
||||
|
||||
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
|
||||
auto current_artifact = artifact;
|
||||
|
||||
// Validate that the artifact contains required schema version information
|
||||
if (!current_artifact.contains("schema_version")) {
|
||||
throw std::runtime_error("Missing schema_version field in artifact");
|
||||
}
|
||||
|
||||
int current_version = current_artifact["schema_version"]["major"];
|
||||
|
||||
// Iteratively apply upgraders until target version is reached or no more are
|
||||
// available
|
||||
while (current_version < target_version) {
|
||||
// Look up upgraders for the current version
|
||||
const auto& upgraders = getUpgrader(current_version);
|
||||
|
||||
if (upgraders.empty()) {
|
||||
// No more upgraders available - stop upgrading
|
||||
break;
|
||||
}
|
||||
|
||||
// Apply all upgraders for this version in bottom-up order
|
||||
// (deeper keypaths first to prevent parent/child conflicts)
|
||||
for (const auto& upgrader : upgraders) {
|
||||
// Extract the field to be upgraded using its keypath
|
||||
auto field_to_upgrade =
|
||||
getFieldByKeypath(current_artifact, upgrader.keypath);
|
||||
|
||||
// Apply the upgrade transformation
|
||||
auto upgraded_field = upgrader.upgrade_func(field_to_upgrade);
|
||||
|
||||
// Update the artifact with the upgraded field
|
||||
setFieldByKeypath(current_artifact, upgrader.keypath, upgraded_field);
|
||||
}
|
||||
|
||||
// Move to the next version for potential additional upgrades
|
||||
current_version++;
|
||||
}
|
||||
|
||||
// Update schema version to reflect the final upgraded version
|
||||
if (current_artifact["schema_version"]["major"] != current_version) {
|
||||
current_artifact["schema_version"]["major"] = current_version;
|
||||
// Reset minor version to 0 - the correct minor version should be set
|
||||
// when converting the json to in memory representation of ExportedProgram
|
||||
current_artifact["schema_version"]["minor"] = 0;
|
||||
}
|
||||
|
||||
// Validate that we reached the target version if requested
|
||||
if (current_version != target_version) {
|
||||
std::ostringstream error_stream;
|
||||
error_stream
|
||||
<< "Failed to upgrade to target version " << target_version
|
||||
<< ". Final version reached: " << current_version
|
||||
<< ". This may indicate missing upgraders for intermediate versions.";
|
||||
throw std::runtime_error(error_stream.str());
|
||||
}
|
||||
|
||||
return current_artifact;
|
||||
}
|
||||
|
||||
} // namespace torch::_export
|
118
torch/csrc/export/upgrader.h
Normal file
118
torch/csrc/export/upgrader.h
Normal file
@ -0,0 +1,118 @@
|
||||
#pragma once
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::_export {
|
||||
|
||||
/// Function type for upgrading JSON fields during schema version migration.
|
||||
/// Takes a JSON field and returns the upgraded version of that field.
|
||||
using UpgraderFunction = std::function<nlohmann::json(const nlohmann::json&)>;
|
||||
|
||||
/// Structure containing upgrader information for a specific keypath.
|
||||
/// The version is stored as the map key in the registry, so it's not
|
||||
/// duplicated here.
|
||||
struct Upgrader {
|
||||
/// Path to the field that should be upgraded (e.g., {"graph_module", "graph",
|
||||
/// "nodes"}) Assuming top-level is a JSON object that represents
|
||||
/// ExportedProgram
|
||||
std::vector<std::string> keypath;
|
||||
|
||||
/// Function that performs the actual upgrade transformation
|
||||
UpgraderFunction upgrade_func;
|
||||
|
||||
/// Constructor for creating an upgrader with keypath and function
|
||||
Upgrader(std::vector<std::string> kp, UpgraderFunction func);
|
||||
|
||||
/// Comparator for maintaining bottom-up ordering in the registry.
|
||||
/// Deeper keypaths are processed first to ensure safe upgrade application
|
||||
/// without conflicts between parent and child field modifications.
|
||||
bool operator<(const Upgrader& other) const;
|
||||
};
|
||||
|
||||
/// Register an upgrader function for a specific schema version and keypath.
|
||||
///
|
||||
/// This function allows registration of custom upgrade logic that will be
|
||||
/// applied when upgrading artifacts from the specified version. Upgraders
|
||||
/// are applied in bottom-up order (deeper keypaths first) to prevent
|
||||
/// conflicts between parent and child field modifications.
|
||||
///
|
||||
/// @param version The schema version this upgrader applies to
|
||||
/// @param keypath The key path to the field that should be upgraded
|
||||
/// @param upgrade_func Function that performs the upgrade transformation
|
||||
void registerUpgrader(
|
||||
int version,
|
||||
const std::vector<std::string>& keypath,
|
||||
const UpgraderFunction& upgrade_func);
|
||||
|
||||
/// Register an upgrader function using dot-separated keypath notation.
|
||||
///
|
||||
/// Convenience overload that accepts dot-separated keypath strings for
|
||||
/// simpler syntax. For example: "graph_module.graph.nodes" instead of
|
||||
/// {"graph_module", "graph", "nodes"}.
|
||||
///
|
||||
/// @param version The schema version this upgrader applies to
|
||||
/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes")
|
||||
/// @param upgrade_func Function that performs the upgrade transformation
|
||||
void registerUpgrader(
|
||||
int version,
|
||||
const std::string& dot_keypath,
|
||||
const UpgraderFunction& upgrade_func);
|
||||
|
||||
/// Deregister an upgrader function for a specific schema version and keypath.
|
||||
///
|
||||
/// This function allows removal of previously registered upgrade logic for
|
||||
/// the specified version and keypath. This is useful for testing scenarios
|
||||
/// where you need to clean up registered upgraders or modify upgrader
|
||||
/// behavior dynamically.
|
||||
///
|
||||
/// @param version The schema version to deregister the upgrader from
|
||||
/// @param keypath The key path to the field that should be deregistered
|
||||
/// @return true if an upgrader was found and removed, false otherwise
|
||||
bool deregisterUpgrader(int version, const std::vector<std::string>& keypath);
|
||||
|
||||
/// Deregister an upgrader function using dot-separated keypath notation.
|
||||
///
|
||||
/// Convenience overload that accepts dot-separated keypath strings for
|
||||
/// simpler syntax. For example: "graph_module.graph.nodes" instead of
|
||||
/// {"graph_module", "graph", "nodes"}.
|
||||
///
|
||||
/// @param version The schema version to deregister the upgrader from
|
||||
/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes")
|
||||
/// @return true if an upgrader was found and removed, false otherwise
|
||||
bool deregisterUpgrader(int version, const std::string& dot_keypath);
|
||||
|
||||
/// Utility function for throwing consistent upgrader errors.
|
||||
///
|
||||
/// This function formats error messages in a standardized way for upgrader
|
||||
/// failures, including version information and optional problematic object
|
||||
/// details for debugging.
|
||||
///
|
||||
/// @param upgrader_name Name of the upgrader that failed
|
||||
/// @param from_version Source schema version being upgraded from
|
||||
/// @param error_message Descriptive error message
|
||||
/// @param problematic_object Optional JSON object that caused the error
|
||||
/// @throws std::runtime_error Always throws with formatted error message
|
||||
void throwUpgraderError(
|
||||
const std::string& upgrader_name,
|
||||
int from_version,
|
||||
const std::string& error_message,
|
||||
const nlohmann::json& problematic_object = nlohmann::json::object());
|
||||
|
||||
/// Upgrade a JSON artifact to a specific target version with available
|
||||
/// upgraders until a target version is reached.
|
||||
///
|
||||
/// This handles major version upgrade only. For minor version upgrade,
|
||||
/// e.g. adding a new field with default value, it's automatically handled by
|
||||
/// the default constructor in generated_serialization_types.h.
|
||||
///
|
||||
/// @param artifact The JSON artifact to upgrade
|
||||
/// @param target_version The target schema version to upgrade to
|
||||
/// @return The upgraded JSON artifact with updated schema version
|
||||
/// @throws std::runtime_error if artifact is missing schema_version field
|
||||
/// @throws std::runtime_error if final version doesn't match target version
|
||||
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version);
|
||||
|
||||
} // namespace torch::_export
|
Reference in New Issue
Block a user