mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] Move serialization to PyTorch core (#155229)
Summary: Serialization contains utilities to deserialize a graph saved on disk in json format as defined in `torch/csrc/utils/generated_serialization_types.h` to the in-memory representation as defined in `torch/nativert/graph/Graph.h` Test Plan: buck2 run @mode/dev-nosan caffe2/test/cpp/nativert:serialization_test Rollback Plan: Differential Revision: D76012641 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155229 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e6a653234
commit
eba5fc91ac
@ -593,6 +593,7 @@ libtorch_core_jit_sources = sorted(jit_sources_full)
|
||||
libtorch_nativert_sources = [
|
||||
"torch/nativert/graph/Graph.cpp",
|
||||
"torch/nativert/graph/GraphSignature.cpp",
|
||||
"torch/nativert/graph/Serialization.cpp",
|
||||
"torch/nativert/graph/TensorMeta.cpp",
|
||||
"torch/nativert/executor/Placement.cpp",
|
||||
"torch/nativert/executor/PlacementUtils.cpp",
|
||||
|
@ -8,6 +8,7 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/Weights.cpp
|
||||
${TORCH_ROOT}/torch/nativert/common/FileUtil.cpp
|
||||
|
@ -562,4 +562,86 @@ return (%a)
|
||||
EXPECT_NE(ids.end(), ids.find(i));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SerializationTest, RoundTrip) {
|
||||
static constexpr std::string_view source =
|
||||
R"(graph(%foo, %bar, %baz):
|
||||
%o1 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
||||
return(%o1, %baz)
|
||||
)";
|
||||
const auto graph = stringToGraph(source);
|
||||
const auto serialized = graphToString(*graph);
|
||||
EXPECT_EQ(source, serialized);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, EscapedStringConstant) {
|
||||
const auto parsed =
|
||||
std::get<std::string>(convertAtomicConstant(R"("string_\"escape")"));
|
||||
std::string expected = "string_\\\"escape";
|
||||
EXPECT_EQ(parsed, expected);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, DeviceConstant) {
|
||||
const auto device =
|
||||
std::get<c10::Device>(convertAtomicConstant("Device{cuda:1}"));
|
||||
EXPECT_EQ(device.index(), 1);
|
||||
EXPECT_EQ(device.type(), c10::DeviceType::CUDA);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, TrueConstant) {
|
||||
const auto parsedTrue = std::get<bool>(convertAtomicConstant("true"));
|
||||
EXPECT_EQ(parsedTrue, true);
|
||||
const auto parsedFalse = std::get<bool>(convertAtomicConstant("false"));
|
||||
EXPECT_EQ(parsedFalse, false);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, MemoryFormatConstant) {
|
||||
const auto parsed = std::get<c10::MemoryFormat>(
|
||||
convertAtomicConstant("MemoryFormat::ContiguousFormat"));
|
||||
EXPECT_EQ(parsed, c10::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, FloatConstant) {
|
||||
const auto parsed = std::get<double>(convertAtomicConstant("5.0"));
|
||||
EXPECT_EQ(parsed, 5.0);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, IntConstant) {
|
||||
const auto parsed = std::get<int64_t>(convertAtomicConstant("5"));
|
||||
EXPECT_EQ(parsed, 5);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, FloatExponentConstant) {
|
||||
const auto parsed = std::get<double>(convertAtomicConstant("1e-05"));
|
||||
EXPECT_EQ(parsed, 0.00001);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, SingleElementListConstant) {
|
||||
const auto parsed =
|
||||
std::get<std::vector<int64_t>>(convertListConstant("[1]"));
|
||||
const auto expected = std::vector<int64_t>{1};
|
||||
EXPECT_EQ(parsed, expected);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, IntListConstant) {
|
||||
const auto parsed =
|
||||
std::get<std::vector<int64_t>>(convertListConstant("[1, 2, 3, 4]"));
|
||||
const auto expected = std::vector<int64_t>{1, 2, 3, 4};
|
||||
EXPECT_EQ(parsed, expected);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, FloatListConstant) {
|
||||
const auto parsed = std::get<std::vector<double>>(
|
||||
convertListConstant("[1.0, 2.0, 3.0, 4.0]"));
|
||||
const auto expected = std::vector<double>{1.0, 2.0, 3.0, 4.0};
|
||||
EXPECT_EQ(parsed, expected);
|
||||
}
|
||||
|
||||
TEST(SerializationTest, BoolListConstant) {
|
||||
const auto parsed =
|
||||
std::get<std::vector<bool>>(convertListConstant("[false, true, false]"));
|
||||
const auto expected = std::vector<bool>{false, true, false};
|
||||
EXPECT_EQ(parsed, expected);
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
51
test/cpp/nativert/test_serialization.cpp
Normal file
51
test/cpp/nativert/test_serialization.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/nativert/graph/Serialization.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
TEST(SerializationTest, CheckIsSymbolic) {
|
||||
torch::_export::TensorArgument tensor_arg;
|
||||
torch::_export::Argument as_tensor_arg;
|
||||
as_tensor_arg.set_as_tensor(tensor_arg);
|
||||
EXPECT_TRUE(isSymbolic(as_tensor_arg));
|
||||
|
||||
std::vector<torch::_export::TensorArgument> tensor_args;
|
||||
torch::_export::Argument as_tensors_arg;
|
||||
as_tensors_arg.set_as_tensors(tensor_args);
|
||||
EXPECT_TRUE(isSymbolic(as_tensors_arg));
|
||||
|
||||
torch::_export::SymIntArgument sym_int_arg;
|
||||
torch::_export::Argument as_sym_int_arg;
|
||||
as_sym_int_arg.set_as_sym_int(sym_int_arg);
|
||||
EXPECT_TRUE(isSymbolic(as_sym_int_arg));
|
||||
|
||||
torch::_export::Argument as_int_arg;
|
||||
as_int_arg.set_as_int(static_cast<int64_t>(1));
|
||||
EXPECT_FALSE(isSymbolic(as_int_arg));
|
||||
|
||||
torch::_export::Argument as_bool_arg;
|
||||
as_bool_arg.set_as_bool(true);
|
||||
EXPECT_FALSE(isSymbolic(as_bool_arg));
|
||||
|
||||
torch::_export::Argument as_string_arg;
|
||||
as_string_arg.set_as_string("test_string");
|
||||
EXPECT_FALSE(isSymbolic(as_string_arg));
|
||||
}
|
||||
|
||||
TEST(SerializationTest, ConstantToValue) {
|
||||
torch::_export::Argument as_int_arg;
|
||||
as_int_arg.set_as_int(static_cast<int64_t>(42));
|
||||
auto value = constantToValue(as_int_arg, false);
|
||||
EXPECT_EQ(value, Constant(static_cast<int64_t>(42)));
|
||||
|
||||
torch::_export::Argument as_bool_arg;
|
||||
as_bool_arg.set_as_bool(true);
|
||||
value = constantToValue(as_bool_arg, false);
|
||||
EXPECT_EQ(value, Constant(true));
|
||||
|
||||
torch::_export::Argument as_string_arg;
|
||||
as_string_arg.set_as_string("test_string");
|
||||
value = constantToValue(as_string_arg, false);
|
||||
EXPECT_EQ(value, Constant("test_string"));
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
550
torch/nativert/graph/Serialization.cpp
Normal file
550
torch/nativert/graph/Serialization.cpp
Normal file
@ -0,0 +1,550 @@
|
||||
#include <fmt/format.h>
|
||||
#include <fmt/ostream.h>
|
||||
#include <fmt/ranges.h>
|
||||
#include <torch/nativert/graph/Serialization.h>
|
||||
#include <limits>
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<Graph> jsonToSubgraph(
|
||||
const torch::_export::Graph& jsonGraph,
|
||||
const torch::_export::GraphSignature* signature,
|
||||
bool loadNodeMetadata);
|
||||
|
||||
Value* symbolicToValue(
|
||||
const torch::_export::Argument& arg,
|
||||
Graph& graph,
|
||||
Node* insertBefore) {
|
||||
switch (arg.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR:
|
||||
return graph.getValue(arg.get_as_tensor().get_name());
|
||||
case torch::_export::Argument::Tag::AS_TENSORS: {
|
||||
// Need to insert a list pack node
|
||||
std::vector<Value*> listValue;
|
||||
for (const auto& listEl : arg.get_as_tensors()) {
|
||||
listValue.push_back(graph.getValue(listEl.get_name()));
|
||||
}
|
||||
auto listPack =
|
||||
graph.createListPack(std::move(listValue), Type::Kind::Tensor);
|
||||
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: {
|
||||
// Need to insert a list pack node
|
||||
std::vector<Value*> listValue;
|
||||
for (const auto& listEl : arg.get_as_optional_tensors()) {
|
||||
switch (listEl.tag()) {
|
||||
case torch::_export::OptionalTensorArgument::Tag::AS_TENSOR: {
|
||||
listValue.push_back(
|
||||
graph.getValue(listEl.get_as_tensor().get_name()));
|
||||
break;
|
||||
}
|
||||
case torch::_export::OptionalTensorArgument::Tag::AS_NONE: {
|
||||
listValue.push_back(
|
||||
graph.addValue(std::nullopt, Type::Kind::None, nullptr));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
fmt::format(
|
||||
"Unknown OptionalTensorArgument type: {}",
|
||||
torch::_export::printEnum(listEl.tag())));
|
||||
}
|
||||
}
|
||||
auto listPack = graph.createOptionalListPack(std::move(listValue));
|
||||
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
||||
return graph.getValue(arg.get_as_sym_int().get_as_name());
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INTS: {
|
||||
// Need to insert a list pack node
|
||||
std::vector<Value*> listValue;
|
||||
for (const auto& listEl : arg.get_as_sym_ints()) {
|
||||
switch (listEl.tag()) {
|
||||
case torch::_export::SymIntArgument::Tag::AS_NAME: {
|
||||
listValue.push_back(graph.getValue(listEl.get_as_name()));
|
||||
break;
|
||||
}
|
||||
case torch::_export::SymIntArgument::Tag::AS_INT: {
|
||||
// These are concrete int values in the SymIntList, e.g [s0, 8]
|
||||
// We convert them into a constant Value in graph. These value
|
||||
// doesn't have producer node
|
||||
int64_t value = listEl.get_as_int();
|
||||
TORCH_CHECK(
|
||||
value >= std::numeric_limits<int>::min() &&
|
||||
value <= std::numeric_limits<int>::max());
|
||||
Value* symintValue =
|
||||
graph.createConstantSymIntValue(static_cast<int>(value));
|
||||
listValue.push_back(symintValue);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
fmt::format(
|
||||
"Unknown SymIntArgument type: {}",
|
||||
torch::_export::printEnum(listEl.tag())));
|
||||
}
|
||||
}
|
||||
auto listPack =
|
||||
graph.createListPack(std::move(listValue), Type::Kind::SymInt);
|
||||
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
|
||||
return graph.getValue(arg.get_as_custom_obj().get_name());
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
|
||||
return graph.getValue(arg.get_as_sym_bool().get_as_name());
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
|
||||
return graph.getValue(arg.get_as_sym_float().get_as_name());
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
fmt::format(
|
||||
"This function should only be called with symbolic arguments, got {} instead",
|
||||
torch::_export::printEnum(arg.tag())));
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<
|
||||
std::vector<torch::_export::InputSpec>,
|
||||
std::vector<torch::_export::Argument>>
|
||||
enforceInputOrder(
|
||||
const std::vector<torch::_export::InputSpec>& inputSpecs,
|
||||
const std::vector<torch::_export::Argument>& graphInputs) {
|
||||
// Enforce the order of inputSpecs and graphInputs to be the following:
|
||||
// 1. token
|
||||
// 2. parameter
|
||||
// 3. persistent buffer, non-persistent buffer
|
||||
// 4. tensor_constant
|
||||
// 5. custom_obj
|
||||
// 6. user_input/constant_input
|
||||
std::vector<torch::_export::InputSpec> reorderedInputSpecs;
|
||||
std::vector<torch::_export::Argument> reorderedGraphInputs;
|
||||
std::vector<torch::_export::InputSpec::Tag> desiredOrder = {
|
||||
torch::_export::InputSpec::Tag::TOKEN,
|
||||
torch::_export::InputSpec::Tag::PARAMETER,
|
||||
torch::_export::InputSpec::Tag::BUFFER,
|
||||
torch::_export::InputSpec::Tag::TENSOR_CONSTANT,
|
||||
torch::_export::InputSpec::Tag::CUSTOM_OBJ};
|
||||
|
||||
auto reorder = [&](auto condition) {
|
||||
for (size_t i = 0; i < inputSpecs.size(); ++i) {
|
||||
if (condition(inputSpecs[i])) {
|
||||
reorderedInputSpecs.push_back(inputSpecs[i]);
|
||||
reorderedGraphInputs.push_back(graphInputs[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (const auto& tag : desiredOrder) {
|
||||
if (tag == torch::_export::InputSpec::Tag::BUFFER) {
|
||||
// Add persistent buffers first, then non-persistent
|
||||
reorder([&](const auto& spec) {
|
||||
return spec.tag() == tag && spec.get_buffer().get_persistent();
|
||||
});
|
||||
reorder([&](const auto& spec) {
|
||||
return spec.tag() == tag && !spec.get_buffer().get_persistent();
|
||||
});
|
||||
} else {
|
||||
reorder([&](const auto& spec) { return spec.tag() == tag; });
|
||||
}
|
||||
}
|
||||
|
||||
// Append USER_INPUT and CONSTANT_INPUT without reordering
|
||||
for (size_t i = 0; i < inputSpecs.size(); ++i) {
|
||||
auto tag = inputSpecs[i].tag();
|
||||
if (tag == torch::_export::InputSpec::Tag::USER_INPUT ||
|
||||
tag == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
|
||||
reorderedInputSpecs.push_back(inputSpecs[i]);
|
||||
reorderedGraphInputs.push_back(graphInputs[i]);
|
||||
}
|
||||
}
|
||||
return {std::move(reorderedInputSpecs), std::move(reorderedGraphInputs)};
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> jsonToSubgraph(
|
||||
const torch::_export::Graph& jsonGraph,
|
||||
const torch::_export::GraphSignature* signature,
|
||||
bool loadNodeMetadata) {
|
||||
auto graphInputs = jsonGraph.get_inputs();
|
||||
auto graph = Graph::createGraph();
|
||||
|
||||
if (signature) {
|
||||
// enforcing the order signature inputspecs and graph inputs
|
||||
const auto& inputSpecs = signature->get_input_specs();
|
||||
|
||||
auto [reorderedInputSpecs, reorderedGraphInputs] =
|
||||
enforceInputOrder(inputSpecs, graphInputs);
|
||||
|
||||
graphInputs = std::move(reorderedGraphInputs);
|
||||
auto reorderedSignature = *signature;
|
||||
reorderedSignature.set_input_specs(reorderedInputSpecs);
|
||||
graph->setSignature(torch::nativert::GraphSignature{reorderedSignature});
|
||||
}
|
||||
|
||||
for (const auto& input : graphInputs) {
|
||||
if (isSymbolic(input)) {
|
||||
switch (input.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR: {
|
||||
const auto& asTensor = input.get_as_tensor();
|
||||
const auto& name = asTensor.get_name();
|
||||
graph->addInput(name, Type::Kind::Tensor);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
|
||||
const auto& asCustomObj = input.get_as_custom_obj();
|
||||
const std::string& name = asCustomObj.get_name();
|
||||
const std::string& classFqn = asCustomObj.get_class_fqn();
|
||||
graph->addInput(name, Type(Type::Kind::CustomObj, classFqn));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
fmt::format(
|
||||
"Unsupported symbolic graph input type: {}",
|
||||
torch::_export::printEnum(input.tag())));
|
||||
}
|
||||
} else {
|
||||
switch (input.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_INT:
|
||||
case torch::_export::Argument::Tag::AS_FLOAT:
|
||||
case torch::_export::Argument::Tag::AS_STRING:
|
||||
case torch::_export::Argument::Tag::AS_BOOL:
|
||||
case torch::_export::Argument::Tag::AS_NONE: {
|
||||
// Constant graph inputs are specialized in the graph, here we simply
|
||||
// add a nullptr of Value to the graph input node.
|
||||
graph->addInput();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
fmt::format(
|
||||
"Unsupported constant graph input type: {}",
|
||||
torch::_export::printEnum(input.tag())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& jsonNode : jsonGraph.get_nodes()) {
|
||||
auto node = graph->insertNode(
|
||||
jsonNode.get_target(),
|
||||
{},
|
||||
loadNodeMetadata ? jsonNode.get_metadata()
|
||||
: std::unordered_map<std::string, std::string>());
|
||||
|
||||
std::vector<NamedArgument> args;
|
||||
std::vector<Attribute> attributes;
|
||||
for (const auto& input : jsonNode.get_inputs()) {
|
||||
// We handle constants and symbolic inputs differently.
|
||||
const auto& arg = input.get_arg();
|
||||
if (isSymbolic(arg)) {
|
||||
// Symbolic values are made part of the inputs to the node
|
||||
node->addInput(NamedArgument{
|
||||
input.get_name(), symbolicToValue(input.get_arg(), *graph, node)});
|
||||
} else if (arg.tag() == torch::_export::Argument::Tag::AS_NONE) {
|
||||
node->addInput(NamedArgument{
|
||||
input.get_name(),
|
||||
graph->addValue(std::nullopt, Type::Kind::None, node)});
|
||||
} else {
|
||||
node->addAttribute(Attribute{
|
||||
input.get_name(),
|
||||
constantToValue(input.get_arg(), loadNodeMetadata)});
|
||||
// Constant values are added as "attributes" to the node.
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Value*> outputs;
|
||||
std::vector<Value*> listUnpacksToCreate;
|
||||
for (const auto& output : jsonNode.get_outputs()) {
|
||||
switch (output.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_NONE: {
|
||||
node->addOutput(Type::Kind::None);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_TENSOR: {
|
||||
const auto name = output.get_as_tensor().get_name();
|
||||
node->addOutput(name, Type::Kind::Tensor);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_TENSORS: {
|
||||
auto outputValue = node->addOutput(
|
||||
graph->getUniqueValueName(), Type::Kind::TensorList);
|
||||
|
||||
Node* listUnpack =
|
||||
graph->insertNode("prim.ListUnpack", {{"input", outputValue}});
|
||||
for (const auto& arg : output.get_as_tensors()) {
|
||||
listUnpack->addOutput(arg.get_name(), Type::Kind::Tensor);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
||||
const auto name = output.get_as_sym_int().get_as_name();
|
||||
node->addOutput(name, Type::Kind::SymInt);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INTS: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"SymInts NYI. We currently don't have ops that produce SymInts as output");
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
|
||||
const auto name = output.get_as_sym_bool().get_as_name();
|
||||
node->addOutput(name, Type::Kind::SymBool);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOLS: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"SymBools NYI. We currently don't have ops that produce SymBools as output");
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
|
||||
const auto name = output.get_as_sym_float().get_as_name();
|
||||
node->addOutput(name, Type::Kind::SymFloat);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"SymFloats NYI. We currently doesn't have op that produces SymFloats as output");
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
fmt::format(
|
||||
"Unsupported graph output type: {}",
|
||||
torch::_export::printEnum(output.tag())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& output : jsonGraph.get_outputs()) {
|
||||
// handle symbolic outputs and constant outputs differently
|
||||
if (isSymbolic(output)) {
|
||||
switch (output.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR: {
|
||||
const auto& asTensor = output.get_as_tensor();
|
||||
const auto& name = asTensor.get_name();
|
||||
Value* outputValue = graph->getValue(name);
|
||||
graph->addOutput(outputValue);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
||||
const auto& asSymInt = output.get_as_sym_int();
|
||||
TORCH_CHECK(
|
||||
asSymInt.tag() == torch::_export::SymIntArgument::Tag::AS_NAME);
|
||||
const auto& name = asSymInt.get_as_name();
|
||||
Value* outputValue = graph->getValue(name);
|
||||
graph->addOutput(outputValue);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
fmt::format(
|
||||
"Unsupported graph output type: {}",
|
||||
torch::_export::printEnum(output.tag())));
|
||||
}
|
||||
} else {
|
||||
Constant constValue = constantToValue(output, loadNodeMetadata);
|
||||
graph->addConstantOutput(std::move(constValue));
|
||||
}
|
||||
}
|
||||
|
||||
auto jsonTensorValue = jsonGraph.get_tensor_values();
|
||||
|
||||
if (!signature) {
|
||||
// For subgraphs we just need to derive a graph signature that only
|
||||
// contains user inputs and outputs, because we don't need to handle any
|
||||
// special semantics for them, e.g. mutation or gradients.
|
||||
torch::_export::GraphSignature sig;
|
||||
std::vector<torch::_export::InputSpec> inputSpecs;
|
||||
for (const auto& input : graph->inputs()) {
|
||||
torch::_export::Argument arg;
|
||||
if (input->type().kind() == Type::Kind::Tensor) {
|
||||
torch::_export::TensorArgument targ;
|
||||
targ.set_name(std::string{input->name()});
|
||||
arg.set_as_tensor(std::move(targ));
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
fmt::format(
|
||||
"Unsupported subgraph input type {}",
|
||||
fmt::streamed(input->type())));
|
||||
}
|
||||
torch::_export::UserInputSpec userInputSpec;
|
||||
userInputSpec.set_arg(std::move(arg));
|
||||
torch::_export::InputSpec inputSpec;
|
||||
inputSpec.set_user_input(std::move(userInputSpec));
|
||||
inputSpecs.push_back(std::move(inputSpec));
|
||||
}
|
||||
sig.set_input_specs(std::move(inputSpecs));
|
||||
|
||||
std::vector<torch::_export::OutputSpec> outputSpecs;
|
||||
for (const auto& output : graph->outputs()) {
|
||||
torch::_export::Argument arg;
|
||||
if (output->type().kind() == Type::Kind::Tensor) {
|
||||
torch::_export::TensorArgument targ;
|
||||
targ.set_name(std::string{output->name()});
|
||||
arg.set_as_tensor(std::move(targ));
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
fmt::format(
|
||||
"Unsupported subgraph output type {}",
|
||||
fmt::streamed(output->type())));
|
||||
}
|
||||
torch::_export::UserOutputSpec userOutputSpec;
|
||||
userOutputSpec.set_arg(std::move(arg));
|
||||
torch::_export::OutputSpec outputSpec;
|
||||
outputSpec.set_user_output(std::move(userOutputSpec));
|
||||
outputSpecs.push_back(std::move(outputSpec));
|
||||
}
|
||||
sig.set_output_specs(std::move(outputSpecs));
|
||||
|
||||
graph->setSignature(torch::nativert::GraphSignature{sig});
|
||||
}
|
||||
|
||||
// weightsTensorMeta are indexed by weight's name, not graph input's name
|
||||
std::unordered_map<std::string, torch::_export::TensorMeta> weightsTensorMeta;
|
||||
for (const auto& [inputName, weightName] :
|
||||
graph->signature().inputsToWeights()) {
|
||||
auto value = graph->getValue(inputName);
|
||||
if (value->type().kind() == Type::Kind::CustomObj) {
|
||||
// skip setting meta for non-tensor inputs
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = jsonTensorValue.find(inputName);
|
||||
CHECK(it != jsonTensorValue.end())
|
||||
<< "Missing tensor metadata for " << inputName
|
||||
<< "in thriftGraph.tensorValue";
|
||||
weightsTensorMeta[weightName] = it->second;
|
||||
}
|
||||
graph->setWeightsMeta(weightsTensorMeta);
|
||||
|
||||
graph->setTensorValuesMeta(jsonTensorValue);
|
||||
|
||||
graph->finalize();
|
||||
|
||||
graph->lint();
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isSymbolic(const torch::_export::Argument& arg) {
|
||||
switch (arg.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR:
|
||||
case torch::_export::Argument::Tag::AS_TENSORS:
|
||||
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT:
|
||||
case torch::_export::Argument::Tag::AS_SYM_INTS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOL:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOAT:
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOATS:
|
||||
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Constant constantToValue(
|
||||
const torch::_export::Argument& jsonArg,
|
||||
bool loadNodeMetadata) {
|
||||
switch (jsonArg.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_NONE:
|
||||
return torch::nativert::None();
|
||||
case torch::_export::Argument::Tag::AS_INT:
|
||||
return jsonArg.get_as_int();
|
||||
case torch::_export::Argument::Tag::AS_INTS: {
|
||||
std::vector<int64_t> ret;
|
||||
for (const auto& arg : jsonArg.get_as_ints()) {
|
||||
ret.push_back(arg);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_FLOAT:
|
||||
return jsonArg.get_as_float().get();
|
||||
case torch::_export::Argument::Tag::AS_FLOATS: {
|
||||
std::vector<double> ret;
|
||||
for (const auto& arg : jsonArg.get_as_floats()) {
|
||||
ret.push_back(arg.get());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_STRING:
|
||||
return jsonArg.get_as_string();
|
||||
case torch::_export::Argument::Tag::AS_STRINGS: {
|
||||
std::vector<std::string> ret;
|
||||
for (const auto& arg : jsonArg.get_as_strings()) {
|
||||
ret.push_back(arg);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SCALAR_TYPE:
|
||||
return torch::nativert::convertJsonScalarType(
|
||||
jsonArg.get_as_scalar_type());
|
||||
case torch::_export::Argument::Tag::AS_MEMORY_FORMAT:
|
||||
return torch::nativert::convertJsonMemoryFormat(
|
||||
jsonArg.get_as_memory_format());
|
||||
case torch::_export::Argument::Tag::AS_LAYOUT:
|
||||
return torch::nativert::convertJsonLayout(jsonArg.get_as_layout());
|
||||
case torch::_export::Argument::Tag::AS_DEVICE:
|
||||
return torch::nativert::convertJsonDevice(jsonArg.get_as_device());
|
||||
case torch::_export::Argument::Tag::AS_BOOL:
|
||||
return jsonArg.get_as_bool();
|
||||
case torch::_export::Argument::Tag::AS_BOOLS: {
|
||||
std::vector<bool> ret;
|
||||
for (const auto& arg : jsonArg.get_as_bools()) {
|
||||
ret.push_back(arg);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_GRAPH: {
|
||||
return jsonToSubgraph(
|
||||
*jsonArg.get_as_graph().get_graph(), nullptr, loadNodeMetadata);
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_TENSOR:
|
||||
case torch::_export::Argument::Tag::AS_TENSORS:
|
||||
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
|
||||
TORCH_CHECK(false, "Tensor values are symbolic, not constant.");
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT:
|
||||
case torch::_export::Argument::Tag::AS_SYM_INTS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOL:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
|
||||
TORCH_CHECK(false, "Symint/Symbool Values are symbolic, not constant.");
|
||||
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
|
||||
TORCH_CHECK(false, "custom obj is symbolic, not constant");
|
||||
case torch::_export::Argument::Tag::AS_OPERATOR:
|
||||
return jsonArg.get_as_operator();
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
|
||||
TORCH_CHECK(false, "SymFloat is not yet implemented");
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
|
||||
TORCH_CHECK(false, "SymFloats is not yet implemented");
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(false, "Got unknown json argument");
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> jsonToGraph(
|
||||
const torch::_export::GraphModule& jsonGraphModule,
|
||||
bool loadNodeMetadata) {
|
||||
auto graph = jsonToSubgraph(
|
||||
jsonGraphModule.get_graph(),
|
||||
&jsonGraphModule.get_signature(),
|
||||
loadNodeMetadata);
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
27
torch/nativert/graph/Serialization.h
Normal file
27
torch/nativert/graph/Serialization.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
#include <torch/csrc/utils/generated_serialization_types.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
/**
|
||||
* This file contains serialization utilities for Graph.
|
||||
*
|
||||
* There are two serialized representations we care about:
|
||||
* - Json: stable but hard to work with, not really human readable
|
||||
* - Debug format: human-readable, not stable.
|
||||
*/
|
||||
|
||||
// Json -> Graph
|
||||
std::unique_ptr<Graph> jsonToGraph(
|
||||
const torch::_export::GraphModule& jsonGraph,
|
||||
bool loadNodeMetadata = true);
|
||||
|
||||
bool isSymbolic(const torch::_export::Argument& arg);
|
||||
|
||||
Constant constantToValue(
|
||||
const torch::_export::Argument& jsonArg,
|
||||
bool loadNodeMetadata);
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user