mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Differential Revision: [D77459912](https://our.internmc.facebook.com/intern/diff/D77459912) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156761 Approved by: https://github.com/angelayi
243 lines
7.7 KiB
C++
243 lines
7.7 KiB
C++
#include <torch/csrc/export/upgrader.h>
|
|
#include <limits>
|
|
#include <map>
|
|
#include <set>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <vector>
|
|
|
|
namespace torch::_export {
|
|
|
|
// Global upgrader registry organized by version.
|
|
// Using std::multiset to maintain automatic bottom-up ordering where
|
|
// deeper keypaths are processed before shallower ones.
|
|
static std::map<int, std::multiset<Upgrader>> upgrader_registry;
|
|
|
|
static const std::multiset<Upgrader>& getUpgrader(int current_version) {
|
|
static const std::multiset<Upgrader> empty_upgraders;
|
|
auto it = upgrader_registry.find(current_version);
|
|
if (it != upgrader_registry.end()) {
|
|
return it->second;
|
|
}
|
|
return empty_upgraders;
|
|
}
|
|
|
|
static nlohmann::json getFieldByKeypath(
|
|
const nlohmann::json& obj,
|
|
const std::vector<std::string>& keypath) {
|
|
nlohmann::json current = obj;
|
|
for (const auto& key : keypath) {
|
|
if (!current.contains(key)) {
|
|
throw std::runtime_error("Keypath not found: " + key);
|
|
}
|
|
current = current[key];
|
|
}
|
|
return current;
|
|
}
|
|
|
|
static void setFieldByKeypath(
|
|
nlohmann::json& obj,
|
|
const std::vector<std::string>& keypath,
|
|
const nlohmann::json& value) {
|
|
nlohmann::json* current = &obj;
|
|
for (size_t i = 0; i < keypath.size() - 1; ++i) {
|
|
const auto& key = keypath[i];
|
|
if (!current->contains(key)) {
|
|
throw std::runtime_error("Keypath not found: " + key);
|
|
}
|
|
current = &((*current)[key]);
|
|
}
|
|
if (!current->contains(keypath.back())) {
|
|
throw std::runtime_error("Keypath not found: " + keypath.back());
|
|
}
|
|
(*current)[keypath.back()] = value;
|
|
}
|
|
|
|
Upgrader::Upgrader(std::vector<std::string> kp, UpgraderFunction func)
|
|
: keypath(std::move(kp)), upgrade_func(std::move(func)) {}
|
|
|
|
bool Upgrader::operator<(const Upgrader& other) const {
|
|
// First compare by depth - deeper paths come first for bottom-up processing
|
|
if (keypath.size() != other.keypath.size()) {
|
|
return keypath.size() > other.keypath.size();
|
|
}
|
|
// If same depth, compare lexicographically for deterministic ordering
|
|
return keypath < other.keypath;
|
|
}
|
|
|
|
void registerUpgrader(
|
|
int version,
|
|
const std::vector<std::string>& keypath,
|
|
const UpgraderFunction& upgrade_func) {
|
|
// Check if an upgrader already exists for this version and keypath
|
|
auto version_it = upgrader_registry.find(version);
|
|
if (version_it != upgrader_registry.end()) {
|
|
const auto& upgraders = version_it->second;
|
|
|
|
// Search for existing upgrader with the same keypath
|
|
for (const auto& existing_upgrader : upgraders) {
|
|
if (existing_upgrader.keypath == keypath) {
|
|
std::ostringstream error_stream;
|
|
error_stream << "Upgrader already registered for version " << version
|
|
<< " and keypath: ";
|
|
for (size_t i = 0; i < keypath.size(); ++i) {
|
|
if (i > 0)
|
|
error_stream << ".";
|
|
error_stream << keypath[i];
|
|
}
|
|
throw std::runtime_error(error_stream.str());
|
|
}
|
|
}
|
|
}
|
|
|
|
upgrader_registry[version].emplace(keypath, upgrade_func);
|
|
}
|
|
|
|
void registerUpgrader(
|
|
int version,
|
|
const std::string& dot_keypath,
|
|
const UpgraderFunction& upgrade_func) {
|
|
// Convert dot-separated keypath to vector and delegate to main implementation
|
|
std::vector<std::string> keypath_vector;
|
|
std::stringstream ss(dot_keypath);
|
|
std::string component;
|
|
|
|
while (std::getline(ss, component, '.')) {
|
|
if (component.empty()) {
|
|
throw std::invalid_argument("Empty component in keypath: " + dot_keypath);
|
|
}
|
|
keypath_vector.push_back(component);
|
|
}
|
|
|
|
if (keypath_vector.empty()) {
|
|
throw std::invalid_argument("Empty keypath provided");
|
|
}
|
|
|
|
registerUpgrader(version, keypath_vector, upgrade_func);
|
|
}
|
|
|
|
bool deregisterUpgrader(int version, const std::vector<std::string>& keypath) {
|
|
auto version_it = upgrader_registry.find(version);
|
|
if (version_it == upgrader_registry.end()) {
|
|
return false; // Version not found
|
|
}
|
|
|
|
auto& upgraders = version_it->second;
|
|
|
|
// Find the upgrader with matching keypath
|
|
for (auto it = upgraders.begin(); it != upgraders.end(); ++it) {
|
|
if (it->keypath == keypath) {
|
|
upgraders.erase(it);
|
|
|
|
// If this was the last upgrader for this version, remove the version
|
|
// entry
|
|
if (upgraders.empty()) {
|
|
upgrader_registry.erase(version_it);
|
|
}
|
|
|
|
return true; // Successfully removed
|
|
}
|
|
}
|
|
|
|
return false; // Upgrader not found
|
|
}
|
|
|
|
bool deregisterUpgrader(int version, const std::string& dot_keypath) {
|
|
// Convert dot-separated keypath to vector and delegate to main implementation
|
|
std::vector<std::string> keypath_vector;
|
|
std::stringstream ss(dot_keypath);
|
|
std::string component;
|
|
|
|
while (std::getline(ss, component, '.')) {
|
|
if (component.empty()) {
|
|
throw std::invalid_argument("Empty component in keypath: " + dot_keypath);
|
|
}
|
|
keypath_vector.push_back(component);
|
|
}
|
|
|
|
if (keypath_vector.empty()) {
|
|
throw std::invalid_argument("Empty keypath provided");
|
|
}
|
|
|
|
return deregisterUpgrader(version, keypath_vector);
|
|
}
|
|
|
|
void throwUpgraderError(
|
|
const std::string& upgrader_name,
|
|
int from_version,
|
|
const std::string& error_message,
|
|
const nlohmann::json& problematic_object) {
|
|
std::ostringstream error_stream;
|
|
error_stream << "Error in upgrader '" << upgrader_name << "' "
|
|
<< "while upgrading from version " << from_version
|
|
<< " to version " << from_version + 1 << ": " << error_message;
|
|
|
|
if (!problematic_object.empty()) {
|
|
error_stream << "\nProblematic object: " << problematic_object.dump(2);
|
|
}
|
|
|
|
throw std::runtime_error(error_stream.str());
|
|
}
|
|
|
|
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
|
|
auto current_artifact = artifact;
|
|
|
|
// Validate that the artifact contains required schema version information
|
|
if (!current_artifact.contains("schema_version")) {
|
|
throw std::runtime_error("Missing schema_version field in artifact");
|
|
}
|
|
|
|
int current_version = current_artifact["schema_version"]["major"];
|
|
|
|
// Iteratively apply upgraders until target version is reached or no more are
|
|
// available
|
|
while (current_version < target_version) {
|
|
// Look up upgraders for the current version
|
|
const auto& upgraders = getUpgrader(current_version);
|
|
|
|
if (upgraders.empty()) {
|
|
// No more upgraders available - stop upgrading
|
|
break;
|
|
}
|
|
|
|
// Apply all upgraders for this version in bottom-up order
|
|
// (deeper keypaths first to prevent parent/child conflicts)
|
|
for (const auto& upgrader : upgraders) {
|
|
// Extract the field to be upgraded using its keypath
|
|
auto field_to_upgrade =
|
|
getFieldByKeypath(current_artifact, upgrader.keypath);
|
|
|
|
// Apply the upgrade transformation
|
|
auto upgraded_field = upgrader.upgrade_func(field_to_upgrade);
|
|
|
|
// Update the artifact with the upgraded field
|
|
setFieldByKeypath(current_artifact, upgrader.keypath, upgraded_field);
|
|
}
|
|
|
|
// Move to the next version for potential additional upgrades
|
|
current_version++;
|
|
}
|
|
|
|
// Update schema version to reflect the final upgraded version
|
|
if (current_artifact["schema_version"]["major"] != current_version) {
|
|
current_artifact["schema_version"]["major"] = current_version;
|
|
// Reset minor version to 0 - the correct minor version should be set
|
|
// when converting the json to in memory representation of ExportedProgram
|
|
current_artifact["schema_version"]["minor"] = 0;
|
|
}
|
|
|
|
// Validate that we reached the target version if requested
|
|
if (current_version != target_version) {
|
|
std::ostringstream error_stream;
|
|
error_stream
|
|
<< "Failed to upgrade to target version " << target_version
|
|
<< ". Final version reached: " << current_version
|
|
<< ". This may indicate missing upgraders for intermediate versions.";
|
|
throw std::runtime_error(error_stream.str());
|
|
}
|
|
|
|
return current_artifact;
|
|
}
|
|
|
|
} // namespace torch::_export
|