mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Summary: Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72 Added an in-memory representation for input and output specs of a graph. The GraphSignature class models the input and output specs of an exported graph produced by torch.export, which holds the graph information deserialized from the pt2 archive package. Runtime relies on the GraphSignature for weight name lookup and weight loading. The serialization schema is defined in torch/_export/serde/schema.py See more at: https://docs.pytorch.org/docs/stable/export.html#torch.export.ExportGraphSignature Test Plan: Added tests under `test/cpp/nativert/test_graph_signature.cpp` Differential Revision: D73895378 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152969 Approved by: https://github.com/swolchok
474 lines
16 KiB
C++
474 lines
16 KiB
C++
#include <c10/util/Exception.h>
|
|
#include <c10/util/Logging.h>
|
|
#include <fmt/format.h>
|
|
#include <fmt/ranges.h>
|
|
#include <nlohmann/json.hpp>
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <iostream>
|
|
|
|
#include <torch/csrc/utils/generated_serialization_types.h>
|
|
#include <torch/nativert/graph/GraphSignature.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
namespace {
|
|
|
|
bool isSymbolicOutput(torch::_export::Argument::Tag t) {
|
|
switch (t) {
|
|
case torch::_export::Argument::Tag::AS_TENSOR:
|
|
case torch::_export::Argument::Tag::AS_TENSORS:
|
|
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
|
|
case torch::_export::Argument::Tag::AS_SYM_BOOL:
|
|
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
|
|
case torch::_export::Argument::Tag::AS_SYM_INT:
|
|
case torch::_export::Argument::Tag::AS_SYM_INTS:
|
|
case torch::_export::Argument::Tag::AS_SYM_FLOAT:
|
|
case torch::_export::Argument::Tag::AS_SYM_FLOATS:
|
|
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
std::pair<std::string, std::string> getSpecDetails(
|
|
const torch::_export::InputSpec& inputSpec) {
|
|
// Retrieve the argument name and spec tag name
|
|
switch (inputSpec.tag()) {
|
|
case torch::_export::InputSpec::Tag::PARAMETER:
|
|
return std::make_pair(
|
|
inputSpec.get_parameter().get_arg().get_name(), "PARAMETER");
|
|
break;
|
|
case torch::_export::InputSpec::Tag::BUFFER:
|
|
return std::make_pair(
|
|
inputSpec.get_buffer().get_arg().get_name(), "BUFFER");
|
|
break;
|
|
case torch::_export::InputSpec::Tag::TENSOR_CONSTANT:
|
|
return std::make_pair(
|
|
inputSpec.get_tensor_constant().get_arg().get_name(),
|
|
"TENSOR_CONSTANT");
|
|
break;
|
|
case torch::_export::InputSpec::Tag::CUSTOM_OBJ:
|
|
return std::make_pair(
|
|
inputSpec.get_custom_obj().get_arg().get_name(), "CUSTOM_OBJ");
|
|
break;
|
|
case torch::_export::InputSpec::Tag::USER_INPUT:
|
|
if (inputSpec.get_user_input().get_arg().tag() ==
|
|
torch::_export::Argument::Tag::AS_TENSOR) {
|
|
return std::make_pair(
|
|
inputSpec.get_user_input().get_arg().get_as_tensor().get_name(),
|
|
"USER_INPUT");
|
|
} else if (
|
|
inputSpec.get_user_input().get_arg().tag() ==
|
|
torch::_export::Argument::Tag::AS_CUSTOM_OBJ) {
|
|
return std::make_pair(
|
|
inputSpec.get_user_input().get_arg().get_as_custom_obj().get_name(),
|
|
"USER_INPUT");
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported USER_INPUT argument type.");
|
|
}
|
|
break;
|
|
case torch::_export::InputSpec::Tag::CONSTANT_INPUT:
|
|
return std::make_pair(
|
|
inputSpec.get_constant_input().get_name(), "CONSTANT_INPUT");
|
|
break;
|
|
case torch::_export::InputSpec::Tag::TOKEN:
|
|
TORCH_CHECK(false, "Token inputs not implemented yet.");
|
|
default:
|
|
TORCH_CHECK(false, "Unknown InputSpec tag encountered.");
|
|
}
|
|
}
|
|
|
|
void checkInputOrders(
|
|
const std::vector<torch::_export::InputSpec>& inputSpecs) {
|
|
// Map each tag to its index in the expected order
|
|
static constexpr std::
|
|
array<std::pair<torch::_export::InputSpec::Tag, uint32_t>, 5>
|
|
tagOrderArray = {
|
|
{{torch::_export::InputSpec::Tag::TOKEN, 0},
|
|
{torch::_export::InputSpec::Tag::PARAMETER, 1},
|
|
{torch::_export::InputSpec::Tag::BUFFER, 2},
|
|
{torch::_export::InputSpec::Tag::TENSOR_CONSTANT, 3},
|
|
{torch::_export::InputSpec::Tag::CUSTOM_OBJ, 4}}};
|
|
uint32_t currentOrderIndex = 0;
|
|
bool seenNonPersistentBuffer = false;
|
|
for (const auto& inputSpec : inputSpecs) {
|
|
if (inputSpec.tag() == torch::_export::InputSpec::Tag::USER_INPUT ||
|
|
inputSpec.tag() == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
|
|
continue;
|
|
}
|
|
auto it = std::find_if(
|
|
tagOrderArray.begin(),
|
|
tagOrderArray.end(),
|
|
[&inputSpec](const auto& pair) {
|
|
return pair.first == inputSpec.tag();
|
|
});
|
|
TORCH_CHECK(
|
|
it != tagOrderArray.end(), "Unknown InputSpec tag encountered.");
|
|
uint32_t tagIndex = it->second;
|
|
if (tagIndex < currentOrderIndex) {
|
|
auto [argName, tagName] = getSpecDetails(inputSpec);
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Input arg {} with InputSpec {} is out of order!",
|
|
argName,
|
|
tagName));
|
|
}
|
|
currentOrderIndex = tagIndex;
|
|
// Additional check for buffers
|
|
if (inputSpec.tag() == torch::_export::InputSpec::Tag::BUFFER) {
|
|
if (!inputSpec.get_buffer().get_persistent()) {
|
|
seenNonPersistentBuffer = true;
|
|
} else {
|
|
TORCH_CHECK(
|
|
!seenNonPersistentBuffer,
|
|
"Persistent buffers must come before non-persistent buffers.");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void checkInputNames(
|
|
const c10::FastSet<std::string>& sigNames,
|
|
const c10::FastSet<std::string>& graphNames) {
|
|
if (sigNames == graphNames) {
|
|
return;
|
|
}
|
|
|
|
std::string errorMsg = fmt::format(
|
|
"Error: Value name difference detected between graph signature and graph nodes:\n"
|
|
"Signature value names:\n[{}]\n"
|
|
"Graph node names:\n[{}]",
|
|
fmt::join(sigNames, ", "),
|
|
fmt::join(graphNames, ", "));
|
|
TORCH_CHECK(false, errorMsg);
|
|
}
|
|
|
|
void checkOutputNames(
|
|
const c10::FastSet<std::optional<std::string>>& sigNames,
|
|
const c10::FastSet<std::string>& graphNames) {
|
|
std::vector<std::string> validNames;
|
|
for (const auto& nameOpt : sigNames) {
|
|
if (nameOpt.has_value()) {
|
|
validNames.push_back(*nameOpt);
|
|
}
|
|
}
|
|
|
|
for (const auto& name : validNames) {
|
|
if (graphNames.find(name) == graphNames.end()) {
|
|
std::string errorMsg = fmt::format(
|
|
"Error: Value name difference detected between graph signature and graph nodes:\n"
|
|
"Signature value names:\n[{}]\n"
|
|
"Graph node names:\n[{}]",
|
|
fmt::join(validNames, ", "),
|
|
fmt::join(graphNames, ", "));
|
|
TORCH_CHECK(false, errorMsg);
|
|
}
|
|
}
|
|
}
|
|
|
|
void replaceInMap(
|
|
c10::FastMap<std::string, std::string>& map,
|
|
std::string_view old,
|
|
std::string_view replacement) {
|
|
auto it = map.find(std::string{old});
|
|
if (it == map.end()) {
|
|
return;
|
|
}
|
|
std::string value = std::move(it->second);
|
|
map.erase(it);
|
|
map.emplace(replacement, std::move(value));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
GraphSignature::GraphSignature(const torch::_export::GraphSignature& storage) {
|
|
checkInputOrders(storage.get_input_specs());
|
|
|
|
for (const torch::_export::InputSpec& inputSpec : storage.get_input_specs()) {
|
|
switch (inputSpec.tag()) {
|
|
case torch::_export::InputSpec::Tag::USER_INPUT: {
|
|
const auto& userInputArg = inputSpec.get_user_input().get_arg();
|
|
if (userInputArg.tag() == torch::_export::Argument::Tag::AS_TENSOR) {
|
|
userInputs_.emplace_back(userInputArg.get_as_tensor().get_name());
|
|
} else if (
|
|
userInputArg.tag() ==
|
|
torch::_export::Argument::Tag::AS_CUSTOM_OBJ) {
|
|
userInputs_.emplace_back(userInputArg.get_as_custom_obj().get_name());
|
|
} else {
|
|
// TODO: handle other types
|
|
TORCH_CHECK(false, "Non tensor inputs not implemented yet.");
|
|
}
|
|
break;
|
|
}
|
|
case torch::_export::InputSpec::Tag::PARAMETER: {
|
|
numParameters_++;
|
|
const auto& inputName = inputSpec.get_parameter().get_arg().get_name();
|
|
const auto& weightName = inputSpec.get_parameter().get_parameter_name();
|
|
inputsToWeights_.emplace_back(inputName, weightName);
|
|
break;
|
|
}
|
|
case torch::_export::InputSpec::Tag::BUFFER: {
|
|
const bool isPersistent = inputSpec.get_buffer().get_persistent();
|
|
const auto& inputName = inputSpec.get_buffer().get_arg().get_name();
|
|
const auto& weightName = inputSpec.get_buffer().get_buffer_name();
|
|
if (isPersistent) {
|
|
numPersistentBuffers_++;
|
|
} else {
|
|
numNonPersistentBuffers_++;
|
|
}
|
|
inputsToWeights_.emplace_back(inputName, weightName);
|
|
break;
|
|
}
|
|
case torch::_export::InputSpec::Tag::TENSOR_CONSTANT: {
|
|
numTensorConstants_++;
|
|
const auto& inputName =
|
|
inputSpec.get_tensor_constant().get_arg().get_name();
|
|
const auto& weightName =
|
|
inputSpec.get_tensor_constant().get_tensor_constant_name();
|
|
inputsToWeights_.emplace_back(inputName, weightName);
|
|
break;
|
|
}
|
|
case torch::_export::InputSpec::Tag::CUSTOM_OBJ: {
|
|
numCustomObjs_++;
|
|
const auto& inputName = inputSpec.get_custom_obj().get_arg().get_name();
|
|
const auto& customObjName =
|
|
inputSpec.get_custom_obj().get_custom_obj_name();
|
|
inputsToCustomObjs_.emplace_back(inputName, customObjName);
|
|
break;
|
|
}
|
|
case torch::_export::InputSpec::Tag::CONSTANT_INPUT: {
|
|
break;
|
|
}
|
|
case torch::_export::InputSpec::Tag::TOKEN: {
|
|
TORCH_CHECK(false, "Token inputs not implemented yet.");
|
|
}
|
|
default:
|
|
TORCH_CHECK(false, "Unknown InputSpec tag encountered.");
|
|
break;
|
|
}
|
|
}
|
|
|
|
for (const torch::_export::OutputSpec& outputSpec :
|
|
storage.get_output_specs()) {
|
|
switch (outputSpec.tag()) {
|
|
case torch::_export::OutputSpec::Tag::LOSS_OUTPUT:
|
|
lossOutput_ = outputSpec.get_loss_output().get_arg().get_name();
|
|
break;
|
|
case torch::_export::OutputSpec::Tag::USER_OUTPUT:
|
|
if (isSymbolicOutput(outputSpec.get_user_output().get_arg().tag())) {
|
|
switch (outputSpec.get_user_output().get_arg().tag()) {
|
|
case torch::_export::Argument::Tag::AS_TENSOR: {
|
|
userOutputs_.emplace_back(outputSpec.get_user_output()
|
|
.get_arg()
|
|
.get_as_tensor()
|
|
.get_name());
|
|
break;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
|
userOutputs_.emplace_back(outputSpec.get_user_output()
|
|
.get_arg()
|
|
.get_as_sym_int()
|
|
.get_as_name());
|
|
break;
|
|
}
|
|
default: {
|
|
TORCH_CHECK(
|
|
false, "Unsupported symbolic user output type encountered.");
|
|
}
|
|
}
|
|
} else {
|
|
// for constant outputs, we don't have a name
|
|
userOutputs_.emplace_back(std::nullopt);
|
|
}
|
|
break;
|
|
case torch::_export::OutputSpec::Tag::BUFFER_MUTATION:
|
|
buffersToMutate_.emplace(
|
|
outputSpec.get_buffer_mutation().get_arg().get_name(),
|
|
outputSpec.get_buffer_mutation().get_buffer_name());
|
|
break;
|
|
case torch::_export::OutputSpec::Tag::GRADIENT_TO_PARAMETER:
|
|
gradientsToParameters_.emplace(
|
|
outputSpec.get_gradient_to_parameter().get_arg().get_name(),
|
|
outputSpec.get_gradient_to_parameter().get_parameter_name());
|
|
break;
|
|
case torch::_export::OutputSpec::Tag::GRADIENT_TO_USER_INPUT:
|
|
gradientsToUserInputs_.emplace(
|
|
outputSpec.get_gradient_to_user_input().get_arg().get_name(),
|
|
outputSpec.get_gradient_to_user_input().get_user_input_name());
|
|
break;
|
|
case torch::_export::OutputSpec::Tag::USER_INPUT_MUTATION:
|
|
userInputsToMutate_.emplace(
|
|
outputSpec.get_user_input_mutation().get_arg().get_name(),
|
|
outputSpec.get_user_input_mutation().get_user_input_name());
|
|
break;
|
|
case torch::_export::OutputSpec::Tag::TOKEN: {
|
|
TORCH_CHECK(false, "Token outputs not implemented yet.");
|
|
}
|
|
default:
|
|
TORCH_CHECK(false, "Unknown OutputSpec tag encountered.");
|
|
}
|
|
}
|
|
|
|
if (FLAGS_caffe2_log_level > 2) {
|
|
std::cout << *this << "\n";
|
|
}
|
|
}
|
|
|
|
c10::FastSet<std::string> GraphSignature::inputNames() const {
|
|
c10::FastSet<std::string> ret;
|
|
size_t numInputs = userInputs().size() + inputsToWeights().size() +
|
|
inputsToCustomObjs().size();
|
|
ret.reserve(numInputs);
|
|
for (const auto& name : userInputs()) {
|
|
ret.insert(name);
|
|
}
|
|
for (const auto& [inputName, _] : inputsToWeights()) {
|
|
ret.insert(inputName);
|
|
}
|
|
for (const auto& [inputName, _] : inputsToCustomObjs()) {
|
|
ret.insert(inputName);
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
c10::FastSet<std::optional<std::string>> GraphSignature::outputNames() const {
|
|
c10::FastSet<std::optional<std::string>> ret;
|
|
size_t numOutputs = userOutputs().size() + buffersToMutate().size() +
|
|
userInputsToMutate().size() +
|
|
(hasBackward() ? gradientsToParameters().size() +
|
|
gradientsToUserInputs().size() + (lossOutput().empty() ? 0 : 1)
|
|
: 0);
|
|
ret.reserve(numOutputs);
|
|
for (const auto& name : userOutputs()) {
|
|
ret.insert(name);
|
|
}
|
|
for (const auto& [outputName, _] : buffersToMutate()) {
|
|
ret.insert(outputName);
|
|
}
|
|
for (const auto& [outputName, _] : userInputsToMutate()) {
|
|
ret.insert(outputName);
|
|
}
|
|
if (hasBackward()) {
|
|
if (!gradientsToParameters().empty()) {
|
|
for (const auto& [outputName, _] : gradientsToParameters()) {
|
|
ret.insert(outputName);
|
|
}
|
|
}
|
|
if (!gradientsToUserInputs().empty()) {
|
|
for (const auto& [outputName, _] : gradientsToUserInputs()) {
|
|
ret.insert(outputName);
|
|
}
|
|
}
|
|
if (!lossOutput().empty()) {
|
|
ret.insert(lossOutput());
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
void GraphSignature::lint(
|
|
const c10::FastSet<std::string>& graphInputs,
|
|
const c10::FastSet<std::string>& graphOutputs) const {
|
|
checkInputNames(inputNames(), graphInputs);
|
|
checkOutputNames(outputNames(), graphOutputs);
|
|
}
|
|
|
|
void GraphSignature::replaceAllUses(
|
|
std::string_view old,
|
|
std::string_view replacement) {
|
|
if (old == replacement) {
|
|
return;
|
|
}
|
|
for (auto& name : userOutputs_) {
|
|
if (name == old) {
|
|
name = replacement;
|
|
}
|
|
}
|
|
replaceInMap(buffersToMutate_, old, replacement);
|
|
if (hasBackward()) {
|
|
replaceInMap(gradientsToParameters_, old, replacement);
|
|
replaceInMap(gradientsToUserInputs_, old, replacement);
|
|
if (old == lossOutput_) {
|
|
lossOutput_ = replacement;
|
|
}
|
|
}
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const GraphSignature& sig) {
|
|
if (!sig.inputsToParameters().empty()) {
|
|
out << "inputsToParameters: {\n";
|
|
for (const auto& [inputName, paramName] : sig.inputsToParameters()) {
|
|
out << "\t" << inputName << " : " << paramName << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|
|
if (!sig.inputsToBuffers().empty()) {
|
|
out << "inputsToBuffers: {\n";
|
|
for (const auto& [inputName, bufferName] : sig.inputsToBuffers()) {
|
|
out << "\t" << inputName << " : " << bufferName << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|
|
if (!sig.inputsToTensorConstants().empty()) {
|
|
out << "inputsToTensorConstants: {\n";
|
|
for (const auto& [inputName, tensorConstantName] :
|
|
sig.inputsToTensorConstants()) {
|
|
out << "\t" << inputName << " : " << tensorConstantName << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|
|
if (!sig.inputsToCustomObjs().empty()) {
|
|
out << "inputsToCustomObjs: {\n";
|
|
for (const auto& [inputName, customObjName] : sig.inputsToCustomObjs()) {
|
|
out << "\t" << inputName << " : " << customObjName << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|
|
if (!sig.userOutputs().empty()) {
|
|
out << "userOutputs: {\n";
|
|
for (const auto& outputName : sig.userOutputs()) {
|
|
out << "\t" << outputName.value_or("Constant") << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|
|
if (!sig.buffersToMutate().empty()) {
|
|
out << "buffersToMutate: {\n";
|
|
for (const auto& [outputName, mutatedBufferName] : sig.buffersToMutate()) {
|
|
out << "\t" << outputName << " : " << mutatedBufferName << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|
|
if (!sig.userInputsToMutate().empty()) {
|
|
out << "userInputsToMutate: {\n";
|
|
for (const auto& [outputName, mutatedUserInputName] :
|
|
sig.userInputsToMutate()) {
|
|
out << "\t" << outputName << " : " << mutatedUserInputName << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|
|
if (sig.hasBackward()) {
|
|
if (!sig.gradientsToParameters().empty()) {
|
|
out << "gradientsToParameters: {\n";
|
|
for (const auto& [outputName, paramName] : sig.gradientsToParameters()) {
|
|
out << "\t" << outputName << " : " << paramName << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|
|
if (!sig.gradientsToUserInputs().empty()) {
|
|
out << "gradientsToUserInputs: {\n";
|
|
for (const auto& [outputName, userInputName] :
|
|
sig.gradientsToUserInputs()) {
|
|
out << "\t" << outputName << " : " << userInputName << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|
|
out << "lossOutput: " << sig.lossOutput() << "\n";
|
|
}
|
|
return out;
|
|
}
|
|
|
|
} // namespace torch::nativert
|