[nativert] Move serialization to PyTorch core (#155229)

Summary:
Serialization contains utilities to deserialize a graph saved on disk in json format as defined in `torch/csrc/utils/generated_serialization_types.h` to the in-memory representation as defined in `torch/nativert/graph/Graph.h`

Test Plan:
buck2 run @mode/dev-nosan caffe2/test/cpp/nativert:serialization_test

Rollback Plan:

Differential Revision: D76012641

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155229
Approved by: https://github.com/zhxchen17
This commit is contained in:
Yiming Zhou
2025-06-09 21:12:30 +00:00
committed by PyTorch MergeBot
parent 1e6a653234
commit eba5fc91ac
6 changed files with 712 additions and 0 deletions

View File

@ -593,6 +593,7 @@ libtorch_core_jit_sources = sorted(jit_sources_full)
libtorch_nativert_sources = [
"torch/nativert/graph/Graph.cpp",
"torch/nativert/graph/GraphSignature.cpp",
"torch/nativert/graph/Serialization.cpp",
"torch/nativert/graph/TensorMeta.cpp",
"torch/nativert/executor/Placement.cpp",
"torch/nativert/executor/PlacementUtils.cpp",

View File

@ -8,6 +8,7 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp
${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp
${TORCH_ROOT}/torch/nativert/executor/Weights.cpp
${TORCH_ROOT}/torch/nativert/common/FileUtil.cpp

View File

@ -562,4 +562,86 @@ return (%a)
EXPECT_NE(ids.end(), ids.find(i));
}
}
TEST(SerializationTest, RoundTrip) {
static constexpr std::string_view source =
R"(graph(%foo, %bar, %baz):
%o1 = aten.foo(self=%foo, target=%bar, alpha=0.1)
return(%o1, %baz)
)";
const auto graph = stringToGraph(source);
const auto serialized = graphToString(*graph);
EXPECT_EQ(source, serialized);
}
TEST(SerializationTest, EscapedStringConstant) {
const auto parsed =
std::get<std::string>(convertAtomicConstant(R"("string_\"escape")"));
std::string expected = "string_\\\"escape";
EXPECT_EQ(parsed, expected);
}
TEST(SerializationTest, DeviceConstant) {
const auto device =
std::get<c10::Device>(convertAtomicConstant("Device{cuda:1}"));
EXPECT_EQ(device.index(), 1);
EXPECT_EQ(device.type(), c10::DeviceType::CUDA);
}
TEST(SerializationTest, TrueConstant) {
const auto parsedTrue = std::get<bool>(convertAtomicConstant("true"));
EXPECT_EQ(parsedTrue, true);
const auto parsedFalse = std::get<bool>(convertAtomicConstant("false"));
EXPECT_EQ(parsedFalse, false);
}
TEST(SerializationTest, MemoryFormatConstant) {
const auto parsed = std::get<c10::MemoryFormat>(
convertAtomicConstant("MemoryFormat::ContiguousFormat"));
EXPECT_EQ(parsed, c10::MemoryFormat::Contiguous);
}
TEST(SerializationTest, FloatConstant) {
const auto parsed = std::get<double>(convertAtomicConstant("5.0"));
EXPECT_EQ(parsed, 5.0);
}
TEST(SerializationTest, IntConstant) {
const auto parsed = std::get<int64_t>(convertAtomicConstant("5"));
EXPECT_EQ(parsed, 5);
}
TEST(SerializationTest, FloatExponentConstant) {
const auto parsed = std::get<double>(convertAtomicConstant("1e-05"));
EXPECT_EQ(parsed, 0.00001);
}
TEST(SerializationTest, SingleElementListConstant) {
const auto parsed =
std::get<std::vector<int64_t>>(convertListConstant("[1]"));
const auto expected = std::vector<int64_t>{1};
EXPECT_EQ(parsed, expected);
}
TEST(SerializationTest, IntListConstant) {
const auto parsed =
std::get<std::vector<int64_t>>(convertListConstant("[1, 2, 3, 4]"));
const auto expected = std::vector<int64_t>{1, 2, 3, 4};
EXPECT_EQ(parsed, expected);
}
TEST(SerializationTest, FloatListConstant) {
const auto parsed = std::get<std::vector<double>>(
convertListConstant("[1.0, 2.0, 3.0, 4.0]"));
const auto expected = std::vector<double>{1.0, 2.0, 3.0, 4.0};
EXPECT_EQ(parsed, expected);
}
TEST(SerializationTest, BoolListConstant) {
const auto parsed =
std::get<std::vector<bool>>(convertListConstant("[false, true, false]"));
const auto expected = std::vector<bool>{false, true, false};
EXPECT_EQ(parsed, expected);
}
} // namespace torch::nativert

