mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Differential Revision: [D77459912](https://our.internmc.facebook.com/intern/diff/D77459912) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156761 Approved by: https://github.com/angelayi
119 lines
5.1 KiB
C++
119 lines
5.1 KiB
C++
#pragma once
|
|
|
|
#include <nlohmann/json.hpp>
|
|
#include <functional>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace torch::_export {
|
|
|
|
/// Function type for upgrading JSON fields during schema version migration.
|
|
/// Takes a JSON field and returns the upgraded version of that field.
|
|
using UpgraderFunction = std::function<nlohmann::json(const nlohmann::json&)>;
|
|
|
|
/// Structure containing upgrader information for a specific keypath.
|
|
/// The version is stored as the map key in the registry, so it's not
|
|
/// duplicated here.
|
|
struct Upgrader {
|
|
/// Path to the field that should be upgraded (e.g., {"graph_module", "graph",
|
|
/// "nodes"}) Assuming top-level is a JSON object that represents
|
|
/// ExportedProgram
|
|
std::vector<std::string> keypath;
|
|
|
|
/// Function that performs the actual upgrade transformation
|
|
UpgraderFunction upgrade_func;
|
|
|
|
/// Constructor for creating an upgrader with keypath and function
|
|
Upgrader(std::vector<std::string> kp, UpgraderFunction func);
|
|
|
|
/// Comparator for maintaining bottom-up ordering in the registry.
|
|
/// Deeper keypaths are processed first to ensure safe upgrade application
|
|
/// without conflicts between parent and child field modifications.
|
|
bool operator<(const Upgrader& other) const;
|
|
};
|
|
|
|
/// Register an upgrader function for a specific schema version and keypath.
|
|
///
|
|
/// This function allows registration of custom upgrade logic that will be
|
|
/// applied when upgrading artifacts from the specified version. Upgraders
|
|
/// are applied in bottom-up order (deeper keypaths first) to prevent
|
|
/// conflicts between parent and child field modifications.
|
|
///
|
|
/// @param version The schema version this upgrader applies to
|
|
/// @param keypath The key path to the field that should be upgraded
|
|
/// @param upgrade_func Function that performs the upgrade transformation
|
|
void registerUpgrader(
|
|
int version,
|
|
const std::vector<std::string>& keypath,
|
|
const UpgraderFunction& upgrade_func);
|
|
|
|
/// Register an upgrader function using dot-separated keypath notation.
|
|
///
|
|
/// Convenience overload that accepts dot-separated keypath strings for
|
|
/// simpler syntax. For example: "graph_module.graph.nodes" instead of
|
|
/// {"graph_module", "graph", "nodes"}.
|
|
///
|
|
/// @param version The schema version this upgrader applies to
|
|
/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes")
|
|
/// @param upgrade_func Function that performs the upgrade transformation
|
|
void registerUpgrader(
|
|
int version,
|
|
const std::string& dot_keypath,
|
|
const UpgraderFunction& upgrade_func);
|
|
|
|
/// Deregister an upgrader function for a specific schema version and keypath.
|
|
///
|
|
/// This function allows removal of previously registered upgrade logic for
|
|
/// the specified version and keypath. This is useful for testing scenarios
|
|
/// where you need to clean up registered upgraders or modify upgrader
|
|
/// behavior dynamically.
|
|
///
|
|
/// @param version The schema version to deregister the upgrader from
|
|
/// @param keypath The key path to the field that should be deregistered
|
|
/// @return true if an upgrader was found and removed, false otherwise
|
|
bool deregisterUpgrader(int version, const std::vector<std::string>& keypath);
|
|
|
|
/// Deregister an upgrader function using dot-separated keypath notation.
|
|
///
|
|
/// Convenience overload that accepts dot-separated keypath strings for
|
|
/// simpler syntax. For example: "graph_module.graph.nodes" instead of
|
|
/// {"graph_module", "graph", "nodes"}.
|
|
///
|
|
/// @param version The schema version to deregister the upgrader from
|
|
/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes")
|
|
/// @return true if an upgrader was found and removed, false otherwise
|
|
bool deregisterUpgrader(int version, const std::string& dot_keypath);
|
|
|
|
/// Utility function for throwing consistent upgrader errors.
|
|
///
|
|
/// This function formats error messages in a standardized way for upgrader
|
|
/// failures, including version information and optional problematic object
|
|
/// details for debugging.
|
|
///
|
|
/// @param upgrader_name Name of the upgrader that failed
|
|
/// @param from_version Source schema version being upgraded from
|
|
/// @param error_message Descriptive error message
|
|
/// @param problematic_object Optional JSON object that caused the error
|
|
/// @throws std::runtime_error Always throws with formatted error message
|
|
void throwUpgraderError(
|
|
const std::string& upgrader_name,
|
|
int from_version,
|
|
const std::string& error_message,
|
|
const nlohmann::json& problematic_object = nlohmann::json::object());
|
|
|
|
/// Upgrade a JSON artifact to a specific target version with available
|
|
/// upgraders until a target version is reached.
|
|
///
|
|
/// This handles major version upgrade only. For minor version upgrade,
|
|
/// e.g. adding a new field with default value, it's automatically handled by
|
|
/// the default constructor in generated_serialization_types.h.
|
|
///
|
|
/// @param artifact The JSON artifact to upgrade
|
|
/// @param target_version The target schema version to upgrade to
|
|
/// @return The upgraded JSON artifact with updated schema version
|
|
/// @throws std::runtime_error if artifact is missing schema_version field
|
|
/// @throws std::runtime_error if final version doesn't match target version
|
|
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version);
|
|
|
|
} // namespace torch::_export
|