View File

@ -0,0 +1,51 @@
#include <gtest/gtest.h>
#include <torch/nativert/graph/Serialization.h>
namespace torch::nativert {
TEST(SerializationTest, CheckIsSymbolic) {
torch::_export::TensorArgument tensor_arg;
torch::_export::Argument as_tensor_arg;
as_tensor_arg.set_as_tensor(tensor_arg);
EXPECT_TRUE(isSymbolic(as_tensor_arg));
std::vector<torch::_export::TensorArgument> tensor_args;
torch::_export::Argument as_tensors_arg;
as_tensors_arg.set_as_tensors(tensor_args);
EXPECT_TRUE(isSymbolic(as_tensors_arg));
torch::_export::SymIntArgument sym_int_arg;
torch::_export::Argument as_sym_int_arg;
as_sym_int_arg.set_as_sym_int(sym_int_arg);
EXPECT_TRUE(isSymbolic(as_sym_int_arg));
torch::_export::Argument as_int_arg;
as_int_arg.set_as_int(static_cast<int64_t>(1));
EXPECT_FALSE(isSymbolic(as_int_arg));
torch::_export::Argument as_bool_arg;
as_bool_arg.set_as_bool(true);
EXPECT_FALSE(isSymbolic(as_bool_arg));
torch::_export::Argument as_string_arg;
as_string_arg.set_as_string("test_string");
EXPECT_FALSE(isSymbolic(as_string_arg));
}
TEST(SerializationTest, ConstantToValue) {
torch::_export::Argument as_int_arg;
as_int_arg.set_as_int(static_cast<int64_t>(42));
auto value = constantToValue(as_int_arg, false);
EXPECT_EQ(value, Constant(static_cast<int64_t>(42)));
torch::_export::Argument as_bool_arg;
as_bool_arg.set_as_bool(true);
value = constantToValue(as_bool_arg, false);
EXPECT_EQ(value, Constant(true));
torch::_export::Argument as_string_arg;
as_string_arg.set_as_string("test_string");
value = constantToValue(as_string_arg, false);
EXPECT_EQ(value, Constant("test_string"));
}
} // namespace torch::nativert

View File

@ -0,0 +1,550 @@
#include <fmt/format.h>
#include <fmt/ostream.h>
#include <fmt/ranges.h>
#include <torch/nativert/graph/Serialization.h>
#include <limits>
namespace torch::nativert {
namespace {
std::unique_ptr<Graph> jsonToSubgraph(
const torch::_export::Graph& jsonGraph,
const torch::_export::GraphSignature* signature,
bool loadNodeMetadata);
Value* symbolicToValue(
const torch::_export::Argument& arg,
Graph& graph,
Node* insertBefore) {
switch (arg.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR:
return graph.getValue(arg.get_as_tensor().get_name());
case torch::_export::Argument::Tag::AS_TENSORS: {
// Need to insert a list pack node
std::vector<Value*> listValue;
for (const auto& listEl : arg.get_as_tensors()) {
listValue.push_back(graph.getValue(listEl.get_name()));
}
auto listPack =
graph.createListPack(std::move(listValue), Type::Kind::Tensor);
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
}
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: {
// Need to insert a list pack node
std::vector<Value*> listValue;
for (const auto& listEl : arg.get_as_optional_tensors()) {
switch (listEl.tag()) {
case torch::_export::OptionalTensorArgument::Tag::AS_TENSOR: {
listValue.push_back(
graph.getValue(listEl.get_as_tensor().get_name()));
break;
}
case torch::_export::OptionalTensorArgument::Tag::AS_NONE: {
listValue.push_back(
graph.addValue(std::nullopt, Type::Kind::None, nullptr));
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unknown OptionalTensorArgument type: {}",
torch::_export::printEnum(listEl.tag())));
}
}
auto listPack = graph.createOptionalListPack(std::move(listValue));
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
return graph.getValue(arg.get_as_sym_int().get_as_name());
}
case torch::_export::Argument::Tag::AS_SYM_INTS: {
// Need to insert a list pack node
std::vector<Value*> listValue;
for (const auto& listEl : arg.get_as_sym_ints()) {
switch (listEl.tag()) {
case torch::_export::SymIntArgument::Tag::AS_NAME: {
listValue.push_back(graph.getValue(listEl.get_as_name()));
break;
}
case torch::_export::SymIntArgument::Tag::AS_INT: {
// These are concrete int values in the SymIntList, e.g [s0, 8]
// We convert them into a constant Value in graph. These value
// doesn't have producer node
int64_t value = listEl.get_as_int();
TORCH_CHECK(
value >= std::numeric_limits<int>::min() &&
value <= std::numeric_limits<int>::max());
Value* symintValue =
graph.createConstantSymIntValue(static_cast<int>(value));
listValue.push_back(symintValue);
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unknown SymIntArgument type: {}",
torch::_export::printEnum(listEl.tag())));
}
}
auto listPack =
graph.createListPack(std::move(listValue), Type::Kind::SymInt);
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
}
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
return graph.getValue(arg.get_as_custom_obj().get_name());
}
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
return graph.getValue(arg.get_as_sym_bool().get_as_name());
}
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
return graph.getValue(arg.get_as_sym_float().get_as_name());
}
default:
TORCH_CHECK(
false,
fmt::format(
"This function should only be called with symbolic arguments, got {} instead",
torch::_export::printEnum(arg.tag())));
}
}
std::pair<
std::vector<torch::_export::InputSpec>,
std::vector<torch::_export::Argument>>
enforceInputOrder(
const std::vector<torch::_export::InputSpec>& inputSpecs,
const std::vector<torch::_export::Argument>& graphInputs) {
// Enforce the order of inputSpecs and graphInputs to be the following:
// 1. token
// 2. parameter
// 3. persistent buffer, non-persistent buffer
// 4. tensor_constant
// 5. custom_obj
// 6. user_input/constant_input
std::vector<torch::_export::InputSpec> reorderedInputSpecs;
std::vector<torch::_export::Argument> reorderedGraphInputs;
std::vector<torch::_export::InputSpec::Tag> desiredOrder = {
torch::_export::InputSpec::Tag::TOKEN,
torch::_export::InputSpec::Tag::PARAMETER,
torch::_export::InputSpec::Tag::BUFFER,
torch::_export::InputSpec::Tag::TENSOR_CONSTANT,
torch::_export::InputSpec::Tag::CUSTOM_OBJ};
auto reorder = [&](auto condition) {
for (size_t i = 0; i < inputSpecs.size(); ++i) {
if (condition(inputSpecs[i])) {
reorderedInputSpecs.push_back(inputSpecs[i]);
reorderedGraphInputs.push_back(graphInputs[i]);
}
}
};
for (const auto& tag : desiredOrder) {
if (tag == torch::_export::InputSpec::Tag::BUFFER) {
// Add persistent buffers first, then non-persistent
reorder([&](const auto& spec) {
return spec.tag() == tag && spec.get_buffer().get_persistent();
});
reorder([&](const auto& spec) {
return spec.tag() == tag && !spec.get_buffer().get_persistent();
});
} else {
reorder([&](const auto& spec) { return spec.tag() == tag; });
}
}
// Append USER_INPUT and CONSTANT_INPUT without reordering
for (size_t i = 0; i < inputSpecs.size(); ++i) {
auto tag = inputSpecs[i].tag();
if (tag == torch::_export::InputSpec::Tag::USER_INPUT ||
tag == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
reorderedInputSpecs.push_back(inputSpecs[i]);
reorderedGraphInputs.push_back(graphInputs[i]);
}
}
return {std::move(reorderedInputSpecs), std::move(reorderedGraphInputs)};
}
std::unique_ptr<Graph> jsonToSubgraph(
const torch::_export::Graph& jsonGraph,
const torch::_export::GraphSignature* signature,
bool loadNodeMetadata) {
auto graphInputs = jsonGraph.get_inputs();
auto graph = Graph::createGraph();
if (signature) {
// enforcing the order signature inputspecs and graph inputs
const auto& inputSpecs = signature->get_input_specs();
auto [reorderedInputSpecs, reorderedGraphInputs] =
enforceInputOrder(inputSpecs, graphInputs);
graphInputs = std::move(reorderedGraphInputs);
auto reorderedSignature = *signature;
reorderedSignature.set_input_specs(reorderedInputSpecs);
graph->setSignature(torch::nativert::GraphSignature{reorderedSignature});
}
for (const auto& input : graphInputs) {
if (isSymbolic(input)) {
switch (input.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR: {
const auto& asTensor = input.get_as_tensor();
const auto& name = asTensor.get_name();
graph->addInput(name, Type::Kind::Tensor);
break;
}
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
const auto& asCustomObj = input.get_as_custom_obj();
const std::string& name = asCustomObj.get_name();
const std::string& classFqn = asCustomObj.get_class_fqn();
graph->addInput(name, Type(Type::Kind::CustomObj, classFqn));
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unsupported symbolic graph input type: {}",
torch::_export::printEnum(input.tag())));
}
} else {
switch (input.tag()) {
case torch::_export::Argument::Tag::AS_INT:
case torch::_export::Argument::Tag::AS_FLOAT:
case torch::_export::Argument::Tag::AS_STRING:
case torch::_export::Argument::Tag::AS_BOOL:
case torch::_export::Argument::Tag::AS_NONE: {
// Constant graph inputs are specialized in the graph, here we simply
// add a nullptr of Value to the graph input node.
graph->addInput();
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unsupported constant graph input type: {}",
torch::_export::printEnum(input.tag())));
}
}
}
for (const auto& jsonNode : jsonGraph.get_nodes()) {
auto node = graph->insertNode(
jsonNode.get_target(),
{},
loadNodeMetadata ? jsonNode.get_metadata()
: std::unordered_map<std::string, std::string>());
std::vector<NamedArgument> args;
std::vector<Attribute> attributes;
for (const auto& input : jsonNode.get_inputs()) {
// We handle constants and symbolic inputs differently.
const auto& arg = input.get_arg();
if (isSymbolic(arg)) {
// Symbolic values are made part of the inputs to the node
node->addInput(NamedArgument{
input.get_name(), symbolicToValue(input.get_arg(), *graph, node)});
} else if (arg.tag() == torch::_export::Argument::Tag::AS_NONE) {
node->addInput(NamedArgument{
input.get_name(),
graph->addValue(std::nullopt, Type::Kind::None, node)});
} else {
node->addAttribute(Attribute{
input.get_name(),
constantToValue(input.get_arg(), loadNodeMetadata)});
// Constant values are added as "attributes" to the node.
}
}
std::vector<Value*> outputs;
std::vector<Value*> listUnpacksToCreate;
for (const auto& output : jsonNode.get_outputs()) {
switch (output.tag()) {
case torch::_export::Argument::Tag::AS_NONE: {
node->addOutput(Type::Kind::None);
break;
}
case torch::_export::Argument::Tag::AS_TENSOR: {
const auto name = output.get_as_tensor().get_name();
node->addOutput(name, Type::Kind::Tensor);
break;
}
case torch::_export::Argument::Tag::AS_TENSORS: {
auto outputValue = node->addOutput(
graph->getUniqueValueName(), Type::Kind::TensorList);
Node* listUnpack =
graph->insertNode("prim.ListUnpack", {{"input", outputValue}});
for (const auto& arg : output.get_as_tensors()) {
listUnpack->addOutput(arg.get_name(), Type::Kind::Tensor);
}
break;
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
const auto name = output.get_as_sym_int().get_as_name();
node->addOutput(name, Type::Kind::SymInt);
break;
}
case torch::_export::Argument::Tag::AS_SYM_INTS: {
TORCH_CHECK(
false,
"SymInts NYI. We currently don't have ops that produce SymInts as output");
}
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
const auto name = output.get_as_sym_bool().get_as_name();
node->addOutput(name, Type::Kind::SymBool);
break;
}
case torch::_export::Argument::Tag::AS_SYM_BOOLS: {
TORCH_CHECK(
false,
"SymBools NYI. We currently don't have ops that produce SymBools as output");
}
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
const auto name = output.get_as_sym_float().get_as_name();
node->addOutput(name, Type::Kind::SymFloat);
break;
}
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
TORCH_CHECK(
false,
"SymFloats NYI. We currently doesn't have op that produces SymFloats as output");
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unsupported graph output type: {}",
torch::_export::printEnum(output.tag())));
}
}
}
for (const auto& output : jsonGraph.get_outputs()) {
// handle symbolic outputs and constant outputs differently
if (isSymbolic(output)) {
switch (output.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR: {
const auto& asTensor = output.get_as_tensor();
const auto& name = asTensor.get_name();
Value* outputValue = graph->getValue(name);
graph->addOutput(outputValue);
break;
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
const auto& asSymInt = output.get_as_sym_int();
TORCH_CHECK(
asSymInt.tag() == torch::_export::SymIntArgument::Tag::AS_NAME);
const auto& name = asSymInt.get_as_name();
Value* outputValue = graph->getValue(name);
graph->addOutput(outputValue);
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unsupported graph output type: {}",
torch::_export::printEnum(output.tag())));
}
} else {
Constant constValue = constantToValue(output, loadNodeMetadata);
graph->addConstantOutput(std::move(constValue));
}
}
auto jsonTensorValue = jsonGraph.get_tensor_values();
if (!signature) {
// For subgraphs we just need to derive a graph signature that only
// contains user inputs and outputs, because we don't need to handle any
// special semantics for them, e.g. mutation or gradients.
torch::_export::GraphSignature sig;
std::vector<torch::_export::InputSpec> inputSpecs;
for (const auto& input : graph->inputs()) {
torch::_export::Argument arg;
if (input->type().kind() == Type::Kind::Tensor) {
torch::_export::TensorArgument targ;
targ.set_name(std::string{input->name()});
arg.set_as_tensor(std::move(targ));
} else {
TORCH_CHECK(
false,
fmt::format(
"Unsupported subgraph input type {}",
fmt::streamed(input->type())));
}
torch::_export::UserInputSpec userInputSpec;
userInputSpec.set_arg(std::move(arg));
torch::_export::InputSpec inputSpec;
inputSpec.set_user_input(std::move(userInputSpec));
inputSpecs.push_back(std::move(inputSpec));
}
sig.set_input_specs(std::move(inputSpecs));
std::vector<torch::_export::OutputSpec> outputSpecs;
for (const auto& output : graph->outputs()) {
torch::_export::Argument arg;
if (output->type().kind() == Type::Kind::Tensor) {
torch::_export::TensorArgument targ;
targ.set_name(std::string{output->name()});
arg.set_as_tensor(std::move(targ));
} else {
TORCH_CHECK(
false,
fmt::format(
"Unsupported subgraph output type {}",
fmt::streamed(output->type())));
}
torch::_export::UserOutputSpec userOutputSpec;
userOutputSpec.set_arg(std::move(arg));
torch::_export::OutputSpec outputSpec;
outputSpec.set_user_output(std::move(userOutputSpec));
outputSpecs.push_back(std::move(outputSpec));
}
sig.set_output_specs(std::move(outputSpecs));
graph->setSignature(torch::nativert::GraphSignature{sig});
}
// weightsTensorMeta are indexed by weight's name, not graph input's name
std::unordered_map<std::string, torch::_export::TensorMeta> weightsTensorMeta;
for (const auto& [inputName, weightName] :
graph->signature().inputsToWeights()) {
auto value = graph->getValue(inputName);
if (value->type().kind() == Type::Kind::CustomObj) {
// skip setting meta for non-tensor inputs
continue;
}
auto it = jsonTensorValue.find(inputName);
CHECK(it != jsonTensorValue.end())
<< "Missing tensor metadata for " << inputName
<< "in thriftGraph.tensorValue";
weightsTensorMeta[weightName] = it->second;
}
graph->setWeightsMeta(weightsTensorMeta);
graph->setTensorValuesMeta(jsonTensorValue);
graph->finalize();
graph->lint();
return graph;
}
} // namespace
bool isSymbolic(const torch::_export::Argument& arg) {
switch (arg.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR:
case torch::_export::Argument::Tag::AS_TENSORS:
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
case torch::_export::Argument::Tag::AS_SYM_INT:
case torch::_export::Argument::Tag::AS_SYM_INTS:
case torch::_export::Argument::Tag::AS_SYM_BOOL:
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
case torch::_export::Argument::Tag::AS_SYM_FLOAT:
case torch::_export::Argument::Tag::AS_SYM_FLOATS:
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
return true;
default:
return false;
}
}
Constant constantToValue(
const torch::_export::Argument& jsonArg,
bool loadNodeMetadata) {
switch (jsonArg.tag()) {
case torch::_export::Argument::Tag::AS_NONE:
return torch::nativert::None();
case torch::_export::Argument::Tag::AS_INT:
return jsonArg.get_as_int();
case torch::_export::Argument::Tag::AS_INTS: {
std::vector<int64_t> ret;
for (const auto& arg : jsonArg.get_as_ints()) {
ret.push_back(arg);
}
return ret;
}
case torch::_export::Argument::Tag::AS_FLOAT:
return jsonArg.get_as_float().get();
case torch::_export::Argument::Tag::AS_FLOATS: {
std::vector<double> ret;
for (const auto& arg : jsonArg.get_as_floats()) {
ret.push_back(arg.get());
}
return ret;
}
case torch::_export::Argument::Tag::AS_STRING:
return jsonArg.get_as_string();
case torch::_export::Argument::Tag::AS_STRINGS: {
std::vector<std::string> ret;
for (const auto& arg : jsonArg.get_as_strings()) {
ret.push_back(arg);
}
return ret;
}
case torch::_export::Argument::Tag::AS_SCALAR_TYPE:
return torch::nativert::convertJsonScalarType(
jsonArg.get_as_scalar_type());
case torch::_export::Argument::Tag::AS_MEMORY_FORMAT:
return torch::nativert::convertJsonMemoryFormat(
jsonArg.get_as_memory_format());
case torch::_export::Argument::Tag::AS_LAYOUT:
return torch::nativert::convertJsonLayout(jsonArg.get_as_layout());
case torch::_export::Argument::Tag::AS_DEVICE:
return torch::nativert::convertJsonDevice(jsonArg.get_as_device());
case torch::_export::Argument::Tag::AS_BOOL:
return jsonArg.get_as_bool();
case torch::_export::Argument::Tag::AS_BOOLS: {
std::vector<bool> ret;
for (const auto& arg : jsonArg.get_as_bools()) {
ret.push_back(arg);
}
return ret;
}
case torch::_export::Argument::Tag::AS_GRAPH: {
return jsonToSubgraph(
*jsonArg.get_as_graph().get_graph(), nullptr, loadNodeMetadata);
}
case torch::_export::Argument::Tag::AS_TENSOR:
case torch::_export::Argument::Tag::AS_TENSORS:
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
TORCH_CHECK(false, "Tensor values are symbolic, not constant.");
case torch::_export::Argument::Tag::AS_SYM_INT:
case torch::_export::Argument::Tag::AS_SYM_INTS:
case torch::_export::Argument::Tag::AS_SYM_BOOL:
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
TORCH_CHECK(false, "Symint/Symbool Values are symbolic, not constant.");
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
TORCH_CHECK(false, "custom obj is symbolic, not constant");
case torch::_export::Argument::Tag::AS_OPERATOR:
return jsonArg.get_as_operator();
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
TORCH_CHECK(false, "SymFloat is not yet implemented");
}
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
TORCH_CHECK(false, "SymFloats is not yet implemented");
}
default:
TORCH_CHECK(false, "Got unknown json argument");
}
}
std::unique_ptr<Graph> jsonToGraph(
const torch::_export::GraphModule& jsonGraphModule,
bool loadNodeMetadata) {
auto graph = jsonToSubgraph(
jsonGraphModule.get_graph(),
&jsonGraphModule.get_signature(),
loadNodeMetadata);
return graph;
}
} // namespace torch::nativert

View File

@ -0,0 +1,27 @@
#pragma once
#include <torch/nativert/graph/Graph.h>
#include <torch/csrc/utils/generated_serialization_types.h>
namespace torch::nativert {
/**
* This file contains serialization utilities for Graph.
*
* There are two serialized representations we care about:
* - Json: stable but hard to work with, not really human readable
* - Debug format: human-readable, not stable.
*/
// Json -> Graph
std::unique_ptr<Graph> jsonToGraph(
const torch::_export::GraphModule& jsonGraph,
bool loadNodeMetadata = true);
bool isSymbolic(const torch::_export::Argument& arg);
Constant constantToValue(
const torch::_export::Argument& jsonArg,
bool loadNodeMetadata);
} // namespace torch::nativert