From 1bc3571078e9f08c1642b8b7378ccd96f29ae6a2 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 12 Jan 2022 16:27:21 -0800 Subject: [PATCH] [pytorch][PR] Add ability for a mobile::Module to save as flatbuffer (#70201) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70201 Included functions: save_mobile_module -> saves a mobile::Module to flatbuffer load_mobile_module_from_file -> loads a flatbuffer into mobile::Module parse_mobile_module -> parses from bytes or deserialized flatbuffer module object Compared to previous attempts, this diff only adds flatbuffer to cmake target and leaves fbcode/xplat ones unchanged. Test Plan: unittest Reviewed By: malfet, gmagogsfm Differential Revision: D33239362 fbshipit-source-id: b9ca36b83d6af2d78cc50b9eb9e2a6fa7fce0763 --- .github/workflows/lint.yml | 2 +- .gitmodules | 3 + BUILD.bazel | 3 + WORKSPACE | 5 + caffe2/CMakeLists.txt | 5 + cmake/Dependencies.cmake | 3 + cmake/FlatBuffers.cmake | 10 + test/cpp/jit/CMakeLists.txt | 5 + test/cpp/jit/test_flatbuffer.cpp | 1086 +++++++ third_party/flatbuffers | 1 + torch/CMakeLists.txt | 3 + torch/csrc/jit/mobile/flatbuffer_loader.cpp | 518 ++++ torch/csrc/jit/mobile/flatbuffer_loader.h | 54 + torch/csrc/jit/mobile/module.cpp | 9 +- torch/csrc/jit/mobile/module.h | 8 + torch/csrc/jit/runtime/instruction.h | 1 + .../serialization/flatbuffer_serializer.cpp | 681 +++++ .../jit/serialization/flatbuffer_serializer.h | 26 + .../jit/serialization/mobile_bytecode.fbs | 197 ++ .../serialization/mobile_bytecode_generated.h | 2514 +++++++++++++++++ 20 files changed, 5132 insertions(+), 2 deletions(-) create mode 100644 cmake/FlatBuffers.cmake create mode 100644 test/cpp/jit/test_flatbuffer.cpp create mode 160000 third_party/flatbuffers create mode 100644 torch/csrc/jit/mobile/flatbuffer_loader.cpp create mode 100644 torch/csrc/jit/mobile/flatbuffer_loader.h create mode 100644 torch/csrc/jit/serialization/flatbuffer_serializer.cpp create mode 100644 torch/csrc/jit/serialization/flatbuffer_serializer.h create mode 100644 torch/csrc/jit/serialization/mobile_bytecode.fbs create mode 100644 torch/csrc/jit/serialization/mobile_bytecode_generated.h diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 85b38e175509..e429903b75ce 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -49,7 +49,7 @@ jobs: - name: Ensure canonical include if: always() run: | - (! git --no-pager grep -In $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' || (echo "The above lines have include with quotes; please convert them to #include "; false)) + (! git --no-pager grep -In $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' ':(exclude)torch/csrc/jit/serialization/mobile_bytecode_generated.h'|| (echo "The above lines have include with quotes; please convert them to #include "; false)) - name: Ensure no versionless Python shebangs if: always() run: | diff --git a/.gitmodules b/.gitmodules index a7cc437f4384..9c9373ef7229 100644 --- a/.gitmodules +++ b/.gitmodules @@ -142,3 +142,6 @@ [submodule "third_party/breakpad"] path = third_party/breakpad url = https://github.com/driazati/breakpad.git +[submodule "third_party/flatbuffers"] + path = third_party/flatbuffers + url = https://github.com/google/flatbuffers.git diff --git a/BUILD.bazel b/BUILD.bazel index 4b7b8b8fce0d..9c6e5d759174 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1692,6 +1692,7 @@ cc_library( ":aten_headers", ":caffe2_headers", "//c10:headers", + "@com_github_google_flatbuffers//:flatbuffers", "@local_config_python//:python_headers", "@onnx", ], @@ -1725,6 +1726,8 @@ cc_library( ], )) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + jit_core_sources + [ ":cpp_generated_code", + "torch/csrc/jit/serialization/flatbuffer_serializer.cpp", + "torch/csrc/jit/mobile/flatbuffer_loader.cpp" ], copts = TORCH_COPTS, defines = [ diff --git a/WORKSPACE b/WORKSPACE index 0497bef41039..95eee3bdd494 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -197,3 +197,8 @@ new_local_repository( build_file = "@//third_party:cudnn.BUILD", path = "/usr/", ) + +local_repository( + name = "com_github_google_flatbuffers", + path = "third_party/flatbuffers", +) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 44557c014b16..1a6566488c94 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,6 +560,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp + ${TORCH_SRC_DIR}/csrc/jit/mobile/flatbuffer_loader.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp @@ -595,6 +596,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp ${TORCH_SRC_DIR}/csrc/jit/serialization/export_bytecode.cpp ${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp + ${TORCH_SRC_DIR}/csrc/jit/serialization/flatbuffer_serializer.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp ${TORCH_SRC_DIR}/csrc/jit/api/module_save.cpp ${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp @@ -1645,6 +1647,9 @@ if(APPLE AND USE_PYTORCH_METAL) endif() endif() + +target_link_libraries(torch_cpu PRIVATE flatbuffers) + # Note [Global dependencies] # Some libraries (e.g. OpenMPI) like to dlopen plugins after they're initialized, # and they assume that all of their symbols will be available in the global namespace. diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index bb14176f3255..34c6bb285dee 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1996,3 +1996,6 @@ if(USE_KINETO) message(STATUS "Configured Kineto") endif() endif() + +# Include google/FlatBuffers +include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake) diff --git a/cmake/FlatBuffers.cmake b/cmake/FlatBuffers.cmake new file mode 100644 index 000000000000..2e8e1b957db7 --- /dev/null +++ b/cmake/FlatBuffers.cmake @@ -0,0 +1,10 @@ +set(FlatBuffers_Include ${PROJECT_SOURCE_DIR}/third_party/flatbuffers/include) +file(GLOB FlatBuffers_Library_SRCS + ${FlatBuffers_Include}/flatbuffers/*.h +) +add_library(flatbuffers INTERFACE) +target_sources( + flatbuffers + INTERFACE ${FlatBuffers_Library_SRCS} +) +target_include_directories(flatbuffers INTERFACE ${FlatBuffers_Include}) diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index a86910bd0c5e..cfdbb28a6765 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -89,6 +89,7 @@ set(JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_script_profile.cpp ${JIT_TEST_ROOT}/test_shape_analysis.cpp ${JIT_TEST_ROOT}/test_jit_logging_levels.cpp + ${JIT_TEST_ROOT}/test_flatbuffer.cpp ) if(USE_CUDA) @@ -101,6 +102,10 @@ add_executable(test_jit ${JIT_TEST_SRCS} ) +target_link_libraries( + test_jit PRIVATE flatbuffers) + + # TODO temporary until we can delete the old gtest polyfills. target_compile_definitions(test_jit PRIVATE USE_GTEST) diff --git a/test/cpp/jit/test_flatbuffer.cpp b/test/cpp/jit/test_flatbuffer.cpp new file mode 100644 index 000000000000..73a84297f449 --- /dev/null +++ b/test/cpp/jit/test_flatbuffer.cpp @@ -0,0 +1,1086 @@ +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +// Tests go in torch::jit +namespace torch { +namespace jit { + +mobile::Module parse_mobile_module(void* data, size_t) { + auto* flatbuffer_module = mobile::serialization::GetMutableModule(data); + return initialize_mobile_module(flatbuffer_module); +} + +TEST(FlatbufferTest, UpsampleNearest2d) { + Module m("m"); + m.define(R"( + def forward(self, input: Tensor, scale:float): + return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) + )"); + + std::vector inputs; + inputs.emplace_back(torch::rand({1, 3, 128, 128})); + inputs.emplace_back(at::Scalar(2.0)); + auto ref = m.forward(inputs); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + IValue res; + res = bc.forward(inputs); + + auto resd = res.toTensor(); + auto refd = ref.toTensor(); + ASSERT_TRUE(resd.equal(refd)); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + auto res2 = bc2.forward(inputs); + auto resd2 = res2.toTensor(); + ASSERT_TRUE(resd2.equal(refd)); +} + +TEST(FlatbufferTest, CheckAttrAccess) { + Module m("m"); + m.register_attribute("mobile_optimized", BoolType::get(), true); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + bool mobile_optimized = bc.attr("mobile_optimized", false).toBool(); + + AT_ASSERT(mobile_optimized); + m.setattr("mobile_optimized", false); + bc = jitModuleToMobile(m, options); + mobile_optimized = bc.attr("mobile_optimized", false).toBool(); + + AT_ASSERT(!mobile_optimized); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + auto mobile_optimized2 = bc2.attr("mobile_optimized", false).toBool(); + AT_ASSERT(!mobile_optimized2); +} + +TEST(FlatbufferTest, MethodInvocation) { // NOLINT (use =delete in gtest) + const std::vector test_programs{ + // test invoking a method with default parameter + R"( + def test_func(self, x, b : int = 4): + return self.foo + x + b + )", + // inner method call with default parameter (gets inlined) + R"( + def add_with_default_arg(self, x, b : int = 4): + return self.foo + x + b + def test_func(self, x): + return self.add_with_default_arg(x) # invoke method w/ default arg + )", + // simple method call + R"( + def test_func(self, x): + b = 4 + return self.foo + x + b + )", + }; + for (const auto& test_program : test_programs) { + Module m("m"); + m.register_parameter("foo", torch::ones({}), false); + m.define(test_program); + + const int fortyTwo = 42; // (keep linter happy) + auto minput = fortyTwo * torch::ones({}); + auto ref = m.run_method("test_func", minput); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + const auto& test_func = bc.get_method("test_func"); + IValue res; + for (int i = 0; i < 3; ++i) { + res = test_func({minput}); + } + + auto resd = res.toTensor().item(); + auto refd = ref.toTensor().item(); + AT_ASSERT(resd == refd); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + const auto& test_func2 = bc2.get_method("test_func"); + IValue res2; + for (int i = 0; i < 3; ++i) { + res2 = test_func2({minput}); + } + auto resd2 = res2.toTensor().item(); + AT_ASSERT(resd2 == refd); + } +} + +TEST(FlatbufferTest, Conv) { + auto s = std::getenv("PYTORCH_TEST_WITH_TSAN"); + if (s && strcmp(s, "1") == 0) + return; + + std::vector inputs; + + Module m("m"); + m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); + m.register_parameter("bias", torch::ones({20}), false); + m.define(R"( + def forward(self, input): + return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) + )"); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) + inputs.push_back(torch::ones({1, 1, 28, 28})); + + auto outputref = m.forward(inputs).toTensor(); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + IValue res; + for (int i = 0; i < 3; ++i) { + res = bc.get_method("forward")(inputs); + } + auto output = res.toTensor(); + AT_ASSERT(outputref.dim() == output.dim()); + AT_ASSERT( + outputref[0][0][0][0].item() == output[0][0][0][0].item()); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + for (int i = 0; i < 3; ++i) { + res = bc2.get_method("forward")(inputs); + } + output = res.toTensor(); + AT_ASSERT(outputref.dim() == output.dim()); + AT_ASSERT( + outputref[0][0][0][0].item() == output[0][0][0][0].item()); +} + +TEST(FlatbufferTest, Inline) { + Module m("m"); + m.define(R"JIT( + def foo1(self, x): + return x + 1 + + def foo2(self, x): + return self.foo1(x) + 2 + + def foo3(self, x): + return self.foo2(x) + 3 + )JIT"); + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + std::vector inputs({torch::ones({})}); + auto output = bc.get_method("foo3")(inputs); + AT_ASSERT(output.toTensor().item() == 7.0); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + std::vector inputs2({torch::ones({})}); + output = bc2.get_method("foo3")(inputs2); + AT_ASSERT(output.toTensor().item() == 7.0); +} + +TEST(FlatbufferTest, Tuple) { + Module m("m"); + m.define(R"JIT( + def foo(self, x): + return (1, 2, x + 3) + + def forward(self, x): + tuple = self.foo(x) + return tuple + )JIT"); + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + std::vector inputs({torch::ones({})}); + auto output = bc.get_method("forward")(inputs); + AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + output = bc2.get_method("forward")(inputs); + AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2); +} + +TEST(FlatbufferTest, Dict) { + Module m("m"); + m.define(R"JIT( + def foo(self, x): + return {"result": x + 1} + + def forward(self, x): + d = self.foo(x) + return d + )JIT"); + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + std::vector inputs({torch::ones({})}); + auto output = bc.get_method("forward")(inputs); + AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + output = bc2.get_method("forward")(inputs); + AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2); +} + +TEST(FlatbufferTest, Prim) { + Module m("m"); + m.define(R"JIT( + def forward(self, x): + return int(x) + )JIT"); + + std::vector inputs; + auto minput = 3.5 * torch::ones({}); + inputs.emplace_back(minput); + auto ref = m.run_method("forward", minput); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + IValue res; + for (int i = 0; i < 3; ++i) { + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto bcinputs = inputs; + res = bc.get_method("forward")(bcinputs); + } + + auto resi = res.toInt(); + auto refi = ref.toInt(); + AT_ASSERT(resi == refi); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + for (int i = 0; i < 3; ++i) { + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto bcinputs = inputs; + res = bc2.get_method("forward")(bcinputs); + } + auto resi2 = res.toInt(); + AT_ASSERT(resi2 == refi); +} + +TEST(FlatbufferTest, PrimScalar) { + Module m("m"); + m.define(R"JIT( + def forward(self, x): + return int(x.item()) + )JIT"); + + std::vector inputs; + auto minput = 3.5 * torch::ones({}); + inputs.emplace_back(minput); + auto ref = m.run_method("forward", minput); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + IValue res; + for (int i = 0; i < 3; ++i) { + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto bcinputs = inputs; + res = bc.get_method("forward")(bcinputs); + } + + auto resi = res.toInt(); + auto refi = ref.toInt(); + AT_ASSERT(resi == refi); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + for (int i = 0; i < 3; ++i) { + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto bcinputs = inputs; + res = bc2.get_method("forward")(bcinputs); + } + auto resi2 = res.toInt(); + AT_ASSERT(resi2 == refi); +} + +TEST(FlatbufferTest, WrongMethodName) { + Module m("m"); + m.register_parameter("foo", torch::ones({}), false); + m.define(R"( + def add(self, x): + b = 4 + return self.foo + x + b + )"); + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + std::vector inputs; + auto minput = 5 * torch::ones({}); + inputs.emplace_back(minput); + ASSERT_THROWS_WITH_MESSAGE( + bc.get_method("forward")(inputs), "is not defined"); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + ASSERT_THROWS_WITH_MESSAGE( + bc2.get_method("forward")(inputs), "is not defined"); +} + +TEST(FlatbufferTest, SetState) { + Module m("m"); + m.register_parameter("foo", torch::ones({}), false); + m.define(R"( + def __getstate__(self): + return self.foo + def __setstate__(self, a): + self.foo = a + def forward(self, x): + b = 4 + return self.foo + x + b + )"); + + std::vector inputs; + auto minput = 5 * torch::ones({}); + inputs.emplace_back(minput); + + std::stringstream ms; + m.save(ms); + auto loaded_m = load(ms); + auto ref = loaded_m.run_method("forward", minput); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + IValue res; + for (int i = 0; i < 3; ++i) { + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto bcinputs = inputs; + res = bc.get_method("forward")(bcinputs); + } + + auto resd = res.toTensor().item(); + auto refd = ref.toTensor().item(); + AT_ASSERT(resd == refd); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + for (int i = 0; i < 3; ++i) { + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto bcinputs = inputs; + res = bc2.get_method("forward")(bcinputs); + } + + auto resd2 = res.toTensor().item(); + AT_ASSERT(resd2 == refd); +} + +class TorchBindFlatbufferTestStruct : public torch::jit::CustomClassHolder { + public: + std::string get(at::Tensor t) { + std::stringstream ss; + ss << "Hello! Your tensor has "; + ss << t.numel(); + ss << " elements!"; + return ss.str(); + } +}; + +namespace { +struct ClassNamespaceValue : public SugaredValue { + explicit ClassNamespaceValue(c10::QualifiedName name) + : basename_(std::move(name)) {} + + std::shared_ptr attr( + const SourceRange& loc, + GraphFunction& m, + const std::string& name) override { + const auto fullName = c10::QualifiedName(basename_, name); + + // Check to see if it is a custom class. + if (auto custom_class = getCustomClass(fullName.qualifiedName())) { + return std::make_shared(custom_class); + } + + // If it's not a custom class, assume it's another namespace + // NOLINTNEXTLINE(performance-move-const-arg) + return std::make_shared(std::move(fullName)); + } + + std::string kind() const override { + return "Class Namespace"; + } + + private: + c10::QualifiedName basename_; +}; + +struct TestModuleResolver : public Resolver { + std::shared_ptr resolveValue( + const std::string& name, + GraphFunction& m, + const SourceRange& loc) override { + if (name == "torch") { + return std::make_shared("aten"); + } else if (name == "__torch__") { + return std::make_shared(c10::QualifiedName(name)); + } + + return nullptr; + } + + TypePtr resolveType(const std::string& name, const SourceRange& loc) + override { + return nullptr; + } +}; +} // namespace + +TEST(FlatbufferTest, BuiltinClass) { + script::Module m("m"); + + auto cls = getCustomClass( + "__torch__.torch.classes._TorchScriptTesting._FlatbufferTest"); + TORCH_INTERNAL_ASSERT(cls); + c10::intrusive_ptr obj_holder; + m.register_attribute("my_obj", cls, IValue::make_capsule(obj_holder)); + + m.register_parameter("foo", torch::ones({}), false); + m.define( + R"( + def __getstate__(self): + return 1 + def __setstate__(self, a): + self.my_obj = __torch__.torch.classes._TorchScriptTesting._FlatbufferTest() + + def forward(self, x) -> str: + return self.my_obj.get(x) + )", + std::make_shared()); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + std::string expected = "Hello! Your tensor has 12 elements!"; + auto res = + bc2.get_method("forward")(std::vector{torch::zeros({3, 4})}); + const auto& str2 = res.toStringRef(); + AT_ASSERT(str2 == expected); +} + +TEST(FlatbufferTest, BuiltinFunction) { + script::Module m("m"); + auto custom_class_obj = make_custom_class(); + m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj); + m.define(R"( + def forward(self, x) -> str: + return self.my_obj.get(x) + )"); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + auto res = + bc.get_method("forward")(std::vector{torch::zeros({3, 4})}); + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto str = res.toStringRef(); + std::string expected = "Hello! Your tensor has 12 elements!"; + AT_ASSERT(str == expected); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + res = bc2.get_method("forward")(std::vector{torch::zeros({3, 4})}); + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + str = res.toStringRef(); + AT_ASSERT(str == expected); +} + +TEST(FlatbufferTest, Eval) { + std::vector inputs; + + Module m("m"); + m.define(R"( + def __init__(self, x): + self.training = True + + def forward(self, input): + return torch.dropout(input, 1.0, self.training) + )"); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace) + inputs.push_back(torch::ones({1, 1, 28, 28})); + m.eval(); + auto outputref = m.forward(inputs).toTensor(); + + // save m in training mode to make sure that mobile eval() will correctly + // change back to eval mode + m.train(); + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + bc.eval(); + IValue res; + for (int i = 0; i < 3; ++i) { + res = bc.get_method("forward")(inputs); + } + auto output = res.toTensor(); + AT_ASSERT(outputref.dim() == output.dim()); + AT_ASSERT( + outputref[0][0][0][0].item() == output[0][0][0][0].item()); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + bc2.eval(); + for (int i = 0; i < 3; ++i) { + res = bc2.get_method("forward")(inputs); + } + output = res.toTensor(); + AT_ASSERT(outputref.dim() == output.dim()); + AT_ASSERT( + outputref[0][0][0][0].item() == output[0][0][0][0].item()); +} + +TEST(FlatbufferTest, FindWrongMethodName) { + Module m("m"); + m.register_parameter("foo", torch::ones({}), false); + m.define(R"( + def add(self, x): + b = 4 + return self.foo + x + b + )"); + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + ASSERT_TRUE(bc.find_method("forward") == c10::nullopt); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + ASSERT_TRUE(bc2.find_method("forward") == c10::nullopt); +} + +TEST(FlatbufferTest, FindAndRunMethod) { + Module m("m"); + m.register_parameter("foo", torch::ones({}), false); + m.define(R"( + def add_it(self, x): + b = 4 + return self.foo + x + b + )"); + + std::vector inputs; + auto minput = 5 * torch::ones({}); + inputs.emplace_back(minput); + auto ref = m.get_method("add_it")(inputs); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + IValue res; + for (int i = 0; i < 3; ++i) { + auto bcinputs = inputs; + auto method = bc.find_method("add_it"); + AT_ASSERT(method != c10::nullopt); + res = (*method)(std::move(bcinputs)); + } + + auto resd = res.toTensor().item(); + auto refd = ref.toTensor().item(); + AT_ASSERT(resd == refd); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + + for (int i = 0; i < 3; ++i) { + auto bcinputs = inputs; + auto method = bc2.find_method("add_it"); + AT_ASSERT(method != c10::nullopt); + res = (*method)(std::move(bcinputs)); + } + + resd = res.toTensor().item(); + AT_ASSERT(resd == refd); +} + +TEST(FlatbufferTest, RunMethodVariadic) { + Module m("m"); + m.register_parameter("foo", torch::ones({}), false); + m.define(R"( + def add_three(self, x, y): + return self.foo + x + y + )"); + + std::vector inputs; + auto inputx = 5 * torch::ones({}); + auto inputy = 4 * torch::ones({}); + auto ref = m.run_method("add_three", inputx, inputy); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + IValue res = bc.run_method("add_three", inputx, inputy); + + auto resd = res.toTensor().item(); + auto refd = ref.toTensor().item(); + AT_ASSERT(resd == refd); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + res = bc.run_method("add_three", inputx, inputy); + resd = res.toTensor().item(); + AT_ASSERT(resd == refd); +} + +TEST(FlatbufferTest, DuplicateSetState) { + Module m("M"); + m.register_parameter("foo", torch::ones({}), false); + m.define(R"( + def __getstate__(self): + return self.foo + self.foo + def __setstate__(self, a): + self.foo = a + def forward(self, x): + b = 4 + return self.foo + x + b + )"); + + Module b("B"); + b.register_module("M0", m); + b.register_module("M1", m); + b.define(R"( + def forward(self, x): + return self.M0.forward(x) + self.M1.forward(x) + )"); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + const auto methods = bc.get_methods(); + const size_t expected_n = 3; + ASSERT_EQ(methods.size(), expected_n); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + const auto methods2 = bc.get_methods(); + ASSERT_EQ(methods2.size(), expected_n); +} + +TEST(FlatbufferTest, OpNameExportFetchRootOperators) { + torch::jit::Module m("m"); + m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); + m.register_parameter("bias", torch::ones({20}), false); + m.define(R"( + def forward(self, input): + x1 = torch.zeros(2, 2) + x2 = torch.empty_like(torch.empty(2, 2)) + x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) + return (x1, x2, x3) + )"); + m.eval(); + + CompilationOptions options; + mobile::Module ptl_model = jitModuleToMobile(m, options); + std::set operator_names = + torch::jit::mobile::_export_operator_list(ptl_model); + std::set expected_operator_names = { + "aten::_convolution", + "aten::empty.memory_format", + "aten::empty_like", + "aten::zeros", + }; + EXPECT_EQ(operator_names, expected_operator_names) + << "Expected the root operator lists to be the same"; + + auto buff = save_mobile_module_to_bytes(ptl_model); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + operator_names = torch::jit::mobile::_export_operator_list(bc2); + EXPECT_EQ(operator_names, expected_operator_names) + << "Expected the root operator lists to be the same"; +} + +TEST(FlatbufferTest, DefaultArgsConv) { + auto s = std::getenv("PYTORCH_TEST_WITH_TSAN"); + if (s && strcmp(s, "1") == 0) + return; + + std::vector inputs; + + Module m("m"); + m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); + m.register_parameter("bias", torch::ones({20}), false); + m.define(R"( + def forward(self, input): + return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1) + )"); + + inputs.emplace_back(torch::ones({1, 1, 28, 28})); + + auto outputref = m.forward(inputs).toTensor(); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + IValue res; + for (int i = 0; i < 1; ++i) { + res = bc.get_method("forward")(inputs); + } + auto output = res.toTensor(); + AT_ASSERT(outputref.dim() == output.dim()); + AT_ASSERT(output.equal(outputref)); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + for (int i = 0; i < 1; ++i) { + res = bc2.get_method("forward")(inputs); + } + output = res.toTensor(); + AT_ASSERT(outputref.dim() == output.dim()); + AT_ASSERT(output.equal(outputref)); +} + +namespace { +void testLiteModuleCompareResultTensors( + Module& m, + const std::vector& inputs, + const std::string& method_name = "forward") { + auto outputref = m.get_method(method_name)(inputs).toTensor(); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + IValue res; + for (int i = 0; i < 3; ++i) { + res = bc.get_method(method_name)(inputs); + } + auto output = res.toTensor(); + AT_ASSERT(outputref.dim() == output.dim()); + AT_ASSERT(output.equal(outputref)); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + for (int i = 0; i < 3; ++i) { + res = bc2.get_method(method_name)(inputs); + } + output = res.toTensor(); + AT_ASSERT(outputref.dim() == output.dim()); + AT_ASSERT(output.equal(outputref)); +} + +void testDefaultArgsPinv(int num_args) { + Module m("m"); + if (num_args == 1) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input) + )"); + } else if (num_args == 2) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5) + )"); + } else if (num_args == 3) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5, True) + )"); + } + + std::vector inputs; + const int N = 28; + auto input = torch::range(1, N * N, 1); + input[0] = 1; // a more stable matrix + input = input.view({N, N}); + inputs.emplace_back(input); + testLiteModuleCompareResultTensors(m, inputs); +} + +void testDefaultArgsPinvWithOutArg(int num_args) { + Module m("m"); + if (num_args == 1) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, out=input) + )"); + } else if (num_args == 2) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5, out=input) + )"); + } else if (num_args == 3) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5, True, out=input) + )"); + } + + const int N = 28; + auto input = torch::range(1, N * N, 1); + input[0] = 10000; // a more stable matrix + input = input.view({N, N}); + auto ref = m.run_method("forward", input); + TORCH_CHECK(!input.equal(torch::range(1, N * N, 1))); + TORCH_CHECK(input.equal(ref.toTensor())); +} + +TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) { + // Test with different number of specified arguments + out arg. + // Arguments not specified take default value. + for (int num_args = 1; num_args <= 3; ++num_args) { + testDefaultArgsPinvWithOutArg(num_args); + } +} + +TEST(FlatbufferTest, DefaultArgsWithOutArg) { + Module m("m"); + m.define(R"( + def forward(self, x, h): + torch.add(x, h, out=x) + )"); + + std::vector inputs; + auto input_x = 2 * torch::ones({}); + auto input_h = torch::ones({}); + auto ref = m.run_method("forward", input_x, input_h); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + bc.run_method("forward", input_x, input_h); + AT_ASSERT(input_x.equal(4 * torch::ones({}))); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + auto input_x2 = 2 * torch::ones({}); + auto input_h2 = torch::ones({}); + m.run_method("forward", input_x2, input_h2); + bc2.run_method("forward", input_x2, input_h2); + AT_ASSERT(input_x2.equal(4 * torch::ones({}))); +} +} // namespace + +#if !defined FB_XPLAT_BUILD +TEST(FlatbufferTest, DefaultArgsPinv) { + // Test with different number of specified arguments. + // Arguments not specified take default value. + for (int num_args = 1; num_args <= 3; ++num_args) { + testDefaultArgsPinv(num_args); + } + + // bytecode with one specified argument: + // (6, + // ('__torch__.m.forward', + // (('instructions', + // (('STOREN', 1, 2), + // ('DROPR', 1, 0), + // ('MOVE', 2, 0), + // ('OP', 0, 0), + // ('RET', 0, 0))), + // ('operators', (('aten::linalg_pinv', '', 1),)), + // ('constants', (False, 1e-15)), # default constants are not + // used + // ('types', ()), + // ('register_size', 2)), + // (('arguments', + // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', + // None)), + // (('name', 'input'), ('type', 'Tensor'), ('default_value', + // None)))), + // ('returns', + // ((('name', ''), ('type', 'Tensor'), ('default_value', + // None)),))))) + + // bytecode with 2 specified argument: + // (6, + // ('__torch__.m.forward', + // (('instructions', + // (('STOREN', 1, 2), + // ('DROPR', 1, 0), + // ('MOVE', 2, 0), + // ('LOADC', 1, 0), # added LOADC for specified argument + // ('OP', 0, 0), + // ('RET', 0, 0))), + // ('operators', (('aten::linalg_pinv', '', 2),)), + // ('constants', (False, 1e-05)), # updated constant table + // ('types', ()), + // ('register_size', 2)), + // (('arguments', + // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', + // None)), + // (('name', 'input'), ('type', 'Tensor'), ('default_value', + // None)))), + // ('returns', + // ((('name', ''), ('type', 'Tensor'), ('default_value', + // None)),))))) + + // bytecode with 3 specified arguments: + // (6, + // ('__torch__.m.forward', + // (('instructions', + // (('STOREN', 1, 2), + // ('DROPR', 1, 0), + // ('MOVE', 2, 0), + // ('LOADC', 1, 0), + // ('LOADC', 0, 0), + // ('OP', 0, 0), + // ('RET', 0, 0))), + // ('operators', (('aten::linalg_pinv', '', 3),)), + // ('constants', (True, 1e-05)), + // ('types', ()), + // ('register_size', 2)), + // (('arguments', + // ((('name', 'self'), ('type', '__torch__.m'), ('default_value', + // None)), + // (('name', 'input'), ('type', 'Tensor'), ('default_value', + // None)))), + // ('returns', + // ((('name', ''), ('type', 'Tensor'), ('default_value', + // None)),))))) +} + +TEST(FlatbufferTest, DefaultArgsTensorinvSpecifyDefault) { + // The second argument is specified, but the value is the same as the default + // value. It's treated as "not specified" since the value can be fetched from + // schema. + Module m("m"); + m.define(R"( + def forward(self, input): + return torch.linalg_tensorinv(input, 2) + )"); + torch::jit::MobileCode code(m.get_method("forward").graph(), "forward"); + auto arg_nums = code.op_to_num_specified_args(); + ASSERT_EQ(arg_nums.size(), 1); + ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1); + std::vector inputs; + const int N = 4; + auto input = torch::rand({N, N, N, N}); + inputs.emplace_back(input); + testLiteModuleCompareResultTensors(m, inputs); +} + +#endif // !defined(FB_XPLAT_BUILD) + +namespace { +static auto reg = + torch::class_( + "_TorchScriptTesting", + "_FlatbufferTest") + .def(torch::init<>()) + .def("get", &TorchBindFlatbufferTestStruct::get) + .def_pickle( + // __getattr__ + [](const c10::intrusive_ptr& self) + -> int64_t { return 0; }, + // __setattr__ + [](int64_t state) { + return c10::make_intrusive(); + }); + +} // namespace + +TEST(FlatbufferTest, OperatorCacheDifferentiatesDefaultArgs) { + // Create 3 methods: + // + // 1. forward() returns a tensor with dtype=torch.int64 (4) + // 2. forward2() returns a tensor with dtype=torch.float32 (6) + // 3. forward3() returns a tensor with dtype=torch.float32 but + // the dtype is inferred by the input tensor's dtype + // + // If caching works correctly, then the result from the full-jit + // module and the lite module will be the same. Otherwise, it + // will be different if we don't correctly ignore the cache + // entry for an operator that has a different number of + // arguments. + Module m("m"); + m.define(R"( + def forward(self): + ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4) + return ret1.fill_(25) + )"); + m.define(R"( + def forward2(self): + ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6) + return ret1.fill_(32.0) + )"); + m.define(R"( + def forward3(self): + ret1 = torch.new_empty(torch.zeros(10), [10]) + return ret1.fill_(12.0) + )"); + + std::vector inputs; + testLiteModuleCompareResultTensors(m, inputs, "forward"); + testLiteModuleCompareResultTensors(m, inputs, "forward2"); + testLiteModuleCompareResultTensors(m, inputs, "forward3"); +} + +TEST(FlatbufferTest, OperatorSize1) { + Module m("m"); + m.define(R"( + def forward(self, input: Tensor, scale:float): + return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) + )"); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + const auto& func = bc.get_method("forward").function(); + ASSERT_EQ( + func.get_code().operator_input_sizes_.size(), + func.get_code().operators_.size()); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + const auto& func2 = bc.get_method("forward").function(); + ASSERT_EQ( + func2.get_code().operator_input_sizes_.size(), + func2.get_code().operators_.size()); +} + +TEST(FlatbufferTest, OperatorTest2) { // NOLINT (use =delete in gtest) + const std::vector test_programs{ + // test invoking a method with default parameter + R"( + def test_func(self, x, b : int = 4): + return self.foo + x + b + )", + // inner method call with default parameter (gets inlined) + R"( + def add_with_default_arg(self, x, b : int = 4): + return self.foo + x + b + def test_func(self, x): + return self.add_with_default_arg(x) # invoke method w/ default arg + )", + // simple method call + R"( + def test_func(self, x): + b = 4 + return self.foo + x + b + )", + }; + for (const auto& test_program : test_programs) { + Module m("m"); + m.register_parameter("foo", torch::ones({}), false); + m.define(test_program); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + const auto& func = bc.get_method("test_func").function(); + ASSERT_EQ( + func.get_code().operator_input_sizes_.size(), + func.get_code().operators_.size()); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + const auto& func2 = bc.get_method("test_func").function(); + ASSERT_EQ( + func2.get_code().operator_input_sizes_.size(), + func2.get_code().operators_.size()); + } +} + +} // namespace jit +} // namespace torch diff --git a/third_party/flatbuffers b/third_party/flatbuffers new file mode 160000 index 000000000000..f2f9380c86a7 --- /dev/null +++ b/third_party/flatbuffers @@ -0,0 +1 @@ +Subproject commit f2f9380c86a762ef0d9410693c61c35567923d63 diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 1c2f8e329c3a..94d9cbb33b3b 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -70,6 +70,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${TORCH_ROOT}/third_party/gloo ${TORCH_ROOT}/third_party/onnx + ${TORCH_ROOT}/third_party/flatbuffers/include ${pybind11_INCLUDE_DIRS} ${TORCH_SRC_DIR}/csrc @@ -345,6 +346,8 @@ if(HAVE_SOVERSION) VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) endif() add_dependencies(torch_python torch_python_stubs) +add_dependencies(torch_python flatbuffers) + if(USE_PRECOMPILED_HEADERS) target_precompile_headers(torch_python PRIVATE diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp new file mode 100644 index 000000000000..7b40ca82d4a8 --- /dev/null +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -0,0 +1,518 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#if defined(HAVE_MMAP) +#include +#include +#include +#include +#endif + +#include +#include + +namespace torch { +namespace jit { +namespace { + +using caffe2::serialize::IStreamAdapter; +using caffe2::serialize::PyTorchStreamReader; +using caffe2::serialize::ReadAdapterInterface; + +static constexpr c10::string_view kCustomClassPrefix = + "__torch__.torch.classes"; +static constexpr c10::string_view kTorchPrefix = "__torch__"; +static constexpr c10::string_view kJitPrefix = "torch.jit"; + +class FlatbufferLoader { + public: + FlatbufferLoader() + : mcu_(std::make_shared()), + cu_(std::make_shared()) {} + + mobile::Module parseModule(mobile::serialization::Module* module); + + private: + IValue parseIValue(const mobile::serialization::IValue* ivalue); + IValue parseList(const mobile::serialization::List* list); + at::Tensor parseTensor(const mobile::serialization::TensorMetadata* tensor); + IValue parseTuple(const mobile::serialization::Tuple* tuple); + IValue parseDict(const mobile::serialization::Dict* dict); + IValue parseObject(const mobile::serialization::Object* object); + std::unique_ptr parseFunction( + const mobile::serialization::Function* method); + + IValue& getIValue(uint32_t pos) { + TORCH_CHECK(pos < all_ivalues_.size()); + return all_ivalues_[pos]; + } + + mobile::Function* getFunction(uint32_t pos) { + return all_functions_[pos]; + } + + ClassTypePtr getType(uint32_t pos) const { + TORCH_CHECK(pos < all_ivalues_.size()); + return all_types_[pos]; + // auto iter = all_types_.find(pos); + // AT_ASSERT(iter != all_types_.end(), "type not found at pos: ", pos); + // return iter->second; + } + + c10::Storage getStorage(uint32_t index); + TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset); + + // fields + std::unordered_map all_functions_; + std::vector all_types_; + std::unordered_set initialized_types_; + std::unordered_map type_annotations_; + std::vector storage_loaded_; + std::vector storages_; + std::vector all_ivalues_; + std::shared_ptr mcu_; + std::shared_ptr cu_; + mobile::serialization::Module* module_ = nullptr; +}; + +mobile::Module FlatbufferLoader::parseModule( + mobile::serialization::Module* module) { + module_ = module; + all_ivalues_.clear(); + all_types_.clear(); + storages_.clear(); + storage_loaded_.clear(); + + const auto* ivalues = module->ivalues(); + all_ivalues_.resize(ivalues->size()); + all_types_.resize(module->object_types()->size()); + storages_.resize(module->storage_data_size()); + storage_loaded_.resize(module->storage_data_size(), false); + + for (uint32_t i = 0; i < ivalues->size(); i++) { + const auto* ival = ivalues->Get(i); + if (const auto* func = ival->val_as_Function()) { + auto func_ptr = parseFunction(func); + all_functions_[i] = func_ptr.get(); + mcu_->register_function(std::move(func_ptr)); + } else { + all_ivalues_[i] = parseIValue(ival); + } + } + + IValue& module_ivalue = getIValue(module->state_obj()); + // register function to class + // for (const auto& func: all_functions_) { + // const auto* fb_func = ivalues->Get(func.first)->val_as_Function(); + // auto class_type = getType(fb_func->class_type()); + // class_type->addMethod(func.second); + // } + return mobile::Module(module_ivalue.toObject(), mcu_); +} + +std::unique_ptr FlatbufferLoader::parseFunction( + const mobile::serialization::Function* method) { + auto function = std::make_unique( + c10::QualifiedName(method->qn()->str())); + // TODO(qihan) add debug handle + // const auto* debug_handle = method->debug_info()->debug_handle(); + for (const auto* inst : *method->instructions()) { + function->append_instruction( + static_cast(inst->op()), inst->x(), inst->n()); + } + + for (uint32_t i : *method->constants()) { + function->append_constant(getIValue(i)); + } + + std::unordered_set unsupported_op_names; + const int64_t model_version = 0x6L; + for (const auto* op : *method->operators()) { + c10::optional num_args = c10::nullopt; + if (op->num_args_serialized() > -1) { + num_args = op->num_args_serialized(); + } + + auto op_found = function->append_operator( + op->name()->str(), op->overload_name()->str(), num_args, model_version); + + if (!op_found) { + unsupported_op_names.emplace( + op->name()->str() + "/" + op->overload_name()->str()); + } + } + + AT_ASSERT(unsupported_op_names.empty()); + + for (const auto i : *method->type_annotations()) { + function->append_type(getOrCreateTypeAnnotations(i)); + } + + function->set_register_size(method->register_size()); + if (method->schema()) { + auto parseArgList = [this](const auto* args_fb) { + std::vector args; + for (const auto* arg_tb : *args_fb) { + IValue default_value = getIValue(arg_tb->default_value()); + TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type()); + auto arg = c10::Argument( + arg_tb->name()->str(), + std::move(type_ptr), + c10::nullopt /*N*/, + std::move(default_value)); + args.emplace_back(std::move(arg)); + } + return args; + }; + c10::FunctionSchema schema( + method->qn()->str(), + "" /*overload_name*/, + parseArgList(method->schema()->arguments()), + parseArgList(method->schema()->returns()), + false /*is_varargs*/, + false /*is_varret*/); + + function->setSchema(std::move(schema)); + } + return function; +} + +at::Tensor FlatbufferLoader::parseTensor( + const mobile::serialization::TensorMetadata* tensor_md) { + at::ScalarType type = static_cast(tensor_md->scalar_type()); + auto options = at::CPU(type).options(); + at::Tensor tensor; + if (tensor_md->quantized_schema() != nullptr) { + // is quantized + const auto* schema = tensor_md->quantized_schema(); + auto qscheme_type = static_cast(schema->qscheme()); + switch (qscheme_type) { + case at::kPerTensorAffine: { + tensor = at::_empty_affine_quantized( + {0}, options, schema->scale(), schema->zero_point()); + } break; + case at::kPerChannelAffineFloatQParams: + case at::kPerChannelAffine: { + at::Tensor scales = parseTensor(schema->scales()); + at::Tensor zero_points = parseTensor(schema->zero_points()); + tensor = at::_empty_per_channel_affine_quantized( + {0}, scales, zero_points, schema->axis(), options); + } break; + default: + TORCH_CHECK( + false, + "Unsupported tensor quantization type in serialization ", + toString(qscheme_type)); + break; + } + } else { + tensor = at::empty({0}, options); + } + at::TensorImpl* impl = tensor.unsafeGetTensorImpl(); + + c10::Storage storage; + storage = getStorage(tensor_md->storage_location_index()); + impl->set_storage_keep_dtype(storage); + impl->set_storage_offset(tensor_md->storage_offset()); + + std::vector size{ + tensor_md->sizes()->begin(), tensor_md->sizes()->end()}; + std::vector stride{ + tensor_md->strides()->begin(), tensor_md->strides()->end()}; + impl->set_sizes_and_strides(size, stride); + tensor = autograd::make_variable(tensor, tensor_md->requires_grad()); + return tensor; +} +IValue FlatbufferLoader::parseList(const mobile::serialization::List* list) { + auto res = c10::impl::GenericList(AnyType::get()); + for (int i : *list->items()) { + res.emplace_back(getIValue(i)); + } + auto type = + getOrCreateTypeAnnotations(list->annotation_str())->cast(); + res.unsafeSetElementType(type->getElementType()); + return res; +} + +IValue FlatbufferLoader::parseTuple(const mobile::serialization::Tuple* tuple) { + std::vector res; + for (int i : *tuple->items()) { + res.emplace_back(getIValue(i)); + } + return c10::ivalue::Tuple::create(res); +} + +IValue FlatbufferLoader::parseDict(const mobile::serialization::Dict* dict) { + auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get()); + const auto* keys = dict->keys(); + const auto* values = dict->values(); + for (size_t i = 0; i < keys->size(); ++i) { + uint32_t key = keys->Get(i); + uint32_t val = values->Get(i); + result.insert_or_assign(getIValue(key), getIValue(val)); + } + auto type = + getOrCreateTypeAnnotations(dict->annotation_str())->cast(); + result.unsafeSetKeyType(type->getKeyType()); + result.unsafeSetValueType(type->getValueType()); + return result; +} + +IValue FlatbufferLoader::parseObject( + const mobile::serialization::Object* object) { + const mobile::serialization::ObjectType* obj_type = + module_->object_types()->Get(object->type_index()); + auto cls = getType(object->type_index()); + bool initialized = true; + if (cls == nullptr) { + c10::string_view qn_str( + obj_type->type_name()->c_str(), obj_type->type_name()->size()); + if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) { + c10::QualifiedName qn(obj_type->type_name()->str()); + cls = cu_->get_class(qn); + if (cls == nullptr) { + cls = ClassType::create(qn, cu_, true); + cu_->register_type(cls); + } + } else { + cls = c10::parseType(std::string(qn_str))->cast(); + } + TORCH_CHECK(object->type_index() < all_ivalues_.size()); + all_types_[object->type_index()] = cls; + initialized = false; + } + Stack stack; + switch (obj_type->type()) { + case mobile::serialization::TypeType::CLASS_WITH_FIELD: { + auto obj = c10::ivalue::Object::create( + at::StrongTypePtr(cu_, cls), object->attrs()->size()); + if (!initialized) { + for (uint32_t i = 0; i < object->attrs()->size(); i++) { + IValue val = getIValue(object->attrs()->Get(i)); + cls->addAttribute(obj_type->attr_names()->Get(i)->str(), val.type()); + obj->setSlot(i, std::move(val)); + } + initialized_types_.insert(object->type_index()); + } else { + for (uint32_t i = 0; i < object->attrs()->size(); i++) { + IValue val = getIValue(object->attrs()->Get(i)); + obj->setSlot(i, std::move(val)); + } + } + return obj; + } + case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: { + IValue input = getIValue(object->state()); + mobile::Function* setstate = getFunction(object->setstate_func()); + auto obj = c10::ivalue::Object::create(at::StrongTypePtr(cu_, cls), 0); + std::cerr << "here 2: " << cls.get() << std::endl; + stack.push_back(obj); + stack.emplace_back(std::move(input)); + setstate->run(stack); + return obj; + } + case mobile::serialization::TypeType::CUSTOM_CLASS: { + auto custom_class_type = + torch::jit::getCustomClass(cls->name()->qualifiedName()); + IValue input = getIValue(object->state()); + auto obj = c10::ivalue::Object::create( + c10::StrongTypePtr(nullptr, custom_class_type), 1); + std::cerr << "here 3: " << cls.get() << std::endl; + stack.push_back(obj); + stack.emplace_back(std::move(input)); + custom_class_type->getMethod("__setstate__").run(stack); + return obj; + } + default: + AT_ASSERT(false, "need to be object"); + } +} + +template +std::vector parseListNative(const U* list) { + return {list->items()->begin(), list->items()->end()}; +} + +IValue FlatbufferLoader::parseIValue( + const mobile::serialization::IValue* ivalue) { + switch (ivalue->val_type()) { + case mobile::serialization::IValueUnion::NONE: + return {}; + case mobile::serialization::IValueUnion::Int: + return ivalue->val_as_Int()->int_val(); + case mobile::serialization::IValueUnion::Bool: + return ivalue->val_as_Bool()->bool_val(); + case mobile::serialization::IValueUnion::Double: + return ivalue->val_as_Double()->double_val(); + case mobile::serialization::IValueUnion::ComplexDouble: { + const auto* comp = ivalue->val_as_ComplexDouble(); + return c10::complex(comp->real(), comp->imag()); + } + case mobile::serialization::IValueUnion::TensorMetadata: + return parseTensor(ivalue->val_as_TensorMetadata()); + case mobile::serialization::IValueUnion::String: + return ivalue->val_as_String()->data()->str(); + case mobile::serialization::IValueUnion::List: + return parseList(ivalue->val_as_List()); + case mobile::serialization::IValueUnion::IntList: + return parseListNative(ivalue->val_as_IntList()); + case mobile::serialization::IValueUnion::DoubleList: + return parseListNative(ivalue->val_as_DoubleList()); + case mobile::serialization::IValueUnion::BoolList: { + std::vector res = + parseListNative(ivalue->val_as_BoolList()); + c10::List boollist; + for (auto x : res) { + boollist.push_back(x); + } + return boollist; + } + case mobile::serialization::IValueUnion::Tuple: + return parseTuple(ivalue->val_as_Tuple()); + case mobile::serialization::IValueUnion::Dict: + return parseDict(ivalue->val_as_Dict()); + case mobile::serialization::IValueUnion::Object: { + auto val = parseObject(ivalue->val_as_Object()); + return val; + } + case mobile::serialization::IValueUnion::Device: { + return c10::Device(ivalue->val_as_Device()->str()->str()); + } + case mobile::serialization::IValueUnion::EnumValue: { + const auto* enum_val = ivalue->val_as_EnumValue(); + auto enum_type = getOrCreateTypeAnnotations(enum_val->type_name()) + ->cast(); + AT_ASSERT( + enum_type, + "Enum with type: " + enum_val->type_name()->str() + " not found."); + IValue val = getIValue(enum_val->value()); + for (const auto& p : enum_type->enumNamesValues()) { + if (p.second == val) { + auto enum_holder = c10::make_intrusive( + enum_type, p.first, p.second); + return IValue(std::move(enum_holder)); + } + } + AT_ASSERT( + false, + "Enum with type: " + enum_val->type_name()->str() + " not found."); + } + default: + return {}; + } +} + +void deleteNothing2(void*); +void deleteNothing2(void*) {} + +c10::Storage FlatbufferLoader::getStorage(uint32_t index) { + TORCH_CHECK(index < storage_loaded_.size()); + TORCH_CHECK(index < storages_.size()); + if (!storage_loaded_[index]) { + auto* storage = module_->storage_data()->GetMutableObject(index); + size_t size = storage->data()->size(); + void* ptr = static_cast(storage->mutable_data()->data()); + at::DataPtr data(ptr, ptr, deleteNothing2, DeviceType::CPU); + storages_[index] = + c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data)); + storage_loaded_[index] = true; + } + return storages_[index]; +} + +TypePtr FlatbufferLoader::getOrCreateTypeAnnotations( + const flatbuffers::String* offset) { + auto iter = type_annotations_.find(offset); + if (iter != type_annotations_.end()) { + return iter->second; + } + TypePtr type; + c10::string_view qn_str(offset->c_str(), offset->size()); + c10::QualifiedName qn(offset->str()); + if (qn_str.starts_with(kCustomClassPrefix)) { + type = getCustomClass(qn.qualifiedName()); + TORCH_CHECK( + type, + "The implementation of class ", + qn.qualifiedName(), + " cannot be found."); + } else if ( + qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) { + if (cu_->get_class(qn) == nullptr) { + auto classtype = ClassType::create(qn, cu_, true); + cu_->register_type(classtype); + type = classtype; + } else { + type = cu_->get_class(qn); + } + } else { + type = c10::parseType(qn.qualifiedName()); + } + type_annotations_[offset] = type; + return type; +} + +} // namespace + +mobile::Module parse_and_initialize_mobile_module( + std::shared_ptr data, + size_t, + c10::optional) { + auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get()); + mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module); + m.set_delete_memory(std::move(data)); + return m; +} + +mobile::Module initialize_mobile_module( + mobile::serialization::Module* flatbuffer_module, + c10::optional) { + mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module); + return m; +} + +mobile::Module load_mobile_module_from_file( + const std::string& filename, + c10::optional device) { +#if defined(HAVE_MMAP) + int fd = open(filename.c_str(), O_RDONLY); + struct stat statbuf {}; + fstat(fd, &statbuf); + int size = statbuf.st_size; + void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0); + close(fd); + auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); }; + std::shared_ptr data(reinterpret_cast(ptr), deleter); +#else + FILE* f = fopen(filename.c_str(), "rb"); + fseek(f, 0, SEEK_END); + long size = ftell(f); + fseek(f, 0, SEEK_SET); + std::shared_ptr data(static_cast(malloc(size)), free); // NOLINT + fread(data.get(), size, 1, f); + fclose(f); +#endif + return parse_and_initialize_mobile_module(std::move(data), size, device); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h new file mode 100644 index 000000000000..ee76cddc27e4 --- /dev/null +++ b/torch/csrc/jit/mobile/flatbuffer_loader.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include + +#include +#include + +namespace torch { +namespace jit { + +// On high level, to produce a Module from a file on disk, we need to go +// through the follow steps: +// 1. Read: Read the file from disk -> memory +// 2. Deserialize: Parse the bytes to produce some in memory manipulable +// structure +// 3. Module initialization: Produce mobile::Module out of the structure +// produced in 2. +// Under this context, the structure described in 2. is +// mobile::serialization::Module + +// Parse a mobile::Module from flatbuffer's in-memory Module representation. +// The caller is assumed to manage the lifetimes of Module. +// This function does step 3 described above. +TORCH_API mobile::Module initialize_mobile_module( + mobile::serialization::Module* flatbuffer_module, + c10::optional device = c10::nullopt); + +// Parse a mobile::Module from raw bytes. +// ownership of data is shared to the returned Module. +// (Feel free to pass in a unique_ptr too!) +// This function does steps 2+3 described above +TORCH_API mobile::Module parse_and_initialize_mobile_module( + std::shared_ptr data, + size_t size, + c10::optional device = c10::nullopt); + +// Load a mobile::Module from a filepath. +// This function does steps 1+2+3 described above. +// We need to have this as a convienience because Python +// API will need to wrap this. C++ clients should use one +// versions above. +TORCH_API mobile::Module load_mobile_module_from_file( + const std::string& filename, + c10::optional device = c10::nullopt); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index 55db4128bdc1..27f013808a78 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -19,7 +19,8 @@ void CompilationUnit::register_function(std::unique_ptr fn) { methods_.emplace_back(std::move(fn)); } -Function* CompilationUnit::find_function(const c10::QualifiedName& qn) { +const Function* CompilationUnit::find_function( + const c10::QualifiedName& qn) const { for (auto& fn : methods_) { if (fn->qualname() == qn) { return fn.get(); @@ -28,6 +29,12 @@ Function* CompilationUnit::find_function(const c10::QualifiedName& qn) { return nullptr; } +Function* CompilationUnit::find_function(const c10::QualifiedName& qn) { + // NOLINTNEXTLINE + return const_cast( + static_cast(this)->find_function(qn)); +} + Method Module::get_method(const std::string& name) const { if (auto method = find_method(name)) { return *method; diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index 31d3f58b81e9..fc4c046c77fa 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -40,6 +40,7 @@ class CompilationUnit { return methods_; } Function* find_function(const c10::QualifiedName& qn); + const Function* find_function(const c10::QualifiedName& qn) const; private: std::vector> methods_; @@ -130,12 +131,19 @@ class TORCH_API Module { return *cu_.get(); } + void set_delete_memory(std::shared_ptr delete_mem) { + mem_to_delete_ = delete_mem; + } + private: c10::intrusive_ptr object_; std::unordered_map metadata_; std::shared_ptr cu_; MobileDebugTable debug_table_; bool has_debug_handles_ = false; + + // Extra handle for the module to delete when itself is deleted + std::shared_ptr mem_to_delete_; }; } // namespace mobile } // namespace jit diff --git a/torch/csrc/jit/runtime/instruction.h b/torch/csrc/jit/runtime/instruction.h index 8e72b55c7ae0..17bb2135906b 100644 --- a/torch/csrc/jit/runtime/instruction.h +++ b/torch/csrc/jit/runtime/instruction.h @@ -76,6 +76,7 @@ std::ostream& operator<<(std::ostream& out, Instruction inst); bool isOpSupportedInMobile(OpCode op); char const* toString(OpCode op); +std::ostream& operator<<(std::ostream& out, Instruction inst); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp new file mode 100644 index 000000000000..928bf32256a7 --- /dev/null +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -0,0 +1,681 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using flatbuffers::FlatBufferBuilder; +using mobile::serialization::CreateArg; +using mobile::serialization::CreateDebugInfo; +using mobile::serialization::CreateDict; +using mobile::serialization::CreateFunctionDirect; +using mobile::serialization::CreateIValue; +using mobile::serialization::CreateList; +using mobile::serialization::CreateModule; +using mobile::serialization::CreateObject; +using mobile::serialization::CreateOperator; +using mobile::serialization::CreateTensorMetadataDirect; +using mobile::serialization::CreateTupleDirect; + +namespace { + +// We will store IValue NONE in index 0 in flatbuffer. +constexpr int kNoneIndex = 0; + +class FlatbufferSerializer { + public: + FlatbufferSerializer() = default; + + flatbuffers::DetachedBuffer serializeModule( + const mobile::Module& module, + bool include_tensor_data_in_flatbuffer); + + private: + template + std::vector storeIValuesAndGetIndexes( + flatbuffers::FlatBufferBuilder& fbb, + It begin, + It end) { + std::vector indexes; + for (; begin != end; ++begin) { + indexes.push_back(storeIValueAndGetIndex(fbb, *begin)); + } + return indexes; + } + + flatbuffers::Offset tupleToFB( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& tuple); + + flatbuffers::Offset listToFB( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& list); + + flatbuffers::Offset dictToFB( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& list); + + flatbuffers::Offset objectToFB( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& ivalue); + + flatbuffers::Offset tensorToFB( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& ivalue); + + flatbuffers::Offset functionToFB( + flatbuffers::FlatBufferBuilder& fbb, + const std::string& qn, + const mobile::Function& func); + + flatbuffers::Offset iValueToFB( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& ivalue); + + flatbuffers::Offset CreateFBSchema( + flatbuffers::FlatBufferBuilder& fbb, + const std::vector& args, + const std::vector& returns, + c10::TypePrinter type_printer); + + flatbuffers::Offset classTypeToFB( + flatbuffers::FlatBufferBuilder& fbb, + ClassTypePtr class_ptr); + + uint32_t storeIValueAndGetIndex( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& ivalue); + uint32_t storeFunctionAndGetIndex( + flatbuffers::FlatBufferBuilder& fbb, + const std::string& qn, + const mobile::Function& function); + + uint32_t storeClassTypeAndGetIndex( + flatbuffers::FlatBufferBuilder& fbb, + ClassTypePtr class_type); + + uint32_t insertIValue( + flatbuffers::Offset ivalue) { + uint32_t size = ivalue_offsets_.size(); + ivalue_offsets_.push_back(ivalue); + return size; + } + + std::vector tensor_data_; + + std::unordered_map memoized_storage_map_; + + std::vector> + ivalue_offsets_; + std::vector> + obj_types_offset_; + + // qualified name to serialized class, type or function + std::unordered_map qn_to_serialized_values_; + + // cache of some ivalues + struct IValueHash { + size_t operator()(const IValue& val) const { + return IValue::hash(val); + } + }; + + std::unordered_map cached_ivalues_; + + const mobile::CompilationUnit* mcu_ = nullptr; +}; + +flatbuffers::Offset FlatbufferSerializer:: + CreateFBSchema( + flatbuffers::FlatBufferBuilder& fbb, + const std::vector& args, + const std::vector& returns, + c10::TypePrinter type_printer) { + std::vector> arg_vec; + arg_vec.reserve(args.size()); + std::vector> return_vec; + return_vec.reserve(returns.size()); + for (const auto& arg : args) { + int index = storeIValueAndGetIndex(fbb, arg.default_value()); + arg_vec.emplace_back(CreateArg( + fbb, + fbb.CreateSharedString(arg.name()), + fbb.CreateSharedString(arg.type()->annotation_str(type_printer)), + index)); + } + + for (const auto& ret : returns) { + int index = storeIValueAndGetIndex(fbb, ret.default_value()); + return_vec.emplace_back(CreateArg( + fbb, + fbb.CreateSharedString(ret.name()), + fbb.CreateSharedString(ret.type()->annotation_str(type_printer)), + index)); + } + return CreateSchema( + fbb, fbb.CreateVector(arg_vec), fbb.CreateVector(return_vec)); +} + +flatbuffers::Offset FlatbufferSerializer:: + functionToFB( + FlatBufferBuilder& fbb, + const std::string& qn, + const mobile::Function& func) { + const auto& code = func.get_code(); + + // instructions + std::vector instruction_vector; + for (const auto& inst : code.instructions_) { + instruction_vector.emplace_back(inst.op, inst.N, inst.X); + } + + // operators + std::vector> + operator_vector; + operator_vector.reserve(code.op_names_.size()); + for (int i = 0; i < code.op_names_.size(); ++i) { + const auto& opname = code.op_names_[i]; + const int op_size = code.operator_input_sizes_[i]; + operator_vector.push_back(CreateOperator( + fbb, + fbb.CreateSharedString(opname.name), + fbb.CreateSharedString(opname.overload_name), + op_size)); + } + + const auto& constants = code.constants_; + + std::vector constant_indexes; + constant_indexes.reserve(constants.size()); + for (const auto& constant : constants) { + constant_indexes.push_back(storeIValueAndGetIndex(fbb, constant)); + } + + // types + static const std::string torch_prefix("__torch__"); + static const std::string class_prefix("__torch__.torch.classes"); + std::vector> type_offsets; + + for (const TypePtr& t : code.types_) { + auto type_str = t->annotation_str(); + if (type_str.find(torch_prefix) == 0) { + TORCH_CHECK( + type_str.find(class_prefix) == 0, + "__torch__ types other than torchbind (__torch__.torch.classes)" + "are not supported in lite interpreter. ", + "Workaround: instead of using arbitrary class type (class Foo()), ", + "define a pytorch class (class Foo(torch.nn.Module))."); + } + + type_offsets.push_back(fbb.CreateSharedString(type_str)); + } + + // since the register location is embedded into the bytecode, pass the + // register size + auto register_size = static_cast(code.register_size_); + + // schema + auto type_printer = [&](const c10::Type& t) -> c10::optional { + auto namedType = t.cast(); + if (namedType && namedType->name()) { + return namedType->name().value().qualifiedName(); + } + return c10::nullopt; + }; + + flatbuffers::Offset schema_offset = 0; + if (func.hasSchema()) { + const auto& schema = func.getSchema(); + TORCH_CHECK( + schema.overload_name().empty(), // @TODO: is this check correct? + "Overloads are not supported in mobile modules."); + TORCH_CHECK( + !schema.is_vararg(), + "Python *args are not supported in mobile modules."); + TORCH_CHECK( + !schema.is_varret(), + "A variable number of return values is not supported in mobile modules."); + schema_offset = + CreateFBSchema(fbb, schema.arguments(), schema.returns(), type_printer); + } + + auto debug_info_offset = + CreateDebugInfo(fbb, fbb.CreateVector(code.debug_handles_)); + + // auto classtype = schema.arguments()[0].type()->cast(); + // uint32_t class_type = storeClassTypeAndGetIndex(fbb, classtype); + + auto function_offset = CreateFunctionDirect( + fbb, + qn.c_str(), + &instruction_vector, + &operator_vector, + &constant_indexes, + &type_offsets, + register_size, + schema_offset, + debug_info_offset, + 0); + return function_offset; +} + +flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule( + const mobile::Module& module, + bool include_tensor_data_in_flatbuffer) { + FlatBufferBuilder fbb; + + mcu_ = &module.compilation_unit(); + + // first element is None. + insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::NONE, 0)); + + auto methods = module.get_methods(); + std::vector functions_index; + functions_index.reserve(methods.size()); + for (const auto& method : methods) { + auto func_offset = storeFunctionAndGetIndex( + fbb, method.function().qualname().qualifiedName(), method.function()); + functions_index.push_back(func_offset); + } + + auto functions_offset = fbb.CreateVector(functions_index); + uint32_t ivalue_index = storeIValueAndGetIndex(fbb, module._ivalue()); + + flatbuffers::Offset>> + storage_data_offset = 0; + if (include_tensor_data_in_flatbuffer) { + std::vector> + storage_data; + for (auto td : tensor_data_) { + if (td.storage().device_type() != DeviceType::CPU) { + td = at::empty({0}, td.options()) + .set_( + td.storage(), + /* storage_offset = */ 0, + /* size = */ + {static_cast( + td.storage().nbytes() / td.element_size())}, + /* stride = */ {1}) + .cpu(); + } + fbb.ForceVectorAlignment( + td.storage().nbytes(), sizeof(uint8_t), FLATBUFFERS_MAX_ALIGNMENT); + auto storage_offset = mobile::serialization::CreateStorageData( + fbb, + fbb.CreateVector( + reinterpret_cast(td.storage().data()), + td.storage().nbytes())); + storage_data.push_back(storage_offset); + } + storage_data_offset = fbb.CreateVector(storage_data); + } + + auto mod = CreateModule( + fbb, + 0, /* version */ + 0, /* extra_files */ + functions_offset, + ivalue_index, + fbb.CreateVector(ivalue_offsets_), + tensor_data_.size(), + storage_data_offset, + fbb.CreateVector(obj_types_offset_)); + fbb.Finish(mod); + return fbb.Release(); +} + +flatbuffers::Offset FlatbufferSerializer:: + tupleToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& tuple) { + const auto& elements = tuple.toTuple()->elements(); + std::vector items = + storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end()); + return CreateTupleDirect(fbb, &items); +} + +flatbuffers::Offset FlatbufferSerializer::listToFB( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& list) { + const auto& elements = list.toList(); + std::vector items = + storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end()); + return CreateList( + fbb, + fbb.CreateVector(items), + fbb.CreateSharedString(list.type()->annotation_str())); +} + +flatbuffers::Offset FlatbufferSerializer::dictToFB( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& ivalue) { + const auto& dict = ivalue.toGenericDict(); + std::vector keys; + std::vector values; + keys.reserve(dict.size()); + values.reserve(dict.size()); + for (const auto& entry : dict) { + int key_index = storeIValueAndGetIndex(fbb, entry.key()); + keys.push_back(key_index); + int value_index = storeIValueAndGetIndex(fbb, entry.value()); + values.push_back(value_index); + } + return CreateDict( + fbb, + fbb.CreateVector(keys), + fbb.CreateVector(values), + fbb.CreateSharedString(ivalue.type()->annotation_str())); +} + +flatbuffers::Offset FlatbufferSerializer:: + classTypeToFB(FlatBufferBuilder& fbb, ClassTypePtr class_ptr) { + mobile::serialization::TypeType typetype = + mobile::serialization::TypeType::UNSET; + + flatbuffers::Offset< + flatbuffers::Vector>> + names_offset = 0; + c10::QualifiedName setstate_name(*class_ptr->name(), "__setstate__"); + const mobile::Function* setstate = mcu_->find_function(setstate_name); + if (setstate != nullptr) { + typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE; + } else if (class_ptr->findMethod("__setstate__")) { + typetype = mobile::serialization::TypeType::CUSTOM_CLASS; + } else { + size_t num_attr = class_ptr->numAttributes(); + std::vector> names; + std::vector type_index; + for (size_t i = 0; i < num_attr; ++i) { + names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i))); + } + names_offset = fbb.CreateVector(names); + typetype = mobile::serialization::TypeType::CLASS_WITH_FIELD; + } + + auto name_offset = fbb.CreateString(class_ptr->name()->qualifiedName()); + return CreateObjectType(fbb, name_offset, typetype, names_offset); +} + +uint32_t FlatbufferSerializer::storeFunctionAndGetIndex( + flatbuffers::FlatBufferBuilder& fbb, + const std::string& qn, + const mobile::Function& function) { + auto iter = qn_to_serialized_values_.find(qn); + if (iter != qn_to_serialized_values_.end()) { + return iter->second; + } + + auto offset = CreateIValue( + fbb, + mobile::serialization::IValueUnion::Function, + functionToFB(fbb, qn, function).Union()); + + uint32_t index = insertIValue(offset); + qn_to_serialized_values_[qn] = index; + return index; +} + +uint32_t FlatbufferSerializer::storeClassTypeAndGetIndex( + FlatBufferBuilder& fbb, + ClassTypePtr class_ptr) { + const auto& type_str = class_ptr->name()->qualifiedName(); + auto iter = qn_to_serialized_values_.find(type_str); + if (iter != qn_to_serialized_values_.end()) { + return iter->second; + } + + auto offset = classTypeToFB(fbb, class_ptr); + uint32_t res = obj_types_offset_.size(); + obj_types_offset_.push_back(offset); + qn_to_serialized_values_[type_str] = res; + return res; +} + +flatbuffers::Offset FlatbufferSerializer:: + objectToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) { + auto obj = ivalue.toObject(); + auto type = obj->type(); + // rename type? + // check getstate + + // save state as ivalue + flatbuffers::Offset> attrs = 0; + uint32_t state_index = 0; + uint32_t setstate_func_index = 0; + const auto qn = type->name()->qualifiedName() + ".__setstate__"; + auto getstate = type->findMethod("__getstate__"); + auto setstate = type->findMethod("__setstate__"); + if (getstate && setstate) { + auto state = (*getstate)({obj}); + state_index = storeIValueAndGetIndex(fbb, state); + auto func_index = qn_to_serialized_values_.find(qn); + if (func_index != qn_to_serialized_values_.end()) { + setstate_func_index = func_index->second; + } + } else { + size_t num_attr = type->numAttributes(); + std::vector tuple_index; + for (size_t i = 0; i < num_attr; ++i) { + tuple_index.push_back(storeIValueAndGetIndex(fbb, obj->getSlot(i))); + } + attrs = fbb.CreateVector(tuple_index); + } + + uint32_t type_index = storeClassTypeAndGetIndex(fbb, type); + return CreateObject(fbb, type_index, state_index, attrs, setstate_func_index); +} + +flatbuffers::Offset FlatbufferSerializer:: + FlatbufferSerializer::tensorToFB( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& ivalue) { + auto& tensor = ivalue.toTensor(); + bool quantized = tensor.is_quantized(); + const at::Storage& storage = tensor.storage(); + + flatbuffers::Offset qschema_offset = + 0; + if (quantized) { + double scale = 0; + int32_t zero_point = 0; + flatbuffers::Offset scales = 0; + flatbuffers::Offset zero_points = 0; + int32_t axis = 0; + + switch (tensor.qscheme()) { + case at::kPerTensorAffine: + scale = tensor.q_scale(); + zero_point = tensor.q_zero_point(); + break; + case at::kPerChannelAffineFloatQParams: + case at::kPerChannelAffine: { + scales = tensorToFB(fbb, tensor.q_per_channel_scales()); + zero_points = tensorToFB(fbb, tensor.q_per_channel_zero_points()); + axis = tensor.q_per_channel_axis(); + } break; + default: + TORCH_CHECK( + false, + "Unsupported tensor quantization type in serialization ", + toString(tensor.qscheme())); + break; + } + + qschema_offset = mobile::serialization::CreateQuantizedSchema( + fbb, + static_cast(tensor.qscheme()), + scale, + zero_point, + scales, + zero_points, + axis); + } + + void* addr = storage.unsafeGetStorageImpl(); + uint32_t storage_index = 0; + auto it = memoized_storage_map_.find(addr); + if (it != memoized_storage_map_.end()) { + storage_index = it->second; + } else { + storage_index = tensor_data_.size(); + memoized_storage_map_[addr] = storage_index; + tensor_data_.push_back(tensor); + } + + std::vector sizes{tensor.sizes().begin(), tensor.sizes().end()}; + std::vector strides{tensor.strides().begin(), tensor.strides().end()}; + + return CreateTensorMetadataDirect( + fbb, + /* storage_location_index */ storage_index, + /* scalar_type */ static_cast(tensor.scalar_type()), + /* int32_t storage_offset */ tensor.storage_offset(), + /* sizes */ &sizes, + /* strides */ &strides, + /* bool requires_grad */ tensor.requires_grad(), + /* qschema */ qschema_offset); +} + +uint32_t FlatbufferSerializer::storeIValueAndGetIndex( + flatbuffers::FlatBufferBuilder& fbb, + const IValue& ivalue) { + if (ivalue.isNone()) { + return kNoneIndex; + } + + try { + auto iter = cached_ivalues_.find(ivalue); + if (iter != cached_ivalues_.end()) { + return iter->second; + } + } catch (const std::runtime_error&) { + // Threw if ivalue is not hashable + } catch (const c10::Error&) { + // Threw if ivalue is don't have proper operator== + } + + auto offset = iValueToFB(fbb, ivalue); + uint32_t index = insertIValue(offset); + try { + cached_ivalues_[ivalue] = index; + } catch (const std::runtime_error&) { + } catch (const c10::Error&) { + } + + return index; +} + +flatbuffers::Offset FlatbufferSerializer:: + iValueToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) { + using mobile::serialization::IValueUnion; + + IValueUnion ivalue_type = IValueUnion::NONE; + flatbuffers::Offset offset = 0; + + if (ivalue.isTensor()) { + ivalue_type = IValueUnion::TensorMetadata; + offset = tensorToFB(fbb, ivalue).Union(); + } else if (ivalue.isTuple()) { + ivalue_type = IValueUnion::Tuple; + offset = tupleToFB(fbb, ivalue).Union(); + } else if (ivalue.isDouble()) { + ivalue_type = IValueUnion::Double; + offset = fbb.CreateStruct(mobile::serialization::Double(ivalue.toDouble())) + .Union(); + } else if (ivalue.isComplexDouble()) { + auto comp = ivalue.toComplexDouble(); + ivalue_type = IValueUnion::ComplexDouble; + offset = fbb.CreateStruct(mobile::serialization::ComplexDouble( + comp.real(), comp.imag())) + .Union(); + } else if (ivalue.isInt()) { + ivalue_type = IValueUnion::Int; + offset = + fbb.CreateStruct(mobile::serialization::Int(ivalue.toInt())).Union(); + } else if (ivalue.isBool()) { + ivalue_type = IValueUnion::Bool; + offset = + fbb.CreateStruct(mobile::serialization::Bool(ivalue.toBool())).Union(); + } else if (ivalue.isString()) { + ivalue_type = IValueUnion::String; + offset = mobile::serialization::CreateString( + fbb, fbb.CreateSharedString(ivalue.toString()->string())) + .Union(); + } else if (ivalue.isGenericDict()) { + ivalue_type = IValueUnion::Dict; + offset = dictToFB(fbb, ivalue).Union(); + } else if (ivalue.isNone()) { + ivalue_type = IValueUnion::NONE; + offset = 0; + } else if (ivalue.isIntList()) { + ivalue_type = IValueUnion::IntList; + offset = mobile::serialization::CreateIntList( + fbb, fbb.CreateVector(ivalue.toIntVector())) + .Union(); + } else if (ivalue.isDoubleList()) { + ivalue_type = IValueUnion::DoubleList; + offset = mobile::serialization::CreateDoubleList( + fbb, fbb.CreateVector(ivalue.toDoubleVector())) + .Union(); + } else if (ivalue.isBoolList()) { + ivalue_type = IValueUnion::BoolList; + auto boollist = ivalue.toBoolList(); + std::vector bool_vec(boollist.begin(), boollist.end()); + offset = + mobile::serialization::CreateBoolListDirect(fbb, &bool_vec).Union(); + } else if (ivalue.isList()) { + ivalue_type = IValueUnion::List; + offset = listToFB(fbb, ivalue).Union(); + } else if (ivalue.isObject()) { + ivalue_type = IValueUnion::Object; + offset = objectToFB(fbb, ivalue).Union(); + } else if (ivalue.isDevice()) { + ivalue_type = IValueUnion::Device; + offset = mobile::serialization::CreateDevice( + fbb, fbb.CreateSharedString(ivalue.toDevice().str())) + .Union(); + } else if (ivalue.isEnum()) { + const auto& enum_holder = ivalue.toEnumHolder(); + const auto& qualified_class_name = + enum_holder->type()->qualifiedClassName(); + uint32_t ival_pos = storeIValueAndGetIndex(fbb, enum_holder->value()); + ivalue_type = IValueUnion::EnumValue; + offset = mobile::serialization::CreateEnumValue( + fbb, + fbb.CreateSharedString(qualified_class_name.qualifiedName()), + ival_pos) + .Union(); + } else { + AT_ERROR("Invalid IValue type for serialization: ", ivalue.tagKind()); + } + return CreateIValue(fbb, ivalue_type, offset); +} + +} // namespace + +void save_mobile_module( + const mobile::Module& module, + const std::string& filename) { + FlatbufferSerializer fb_serializer; + auto buffer = fb_serializer.serializeModule(module, true); + std::fstream ofile(filename, std::ios::binary | std::ios::out); + ofile.write(reinterpret_cast(buffer.data()), buffer.size()); + ofile.close(); +} + +flatbuffers::DetachedBuffer save_mobile_module_to_bytes( + const mobile::Module& module) { + FlatbufferSerializer fb_serializer; + return fb_serializer.serializeModule(module, true); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.h b/torch/csrc/jit/serialization/flatbuffer_serializer.h new file mode 100644 index 000000000000..6f20a1f799ba --- /dev/null +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include // NOLINT + +namespace torch { +namespace jit { + +TORCH_API void save_mobile_module( + const mobile::Module& module, + const std::string& filename); +TORCH_API flatbuffers::DetachedBuffer save_mobile_module_to_bytes( + const mobile::Module& module); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/serialization/mobile_bytecode.fbs b/torch/csrc/jit/serialization/mobile_bytecode.fbs new file mode 100644 index 000000000000..b936a8ebd4ff --- /dev/null +++ b/torch/csrc/jit/serialization/mobile_bytecode.fbs @@ -0,0 +1,197 @@ +namespace torch.jit.mobile.serialization; + +struct Int { + int_val:long; +} + +struct Bool { + bool_val:bool; +} + +struct Double{ + double_val:double; +} + +struct PerTensorAffineSchema { + q_scale:double; + q_zero_point:int; +} + +table QuantizedSchema { + qscheme:byte; + scale:double; + zero_point:int; + scales:TensorMetadata; + zero_points:TensorMetadata; + axis:int; +} + +table TensorMetadata { + // torch._utils _rebuild_tensor_v2 + storage_location_index:uint; + // enum ScalarType + scalar_type:byte; + storage_offset:int; + sizes:[int]; + strides:[int]; + requires_grad:bool; + + // only set for quantized tensors + quantized_schema:QuantizedSchema; +} + +table String { + data:string; +} + +table Device { + str:string; +} + +table List { + items:[uint]; + annotation_str:string; // to recover key/val type +} + +table IntList { + items:[long]; +} + +table DoubleList { + items:[double]; +} + +table BoolList { + items:[bool]; +} + +table Tuple { + items:[uint]; +} + +table Dict { + keys:[uint]; + values:[uint]; + annotation_str:string; // to recover key/val type +} + +enum TypeType :ubyte { + UNSET, + CLASS_WITH_FIELD, + CUSTOM_CLASS, + CLASS_WITH_SETSTATE, + NON_OBJ, +} + +table ObjectType { + type_name:string; + type:TypeType; + // Below fields are optional + attr_names:[string]; +} + +table Object { + type_index:uint; + state:uint; + attrs:[uint]; + setstate_func:uint; +} + +struct ComplexDouble { + real:double; + imag:double; +} + +table EnumValue { + type_name:string; + value:uint; // index to ivalues; +} + + +struct Instruction { + // Should op be enum instead? + op:byte; + n:ushort; + x:int; +} + +table Operator { + name:string; + overload_name:string; + num_args_serialized:int = -1; +} + +table Arg { + name:string; + // Why do we use string to represent types + // rather than index into Code.types? + type:string; + default_value:uint; // position into ivalues +} + +table Schema { + arguments:[Arg]; + returns:[Arg]; +} + +table DebugInfo { + debug_handle:[long]; +} + +table Function { + qn:string; + instructions:[Instruction]; + operators:[Operator]; + constants:[uint]; // index to ivalue + type_annotations:[string]; + register_size:int; + schema:Schema; + debug_info:DebugInfo; + class_type:uint; // index into type table +} + +table StorageData { + data:[ubyte] (force_align:16); +} + +// Is it needed to represent other types? +union IValueUnion { + Int, + Bool, + Double, + ComplexDouble, + TensorMetadata, + String, + List, + Tuple, + Dict, + Object, + IntList, + DoubleList, + BoolList, + Device, + EnumValue, + Function, +} + +table IValue { + val:IValueUnion; +} + +table ExtraFile { + name:string; + content:string; +} + +table Module { + version:int; + extra_files:[ExtraFile]; + methods:[uint]; // index to ivalues + state_obj:uint; // index to ivalues + ivalues:[IValue]; + storage_data_size:int; // number of storage data; + storage_data:[StorageData]; + object_types:[ObjectType]; +} + +root_type Module; diff --git a/torch/csrc/jit/serialization/mobile_bytecode_generated.h b/torch/csrc/jit/serialization/mobile_bytecode_generated.h new file mode 100644 index 000000000000..082e9dce0069 --- /dev/null +++ b/torch/csrc/jit/serialization/mobile_bytecode_generated.h @@ -0,0 +1,2514 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_MOBILEBYTECODE_TORCH_JIT_MOBILE_SERIALIZATION_H_ +#define FLATBUFFERS_GENERATED_MOBILEBYTECODE_TORCH_JIT_MOBILE_SERIALIZATION_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace torch { +namespace jit { +namespace mobile { +namespace serialization { + +struct Int; + +struct Bool; + +struct Double; + +struct PerTensorAffineSchema; + +struct QuantizedSchema; +struct QuantizedSchemaBuilder; + +struct TensorMetadata; +struct TensorMetadataBuilder; + +struct String; +struct StringBuilder; + +struct Device; +struct DeviceBuilder; + +struct List; +struct ListBuilder; + +struct IntList; +struct IntListBuilder; + +struct DoubleList; +struct DoubleListBuilder; + +struct BoolList; +struct BoolListBuilder; + +struct Tuple; +struct TupleBuilder; + +struct Dict; +struct DictBuilder; + +struct ObjectType; +struct ObjectTypeBuilder; + +struct Object; +struct ObjectBuilder; + +struct ComplexDouble; + +struct EnumValue; +struct EnumValueBuilder; + +struct Instruction; + +struct Operator; +struct OperatorBuilder; + +struct Arg; +struct ArgBuilder; + +struct Schema; +struct SchemaBuilder; + +struct DebugInfo; +struct DebugInfoBuilder; + +struct Function; +struct FunctionBuilder; + +struct StorageData; +struct StorageDataBuilder; + +struct IValue; +struct IValueBuilder; + +struct ExtraFile; +struct ExtraFileBuilder; + +struct Module; +struct ModuleBuilder; + +enum class TypeType : uint8_t { + UNSET = 0, + CLASS_WITH_FIELD = 1, + CUSTOM_CLASS = 2, + CLASS_WITH_SETSTATE = 3, + NON_OBJ = 4, + MIN = UNSET, + MAX = NON_OBJ +}; + +inline const TypeType (&EnumValuesTypeType())[5] { + static const TypeType values[] = { + TypeType::UNSET, + TypeType::CLASS_WITH_FIELD, + TypeType::CUSTOM_CLASS, + TypeType::CLASS_WITH_SETSTATE, + TypeType::NON_OBJ + }; + return values; +} + +inline const char * const *EnumNamesTypeType() { + static const char * const names[6] = { + "UNSET", + "CLASS_WITH_FIELD", + "CUSTOM_CLASS", + "CLASS_WITH_SETSTATE", + "NON_OBJ", + nullptr + }; + return names; +} + +inline const char *EnumNameTypeType(TypeType e) { + if (flatbuffers::IsOutRange(e, TypeType::UNSET, TypeType::NON_OBJ)) return ""; + const size_t index = static_cast(e); + return EnumNamesTypeType()[index]; +} + +enum class IValueUnion : uint8_t { + NONE = 0, + Int = 1, + Bool = 2, + Double = 3, + ComplexDouble = 4, + TensorMetadata = 5, + String = 6, + List = 7, + Tuple = 8, + Dict = 9, + Object = 10, + IntList = 11, + DoubleList = 12, + BoolList = 13, + Device = 14, + EnumValue = 15, + Function = 16, + MIN = NONE, + MAX = Function +}; + +inline const IValueUnion (&EnumValuesIValueUnion())[17] { + static const IValueUnion values[] = { + IValueUnion::NONE, + IValueUnion::Int, + IValueUnion::Bool, + IValueUnion::Double, + IValueUnion::ComplexDouble, + IValueUnion::TensorMetadata, + IValueUnion::String, + IValueUnion::List, + IValueUnion::Tuple, + IValueUnion::Dict, + IValueUnion::Object, + IValueUnion::IntList, + IValueUnion::DoubleList, + IValueUnion::BoolList, + IValueUnion::Device, + IValueUnion::EnumValue, + IValueUnion::Function + }; + return values; +} + +inline const char * const *EnumNamesIValueUnion() { + static const char * const names[18] = { + "NONE", + "Int", + "Bool", + "Double", + "ComplexDouble", + "TensorMetadata", + "String", + "List", + "Tuple", + "Dict", + "Object", + "IntList", + "DoubleList", + "BoolList", + "Device", + "EnumValue", + "Function", + nullptr + }; + return names; +} + +inline const char *EnumNameIValueUnion(IValueUnion e) { + if (flatbuffers::IsOutRange(e, IValueUnion::NONE, IValueUnion::Function)) return ""; + const size_t index = static_cast(e); + return EnumNamesIValueUnion()[index]; +} + +template struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::NONE; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::Int; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::Bool; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::Double; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::ComplexDouble; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::TensorMetadata; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::String; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::List; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::Tuple; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::Dict; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::Object; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::IntList; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::DoubleList; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::BoolList; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::Device; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::EnumValue; +}; + +template<> struct IValueUnionTraits { + static const IValueUnion enum_value = IValueUnion::Function; +}; + +bool VerifyIValueUnion(flatbuffers::Verifier &verifier, const void *obj, IValueUnion type); +bool VerifyIValueUnionVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Int FLATBUFFERS_FINAL_CLASS { + private: + int64_t int_val_; + + public: + Int() + : int_val_(0) { + } + Int(int64_t _int_val) + : int_val_(flatbuffers::EndianScalar(_int_val)) { + } + int64_t int_val() const { + return flatbuffers::EndianScalar(int_val_); + } + void mutate_int_val(int64_t _int_val) { + flatbuffers::WriteScalar(&int_val_, _int_val); + } +}; +FLATBUFFERS_STRUCT_END(Int, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Bool FLATBUFFERS_FINAL_CLASS { + private: + uint8_t bool_val_; + + public: + Bool() + : bool_val_(0) { + } + Bool(bool _bool_val) + : bool_val_(flatbuffers::EndianScalar(static_cast(_bool_val))) { + } + bool bool_val() const { + return flatbuffers::EndianScalar(bool_val_) != 0; + } + void mutate_bool_val(bool _bool_val) { + flatbuffers::WriteScalar(&bool_val_, static_cast(_bool_val)); + } +}; +FLATBUFFERS_STRUCT_END(Bool, 1); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Double FLATBUFFERS_FINAL_CLASS { + private: + double double_val_; + + public: + Double() + : double_val_(0) { + } + Double(double _double_val) + : double_val_(flatbuffers::EndianScalar(_double_val)) { + } + double double_val() const { + return flatbuffers::EndianScalar(double_val_); + } + void mutate_double_val(double _double_val) { + flatbuffers::WriteScalar(&double_val_, _double_val); + } +}; +FLATBUFFERS_STRUCT_END(Double, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) PerTensorAffineSchema FLATBUFFERS_FINAL_CLASS { + private: + double q_scale_; + int32_t q_zero_point_; + int32_t padding0__; + + public: + PerTensorAffineSchema() + : q_scale_(0), + q_zero_point_(0), + padding0__(0) { + (void)padding0__; + } + PerTensorAffineSchema(double _q_scale, int32_t _q_zero_point) + : q_scale_(flatbuffers::EndianScalar(_q_scale)), + q_zero_point_(flatbuffers::EndianScalar(_q_zero_point)), + padding0__(0) { + (void)padding0__; + } + double q_scale() const { + return flatbuffers::EndianScalar(q_scale_); + } + void mutate_q_scale(double _q_scale) { + flatbuffers::WriteScalar(&q_scale_, _q_scale); + } + int32_t q_zero_point() const { + return flatbuffers::EndianScalar(q_zero_point_); + } + void mutate_q_zero_point(int32_t _q_zero_point) { + flatbuffers::WriteScalar(&q_zero_point_, _q_zero_point); + } +}; +FLATBUFFERS_STRUCT_END(PerTensorAffineSchema, 16); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) ComplexDouble FLATBUFFERS_FINAL_CLASS { + private: + double real_; + double imag_; + + public: + ComplexDouble() + : real_(0), + imag_(0) { + } + ComplexDouble(double _real, double _imag) + : real_(flatbuffers::EndianScalar(_real)), + imag_(flatbuffers::EndianScalar(_imag)) { + } + double real() const { + return flatbuffers::EndianScalar(real_); + } + void mutate_real(double _real) { + flatbuffers::WriteScalar(&real_, _real); + } + double imag() const { + return flatbuffers::EndianScalar(imag_); + } + void mutate_imag(double _imag) { + flatbuffers::WriteScalar(&imag_, _imag); + } +}; +FLATBUFFERS_STRUCT_END(ComplexDouble, 16); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Instruction FLATBUFFERS_FINAL_CLASS { + private: + int8_t op_; + int8_t padding0__; + uint16_t n_; + int32_t x_; + + public: + Instruction() + : op_(0), + padding0__(0), + n_(0), + x_(0) { + (void)padding0__; + } + Instruction(int8_t _op, uint16_t _n, int32_t _x) + : op_(flatbuffers::EndianScalar(_op)), + padding0__(0), + n_(flatbuffers::EndianScalar(_n)), + x_(flatbuffers::EndianScalar(_x)) { + (void)padding0__; + } + int8_t op() const { + return flatbuffers::EndianScalar(op_); + } + void mutate_op(int8_t _op) { + flatbuffers::WriteScalar(&op_, _op); + } + uint16_t n() const { + return flatbuffers::EndianScalar(n_); + } + void mutate_n(uint16_t _n) { + flatbuffers::WriteScalar(&n_, _n); + } + int32_t x() const { + return flatbuffers::EndianScalar(x_); + } + void mutate_x(int32_t _x) { + flatbuffers::WriteScalar(&x_, _x); + } +}; +FLATBUFFERS_STRUCT_END(Instruction, 8); + +struct QuantizedSchema FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef QuantizedSchemaBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_QSCHEME = 4, + VT_SCALE = 6, + VT_ZERO_POINT = 8, + VT_SCALES = 10, + VT_ZERO_POINTS = 12, + VT_AXIS = 14 + }; + int8_t qscheme() const { + return GetField(VT_QSCHEME, 0); + } + bool mutate_qscheme(int8_t _qscheme = 0) { + return SetField(VT_QSCHEME, _qscheme, 0); + } + double scale() const { + return GetField(VT_SCALE, 0.0); + } + bool mutate_scale(double _scale = 0.0) { + return SetField(VT_SCALE, _scale, 0.0); + } + int32_t zero_point() const { + return GetField(VT_ZERO_POINT, 0); + } + bool mutate_zero_point(int32_t _zero_point = 0) { + return SetField(VT_ZERO_POINT, _zero_point, 0); + } + const torch::jit::mobile::serialization::TensorMetadata *scales() const { + return GetPointer(VT_SCALES); + } + torch::jit::mobile::serialization::TensorMetadata *mutable_scales() { + return GetPointer(VT_SCALES); + } + const torch::jit::mobile::serialization::TensorMetadata *zero_points() const { + return GetPointer(VT_ZERO_POINTS); + } + torch::jit::mobile::serialization::TensorMetadata *mutable_zero_points() { + return GetPointer(VT_ZERO_POINTS); + } + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool mutate_axis(int32_t _axis = 0) { + return SetField(VT_AXIS, _axis, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_QSCHEME) && + VerifyField(verifier, VT_SCALE) && + VerifyField(verifier, VT_ZERO_POINT) && + VerifyOffset(verifier, VT_SCALES) && + verifier.VerifyTable(scales()) && + VerifyOffset(verifier, VT_ZERO_POINTS) && + verifier.VerifyTable(zero_points()) && + VerifyField(verifier, VT_AXIS) && + verifier.EndTable(); + } +}; + +struct QuantizedSchemaBuilder { + typedef QuantizedSchema Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_qscheme(int8_t qscheme) { + fbb_.AddElement(QuantizedSchema::VT_QSCHEME, qscheme, 0); + } + void add_scale(double scale) { + fbb_.AddElement(QuantizedSchema::VT_SCALE, scale, 0.0); + } + void add_zero_point(int32_t zero_point) { + fbb_.AddElement(QuantizedSchema::VT_ZERO_POINT, zero_point, 0); + } + void add_scales(flatbuffers::Offset scales) { + fbb_.AddOffset(QuantizedSchema::VT_SCALES, scales); + } + void add_zero_points(flatbuffers::Offset zero_points) { + fbb_.AddOffset(QuantizedSchema::VT_ZERO_POINTS, zero_points); + } + void add_axis(int32_t axis) { + fbb_.AddElement(QuantizedSchema::VT_AXIS, axis, 0); + } + explicit QuantizedSchemaBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateQuantizedSchema( + flatbuffers::FlatBufferBuilder &_fbb, + int8_t qscheme = 0, + double scale = 0.0, + int32_t zero_point = 0, + flatbuffers::Offset scales = 0, + flatbuffers::Offset zero_points = 0, + int32_t axis = 0) { + QuantizedSchemaBuilder builder_(_fbb); + builder_.add_scale(scale); + builder_.add_axis(axis); + builder_.add_zero_points(zero_points); + builder_.add_scales(scales); + builder_.add_zero_point(zero_point); + builder_.add_qscheme(qscheme); + return builder_.Finish(); +} + +struct TensorMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TensorMetadataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_STORAGE_LOCATION_INDEX = 4, + VT_SCALAR_TYPE = 6, + VT_STORAGE_OFFSET = 8, + VT_SIZES = 10, + VT_STRIDES = 12, + VT_REQUIRES_GRAD = 14, + VT_QUANTIZED_SCHEMA = 16 + }; + uint32_t storage_location_index() const { + return GetField(VT_STORAGE_LOCATION_INDEX, 0); + } + bool mutate_storage_location_index(uint32_t _storage_location_index = 0) { + return SetField(VT_STORAGE_LOCATION_INDEX, _storage_location_index, 0); + } + int8_t scalar_type() const { + return GetField(VT_SCALAR_TYPE, 0); + } + bool mutate_scalar_type(int8_t _scalar_type = 0) { + return SetField(VT_SCALAR_TYPE, _scalar_type, 0); + } + int32_t storage_offset() const { + return GetField(VT_STORAGE_OFFSET, 0); + } + bool mutate_storage_offset(int32_t _storage_offset = 0) { + return SetField(VT_STORAGE_OFFSET, _storage_offset, 0); + } + const flatbuffers::Vector *sizes() const { + return GetPointer *>(VT_SIZES); + } + flatbuffers::Vector *mutable_sizes() { + return GetPointer *>(VT_SIZES); + } + const flatbuffers::Vector *strides() const { + return GetPointer *>(VT_STRIDES); + } + flatbuffers::Vector *mutable_strides() { + return GetPointer *>(VT_STRIDES); + } + bool requires_grad() const { + return GetField(VT_REQUIRES_GRAD, 0) != 0; + } + bool mutate_requires_grad(bool _requires_grad = 0) { + return SetField(VT_REQUIRES_GRAD, static_cast(_requires_grad), 0); + } + const torch::jit::mobile::serialization::QuantizedSchema *quantized_schema() const { + return GetPointer(VT_QUANTIZED_SCHEMA); + } + torch::jit::mobile::serialization::QuantizedSchema *mutable_quantized_schema() { + return GetPointer(VT_QUANTIZED_SCHEMA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_STORAGE_LOCATION_INDEX) && + VerifyField(verifier, VT_SCALAR_TYPE) && + VerifyField(verifier, VT_STORAGE_OFFSET) && + VerifyOffset(verifier, VT_SIZES) && + verifier.VerifyVector(sizes()) && + VerifyOffset(verifier, VT_STRIDES) && + verifier.VerifyVector(strides()) && + VerifyField(verifier, VT_REQUIRES_GRAD) && + VerifyOffset(verifier, VT_QUANTIZED_SCHEMA) && + verifier.VerifyTable(quantized_schema()) && + verifier.EndTable(); + } +}; + +struct TensorMetadataBuilder { + typedef TensorMetadata Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_storage_location_index(uint32_t storage_location_index) { + fbb_.AddElement(TensorMetadata::VT_STORAGE_LOCATION_INDEX, storage_location_index, 0); + } + void add_scalar_type(int8_t scalar_type) { + fbb_.AddElement(TensorMetadata::VT_SCALAR_TYPE, scalar_type, 0); + } + void add_storage_offset(int32_t storage_offset) { + fbb_.AddElement(TensorMetadata::VT_STORAGE_OFFSET, storage_offset, 0); + } + void add_sizes(flatbuffers::Offset> sizes) { + fbb_.AddOffset(TensorMetadata::VT_SIZES, sizes); + } + void add_strides(flatbuffers::Offset> strides) { + fbb_.AddOffset(TensorMetadata::VT_STRIDES, strides); + } + void add_requires_grad(bool requires_grad) { + fbb_.AddElement(TensorMetadata::VT_REQUIRES_GRAD, static_cast(requires_grad), 0); + } + void add_quantized_schema(flatbuffers::Offset quantized_schema) { + fbb_.AddOffset(TensorMetadata::VT_QUANTIZED_SCHEMA, quantized_schema); + } + explicit TensorMetadataBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTensorMetadata( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t storage_location_index = 0, + int8_t scalar_type = 0, + int32_t storage_offset = 0, + flatbuffers::Offset> sizes = 0, + flatbuffers::Offset> strides = 0, + bool requires_grad = false, + flatbuffers::Offset quantized_schema = 0) { + TensorMetadataBuilder builder_(_fbb); + builder_.add_quantized_schema(quantized_schema); + builder_.add_strides(strides); + builder_.add_sizes(sizes); + builder_.add_storage_offset(storage_offset); + builder_.add_storage_location_index(storage_location_index); + builder_.add_requires_grad(requires_grad); + builder_.add_scalar_type(scalar_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTensorMetadataDirect( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t storage_location_index = 0, + int8_t scalar_type = 0, + int32_t storage_offset = 0, + const std::vector *sizes = nullptr, + const std::vector *strides = nullptr, + bool requires_grad = false, + flatbuffers::Offset quantized_schema = 0) { + auto sizes__ = sizes ? _fbb.CreateVector(*sizes) : 0; + auto strides__ = strides ? _fbb.CreateVector(*strides) : 0; + return torch::jit::mobile::serialization::CreateTensorMetadata( + _fbb, + storage_location_index, + scalar_type, + storage_offset, + sizes__, + strides__, + requires_grad, + quantized_schema); +} + +struct String FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef StringBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::String *data() const { + return GetPointer(VT_DATA); + } + flatbuffers::String *mutable_data() { + return GetPointer(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyString(data()) && + verifier.EndTable(); + } +}; + +struct StringBuilder { + typedef String Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset data) { + fbb_.AddOffset(String::VT_DATA, data); + } + explicit StringBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateString( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset data = 0) { + StringBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateStringDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *data = nullptr) { + auto data__ = data ? _fbb.CreateString(data) : 0; + return torch::jit::mobile::serialization::CreateString( + _fbb, + data__); +} + +struct Device FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DeviceBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_STR = 4 + }; + const flatbuffers::String *str() const { + return GetPointer(VT_STR); + } + flatbuffers::String *mutable_str() { + return GetPointer(VT_STR); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_STR) && + verifier.VerifyString(str()) && + verifier.EndTable(); + } +}; + +struct DeviceBuilder { + typedef Device Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_str(flatbuffers::Offset str) { + fbb_.AddOffset(Device::VT_STR, str); + } + explicit DeviceBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDevice( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset str = 0) { + DeviceBuilder builder_(_fbb); + builder_.add_str(str); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDeviceDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *str = nullptr) { + auto str__ = str ? _fbb.CreateString(str) : 0; + return torch::jit::mobile::serialization::CreateDevice( + _fbb, + str__); +} + +struct List FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ListBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ITEMS = 4, + VT_ANNOTATION_STR = 6 + }; + const flatbuffers::Vector *items() const { + return GetPointer *>(VT_ITEMS); + } + flatbuffers::Vector *mutable_items() { + return GetPointer *>(VT_ITEMS); + } + const flatbuffers::String *annotation_str() const { + return GetPointer(VT_ANNOTATION_STR); + } + flatbuffers::String *mutable_annotation_str() { + return GetPointer(VT_ANNOTATION_STR); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_ITEMS) && + verifier.VerifyVector(items()) && + VerifyOffset(verifier, VT_ANNOTATION_STR) && + verifier.VerifyString(annotation_str()) && + verifier.EndTable(); + } +}; + +struct ListBuilder { + typedef List Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_items(flatbuffers::Offset> items) { + fbb_.AddOffset(List::VT_ITEMS, items); + } + void add_annotation_str(flatbuffers::Offset annotation_str) { + fbb_.AddOffset(List::VT_ANNOTATION_STR, annotation_str); + } + explicit ListBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateList( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> items = 0, + flatbuffers::Offset annotation_str = 0) { + ListBuilder builder_(_fbb); + builder_.add_annotation_str(annotation_str); + builder_.add_items(items); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateListDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *items = nullptr, + const char *annotation_str = nullptr) { + auto items__ = items ? _fbb.CreateVector(*items) : 0; + auto annotation_str__ = annotation_str ? _fbb.CreateString(annotation_str) : 0; + return torch::jit::mobile::serialization::CreateList( + _fbb, + items__, + annotation_str__); +} + +struct IntList FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef IntListBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ITEMS = 4 + }; + const flatbuffers::Vector *items() const { + return GetPointer *>(VT_ITEMS); + } + flatbuffers::Vector *mutable_items() { + return GetPointer *>(VT_ITEMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_ITEMS) && + verifier.VerifyVector(items()) && + verifier.EndTable(); + } +}; + +struct IntListBuilder { + typedef IntList Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_items(flatbuffers::Offset> items) { + fbb_.AddOffset(IntList::VT_ITEMS, items); + } + explicit IntListBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateIntList( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> items = 0) { + IntListBuilder builder_(_fbb); + builder_.add_items(items); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateIntListDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *items = nullptr) { + auto items__ = items ? _fbb.CreateVector(*items) : 0; + return torch::jit::mobile::serialization::CreateIntList( + _fbb, + items__); +} + +struct DoubleList FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DoubleListBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ITEMS = 4 + }; + const flatbuffers::Vector *items() const { + return GetPointer *>(VT_ITEMS); + } + flatbuffers::Vector *mutable_items() { + return GetPointer *>(VT_ITEMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_ITEMS) && + verifier.VerifyVector(items()) && + verifier.EndTable(); + } +}; + +struct DoubleListBuilder { + typedef DoubleList Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_items(flatbuffers::Offset> items) { + fbb_.AddOffset(DoubleList::VT_ITEMS, items); + } + explicit DoubleListBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDoubleList( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> items = 0) { + DoubleListBuilder builder_(_fbb); + builder_.add_items(items); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDoubleListDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *items = nullptr) { + auto items__ = items ? _fbb.CreateVector(*items) : 0; + return torch::jit::mobile::serialization::CreateDoubleList( + _fbb, + items__); +} + +struct BoolList FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BoolListBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ITEMS = 4 + }; + const flatbuffers::Vector *items() const { + return GetPointer *>(VT_ITEMS); + } + flatbuffers::Vector *mutable_items() { + return GetPointer *>(VT_ITEMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_ITEMS) && + verifier.VerifyVector(items()) && + verifier.EndTable(); + } +}; + +struct BoolListBuilder { + typedef BoolList Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_items(flatbuffers::Offset> items) { + fbb_.AddOffset(BoolList::VT_ITEMS, items); + } + explicit BoolListBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateBoolList( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> items = 0) { + BoolListBuilder builder_(_fbb); + builder_.add_items(items); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateBoolListDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *items = nullptr) { + auto items__ = items ? _fbb.CreateVector(*items) : 0; + return torch::jit::mobile::serialization::CreateBoolList( + _fbb, + items__); +} + +struct Tuple FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TupleBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ITEMS = 4 + }; + const flatbuffers::Vector *items() const { + return GetPointer *>(VT_ITEMS); + } + flatbuffers::Vector *mutable_items() { + return GetPointer *>(VT_ITEMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_ITEMS) && + verifier.VerifyVector(items()) && + verifier.EndTable(); + } +}; + +struct TupleBuilder { + typedef Tuple Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_items(flatbuffers::Offset> items) { + fbb_.AddOffset(Tuple::VT_ITEMS, items); + } + explicit TupleBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTuple( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> items = 0) { + TupleBuilder builder_(_fbb); + builder_.add_items(items); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTupleDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *items = nullptr) { + auto items__ = items ? _fbb.CreateVector(*items) : 0; + return torch::jit::mobile::serialization::CreateTuple( + _fbb, + items__); +} + +struct Dict FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DictBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KEYS = 4, + VT_VALUES = 6, + VT_ANNOTATION_STR = 8 + }; + const flatbuffers::Vector *keys() const { + return GetPointer *>(VT_KEYS); + } + flatbuffers::Vector *mutable_keys() { + return GetPointer *>(VT_KEYS); + } + const flatbuffers::Vector *values() const { + return GetPointer *>(VT_VALUES); + } + flatbuffers::Vector *mutable_values() { + return GetPointer *>(VT_VALUES); + } + const flatbuffers::String *annotation_str() const { + return GetPointer(VT_ANNOTATION_STR); + } + flatbuffers::String *mutable_annotation_str() { + return GetPointer(VT_ANNOTATION_STR); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_KEYS) && + verifier.VerifyVector(keys()) && + VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && + VerifyOffset(verifier, VT_ANNOTATION_STR) && + verifier.VerifyString(annotation_str()) && + verifier.EndTable(); + } +}; + +struct DictBuilder { + typedef Dict Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_keys(flatbuffers::Offset> keys) { + fbb_.AddOffset(Dict::VT_KEYS, keys); + } + void add_values(flatbuffers::Offset> values) { + fbb_.AddOffset(Dict::VT_VALUES, values); + } + void add_annotation_str(flatbuffers::Offset annotation_str) { + fbb_.AddOffset(Dict::VT_ANNOTATION_STR, annotation_str); + } + explicit DictBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDict( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> keys = 0, + flatbuffers::Offset> values = 0, + flatbuffers::Offset annotation_str = 0) { + DictBuilder builder_(_fbb); + builder_.add_annotation_str(annotation_str); + builder_.add_values(values); + builder_.add_keys(keys); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDictDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *keys = nullptr, + const std::vector *values = nullptr, + const char *annotation_str = nullptr) { + auto keys__ = keys ? _fbb.CreateVector(*keys) : 0; + auto values__ = values ? _fbb.CreateVector(*values) : 0; + auto annotation_str__ = annotation_str ? _fbb.CreateString(annotation_str) : 0; + return torch::jit::mobile::serialization::CreateDict( + _fbb, + keys__, + values__, + annotation_str__); +} + +struct ObjectType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ObjectTypeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE_NAME = 4, + VT_TYPE = 6, + VT_ATTR_NAMES = 8 + }; + const flatbuffers::String *type_name() const { + return GetPointer(VT_TYPE_NAME); + } + flatbuffers::String *mutable_type_name() { + return GetPointer(VT_TYPE_NAME); + } + torch::jit::mobile::serialization::TypeType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } + bool mutate_type(torch::jit::mobile::serialization::TypeType _type = static_cast(0)) { + return SetField(VT_TYPE, static_cast(_type), 0); + } + const flatbuffers::Vector> *attr_names() const { + return GetPointer> *>(VT_ATTR_NAMES); + } + flatbuffers::Vector> *mutable_attr_names() { + return GetPointer> *>(VT_ATTR_NAMES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE_NAME) && + verifier.VerifyString(type_name()) && + VerifyField(verifier, VT_TYPE) && + VerifyOffset(verifier, VT_ATTR_NAMES) && + verifier.VerifyVector(attr_names()) && + verifier.VerifyVectorOfStrings(attr_names()) && + verifier.EndTable(); + } +}; + +struct ObjectTypeBuilder { + typedef ObjectType Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type_name(flatbuffers::Offset type_name) { + fbb_.AddOffset(ObjectType::VT_TYPE_NAME, type_name); + } + void add_type(torch::jit::mobile::serialization::TypeType type) { + fbb_.AddElement(ObjectType::VT_TYPE, static_cast(type), 0); + } + void add_attr_names(flatbuffers::Offset>> attr_names) { + fbb_.AddOffset(ObjectType::VT_ATTR_NAMES, attr_names); + } + explicit ObjectTypeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateObjectType( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset type_name = 0, + torch::jit::mobile::serialization::TypeType type = torch::jit::mobile::serialization::TypeType::UNSET, + flatbuffers::Offset>> attr_names = 0) { + ObjectTypeBuilder builder_(_fbb); + builder_.add_attr_names(attr_names); + builder_.add_type_name(type_name); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateObjectTypeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *type_name = nullptr, + torch::jit::mobile::serialization::TypeType type = torch::jit::mobile::serialization::TypeType::UNSET, + const std::vector> *attr_names = nullptr) { + auto type_name__ = type_name ? _fbb.CreateString(type_name) : 0; + auto attr_names__ = attr_names ? _fbb.CreateVector>(*attr_names) : 0; + return torch::jit::mobile::serialization::CreateObjectType( + _fbb, + type_name__, + type, + attr_names__); +} + +struct Object FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ObjectBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE_INDEX = 4, + VT_STATE = 6, + VT_ATTRS = 8, + VT_SETSTATE_FUNC = 10 + }; + uint32_t type_index() const { + return GetField(VT_TYPE_INDEX, 0); + } + bool mutate_type_index(uint32_t _type_index = 0) { + return SetField(VT_TYPE_INDEX, _type_index, 0); + } + uint32_t state() const { + return GetField(VT_STATE, 0); + } + bool mutate_state(uint32_t _state = 0) { + return SetField(VT_STATE, _state, 0); + } + const flatbuffers::Vector *attrs() const { + return GetPointer *>(VT_ATTRS); + } + flatbuffers::Vector *mutable_attrs() { + return GetPointer *>(VT_ATTRS); + } + uint32_t setstate_func() const { + return GetField(VT_SETSTATE_FUNC, 0); + } + bool mutate_setstate_func(uint32_t _setstate_func = 0) { + return SetField(VT_SETSTATE_FUNC, _setstate_func, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_TYPE_INDEX) && + VerifyField(verifier, VT_STATE) && + VerifyOffset(verifier, VT_ATTRS) && + verifier.VerifyVector(attrs()) && + VerifyField(verifier, VT_SETSTATE_FUNC) && + verifier.EndTable(); + } +}; + +struct ObjectBuilder { + typedef Object Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type_index(uint32_t type_index) { + fbb_.AddElement(Object::VT_TYPE_INDEX, type_index, 0); + } + void add_state(uint32_t state) { + fbb_.AddElement(Object::VT_STATE, state, 0); + } + void add_attrs(flatbuffers::Offset> attrs) { + fbb_.AddOffset(Object::VT_ATTRS, attrs); + } + void add_setstate_func(uint32_t setstate_func) { + fbb_.AddElement(Object::VT_SETSTATE_FUNC, setstate_func, 0); + } + explicit ObjectBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateObject( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t type_index = 0, + uint32_t state = 0, + flatbuffers::Offset> attrs = 0, + uint32_t setstate_func = 0) { + ObjectBuilder builder_(_fbb); + builder_.add_setstate_func(setstate_func); + builder_.add_attrs(attrs); + builder_.add_state(state); + builder_.add_type_index(type_index); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateObjectDirect( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t type_index = 0, + uint32_t state = 0, + const std::vector *attrs = nullptr, + uint32_t setstate_func = 0) { + auto attrs__ = attrs ? _fbb.CreateVector(*attrs) : 0; + return torch::jit::mobile::serialization::CreateObject( + _fbb, + type_index, + state, + attrs__, + setstate_func); +} + +struct EnumValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef EnumValueBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE_NAME = 4, + VT_VALUE = 6 + }; + const flatbuffers::String *type_name() const { + return GetPointer(VT_TYPE_NAME); + } + flatbuffers::String *mutable_type_name() { + return GetPointer(VT_TYPE_NAME); + } + uint32_t value() const { + return GetField(VT_VALUE, 0); + } + bool mutate_value(uint32_t _value = 0) { + return SetField(VT_VALUE, _value, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE_NAME) && + verifier.VerifyString(type_name()) && + VerifyField(verifier, VT_VALUE) && + verifier.EndTable(); + } +}; + +struct EnumValueBuilder { + typedef EnumValue Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type_name(flatbuffers::Offset type_name) { + fbb_.AddOffset(EnumValue::VT_TYPE_NAME, type_name); + } + void add_value(uint32_t value) { + fbb_.AddElement(EnumValue::VT_VALUE, value, 0); + } + explicit EnumValueBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateEnumValue( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset type_name = 0, + uint32_t value = 0) { + EnumValueBuilder builder_(_fbb); + builder_.add_value(value); + builder_.add_type_name(type_name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateEnumValueDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *type_name = nullptr, + uint32_t value = 0) { + auto type_name__ = type_name ? _fbb.CreateString(type_name) : 0; + return torch::jit::mobile::serialization::CreateEnumValue( + _fbb, + type_name__, + value); +} + +struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_OVERLOAD_NAME = 6, + VT_NUM_ARGS_SERIALIZED = 8 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + flatbuffers::String *mutable_name() { + return GetPointer(VT_NAME); + } + const flatbuffers::String *overload_name() const { + return GetPointer(VT_OVERLOAD_NAME); + } + flatbuffers::String *mutable_overload_name() { + return GetPointer(VT_OVERLOAD_NAME); + } + int32_t num_args_serialized() const { + return GetField(VT_NUM_ARGS_SERIALIZED, -1); + } + bool mutate_num_args_serialized(int32_t _num_args_serialized = -1) { + return SetField(VT_NUM_ARGS_SERIALIZED, _num_args_serialized, -1); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_OVERLOAD_NAME) && + verifier.VerifyString(overload_name()) && + VerifyField(verifier, VT_NUM_ARGS_SERIALIZED) && + verifier.EndTable(); + } +}; + +struct OperatorBuilder { + typedef Operator Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(Operator::VT_NAME, name); + } + void add_overload_name(flatbuffers::Offset overload_name) { + fbb_.AddOffset(Operator::VT_OVERLOAD_NAME, overload_name); + } + void add_num_args_serialized(int32_t num_args_serialized) { + fbb_.AddElement(Operator::VT_NUM_ARGS_SERIALIZED, num_args_serialized, -1); + } + explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOperator( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + flatbuffers::Offset overload_name = 0, + int32_t num_args_serialized = -1) { + OperatorBuilder builder_(_fbb); + builder_.add_num_args_serialized(num_args_serialized); + builder_.add_overload_name(overload_name); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOperatorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const char *overload_name = nullptr, + int32_t num_args_serialized = -1) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto overload_name__ = overload_name ? _fbb.CreateString(overload_name) : 0; + return torch::jit::mobile::serialization::CreateOperator( + _fbb, + name__, + overload_name__, + num_args_serialized); +} + +struct Arg FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ArgBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_TYPE = 6, + VT_DEFAULT_VALUE = 8 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + flatbuffers::String *mutable_name() { + return GetPointer(VT_NAME); + } + const flatbuffers::String *type() const { + return GetPointer(VT_TYPE); + } + flatbuffers::String *mutable_type() { + return GetPointer(VT_TYPE); + } + uint32_t default_value() const { + return GetField(VT_DEFAULT_VALUE, 0); + } + bool mutate_default_value(uint32_t _default_value = 0) { + return SetField(VT_DEFAULT_VALUE, _default_value, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_TYPE) && + verifier.VerifyString(type()) && + VerifyField(verifier, VT_DEFAULT_VALUE) && + verifier.EndTable(); + } +}; + +struct ArgBuilder { + typedef Arg Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(Arg::VT_NAME, name); + } + void add_type(flatbuffers::Offset type) { + fbb_.AddOffset(Arg::VT_TYPE, type); + } + void add_default_value(uint32_t default_value) { + fbb_.AddElement(Arg::VT_DEFAULT_VALUE, default_value, 0); + } + explicit ArgBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateArg( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + flatbuffers::Offset type = 0, + uint32_t default_value = 0) { + ArgBuilder builder_(_fbb); + builder_.add_default_value(default_value); + builder_.add_type(type); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateArgDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const char *type = nullptr, + uint32_t default_value = 0) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto type__ = type ? _fbb.CreateString(type) : 0; + return torch::jit::mobile::serialization::CreateArg( + _fbb, + name__, + type__, + default_value); +} + +struct Schema FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SchemaBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ARGUMENTS = 4, + VT_RETURNS = 6 + }; + const flatbuffers::Vector> *arguments() const { + return GetPointer> *>(VT_ARGUMENTS); + } + flatbuffers::Vector> *mutable_arguments() { + return GetPointer> *>(VT_ARGUMENTS); + } + const flatbuffers::Vector> *returns() const { + return GetPointer> *>(VT_RETURNS); + } + flatbuffers::Vector> *mutable_returns() { + return GetPointer> *>(VT_RETURNS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_ARGUMENTS) && + verifier.VerifyVector(arguments()) && + verifier.VerifyVectorOfTables(arguments()) && + VerifyOffset(verifier, VT_RETURNS) && + verifier.VerifyVector(returns()) && + verifier.VerifyVectorOfTables(returns()) && + verifier.EndTable(); + } +}; + +struct SchemaBuilder { + typedef Schema Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_arguments(flatbuffers::Offset>> arguments) { + fbb_.AddOffset(Schema::VT_ARGUMENTS, arguments); + } + void add_returns(flatbuffers::Offset>> returns) { + fbb_.AddOffset(Schema::VT_RETURNS, returns); + } + explicit SchemaBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSchema( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> arguments = 0, + flatbuffers::Offset>> returns = 0) { + SchemaBuilder builder_(_fbb); + builder_.add_returns(returns); + builder_.add_arguments(arguments); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateSchemaDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *arguments = nullptr, + const std::vector> *returns = nullptr) { + auto arguments__ = arguments ? _fbb.CreateVector>(*arguments) : 0; + auto returns__ = returns ? _fbb.CreateVector>(*returns) : 0; + return torch::jit::mobile::serialization::CreateSchema( + _fbb, + arguments__, + returns__); +} + +struct DebugInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DebugInfoBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DEBUG_HANDLE = 4 + }; + const flatbuffers::Vector *debug_handle() const { + return GetPointer *>(VT_DEBUG_HANDLE); + } + flatbuffers::Vector *mutable_debug_handle() { + return GetPointer *>(VT_DEBUG_HANDLE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DEBUG_HANDLE) && + verifier.VerifyVector(debug_handle()) && + verifier.EndTable(); + } +}; + +struct DebugInfoBuilder { + typedef DebugInfo Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_debug_handle(flatbuffers::Offset> debug_handle) { + fbb_.AddOffset(DebugInfo::VT_DEBUG_HANDLE, debug_handle); + } + explicit DebugInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDebugInfo( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> debug_handle = 0) { + DebugInfoBuilder builder_(_fbb); + builder_.add_debug_handle(debug_handle); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDebugInfoDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *debug_handle = nullptr) { + auto debug_handle__ = debug_handle ? _fbb.CreateVector(*debug_handle) : 0; + return torch::jit::mobile::serialization::CreateDebugInfo( + _fbb, + debug_handle__); +} + +struct Function FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FunctionBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_QN = 4, + VT_INSTRUCTIONS = 6, + VT_OPERATORS = 8, + VT_CONSTANTS = 10, + VT_TYPE_ANNOTATIONS = 12, + VT_REGISTER_SIZE = 14, + VT_SCHEMA = 16, + VT_DEBUG_INFO = 18, + VT_CLASS_TYPE = 20 + }; + const flatbuffers::String *qn() const { + return GetPointer(VT_QN); + } + flatbuffers::String *mutable_qn() { + return GetPointer(VT_QN); + } + const flatbuffers::Vector *instructions() const { + return GetPointer *>(VT_INSTRUCTIONS); + } + flatbuffers::Vector *mutable_instructions() { + return GetPointer *>(VT_INSTRUCTIONS); + } + const flatbuffers::Vector> *operators() const { + return GetPointer> *>(VT_OPERATORS); + } + flatbuffers::Vector> *mutable_operators() { + return GetPointer> *>(VT_OPERATORS); + } + const flatbuffers::Vector *constants() const { + return GetPointer *>(VT_CONSTANTS); + } + flatbuffers::Vector *mutable_constants() { + return GetPointer *>(VT_CONSTANTS); + } + const flatbuffers::Vector> *type_annotations() const { + return GetPointer> *>(VT_TYPE_ANNOTATIONS); + } + flatbuffers::Vector> *mutable_type_annotations() { + return GetPointer> *>(VT_TYPE_ANNOTATIONS); + } + int32_t register_size() const { + return GetField(VT_REGISTER_SIZE, 0); + } + bool mutate_register_size(int32_t _register_size = 0) { + return SetField(VT_REGISTER_SIZE, _register_size, 0); + } + const torch::jit::mobile::serialization::Schema *schema() const { + return GetPointer(VT_SCHEMA); + } + torch::jit::mobile::serialization::Schema *mutable_schema() { + return GetPointer(VT_SCHEMA); + } + const torch::jit::mobile::serialization::DebugInfo *debug_info() const { + return GetPointer(VT_DEBUG_INFO); + } + torch::jit::mobile::serialization::DebugInfo *mutable_debug_info() { + return GetPointer(VT_DEBUG_INFO); + } + uint32_t class_type() const { + return GetField(VT_CLASS_TYPE, 0); + } + bool mutate_class_type(uint32_t _class_type = 0) { + return SetField(VT_CLASS_TYPE, _class_type, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_QN) && + verifier.VerifyString(qn()) && + VerifyOffset(verifier, VT_INSTRUCTIONS) && + verifier.VerifyVector(instructions()) && + VerifyOffset(verifier, VT_OPERATORS) && + verifier.VerifyVector(operators()) && + verifier.VerifyVectorOfTables(operators()) && + VerifyOffset(verifier, VT_CONSTANTS) && + verifier.VerifyVector(constants()) && + VerifyOffset(verifier, VT_TYPE_ANNOTATIONS) && + verifier.VerifyVector(type_annotations()) && + verifier.VerifyVectorOfStrings(type_annotations()) && + VerifyField(verifier, VT_REGISTER_SIZE) && + VerifyOffset(verifier, VT_SCHEMA) && + verifier.VerifyTable(schema()) && + VerifyOffset(verifier, VT_DEBUG_INFO) && + verifier.VerifyTable(debug_info()) && + VerifyField(verifier, VT_CLASS_TYPE) && + verifier.EndTable(); + } +}; + +struct FunctionBuilder { + typedef Function Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_qn(flatbuffers::Offset qn) { + fbb_.AddOffset(Function::VT_QN, qn); + } + void add_instructions(flatbuffers::Offset> instructions) { + fbb_.AddOffset(Function::VT_INSTRUCTIONS, instructions); + } + void add_operators(flatbuffers::Offset>> operators) { + fbb_.AddOffset(Function::VT_OPERATORS, operators); + } + void add_constants(flatbuffers::Offset> constants) { + fbb_.AddOffset(Function::VT_CONSTANTS, constants); + } + void add_type_annotations(flatbuffers::Offset>> type_annotations) { + fbb_.AddOffset(Function::VT_TYPE_ANNOTATIONS, type_annotations); + } + void add_register_size(int32_t register_size) { + fbb_.AddElement(Function::VT_REGISTER_SIZE, register_size, 0); + } + void add_schema(flatbuffers::Offset schema) { + fbb_.AddOffset(Function::VT_SCHEMA, schema); + } + void add_debug_info(flatbuffers::Offset debug_info) { + fbb_.AddOffset(Function::VT_DEBUG_INFO, debug_info); + } + void add_class_type(uint32_t class_type) { + fbb_.AddElement(Function::VT_CLASS_TYPE, class_type, 0); + } + explicit FunctionBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateFunction( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset qn = 0, + flatbuffers::Offset> instructions = 0, + flatbuffers::Offset>> operators = 0, + flatbuffers::Offset> constants = 0, + flatbuffers::Offset>> type_annotations = 0, + int32_t register_size = 0, + flatbuffers::Offset schema = 0, + flatbuffers::Offset debug_info = 0, + uint32_t class_type = 0) { + FunctionBuilder builder_(_fbb); + builder_.add_class_type(class_type); + builder_.add_debug_info(debug_info); + builder_.add_schema(schema); + builder_.add_register_size(register_size); + builder_.add_type_annotations(type_annotations); + builder_.add_constants(constants); + builder_.add_operators(operators); + builder_.add_instructions(instructions); + builder_.add_qn(qn); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateFunctionDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *qn = nullptr, + const std::vector *instructions = nullptr, + const std::vector> *operators = nullptr, + const std::vector *constants = nullptr, + const std::vector> *type_annotations = nullptr, + int32_t register_size = 0, + flatbuffers::Offset schema = 0, + flatbuffers::Offset debug_info = 0, + uint32_t class_type = 0) { + auto qn__ = qn ? _fbb.CreateString(qn) : 0; + auto instructions__ = instructions ? _fbb.CreateVectorOfStructs(*instructions) : 0; + auto operators__ = operators ? _fbb.CreateVector>(*operators) : 0; + auto constants__ = constants ? _fbb.CreateVector(*constants) : 0; + auto type_annotations__ = type_annotations ? _fbb.CreateVector>(*type_annotations) : 0; + return torch::jit::mobile::serialization::CreateFunction( + _fbb, + qn__, + instructions__, + operators__, + constants__, + type_annotations__, + register_size, + schema, + debug_info, + class_type); +} + +struct StorageData FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef StorageDataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct StorageDataBuilder { + typedef StorageData Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(StorageData::VT_DATA, data); + } + explicit StorageDataBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateStorageData( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + StorageDataBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateStorageDataDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 16); } + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return torch::jit::mobile::serialization::CreateStorageData( + _fbb, + data__); +} + +struct IValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef IValueBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VAL_TYPE = 4, + VT_VAL = 6 + }; + torch::jit::mobile::serialization::IValueUnion val_type() const { + return static_cast(GetField(VT_VAL_TYPE, 0)); + } + const void *val() const { + return GetPointer(VT_VAL); + } + template const T *val_as() const; + const torch::jit::mobile::serialization::Int *val_as_Int() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::Int ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::Bool *val_as_Bool() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::Bool ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::Double *val_as_Double() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::Double ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::ComplexDouble *val_as_ComplexDouble() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::ComplexDouble ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::TensorMetadata *val_as_TensorMetadata() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::TensorMetadata ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::String *val_as_String() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::String ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::List *val_as_List() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::List ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::Tuple *val_as_Tuple() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::Tuple ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::Dict *val_as_Dict() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::Dict ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::Object *val_as_Object() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::Object ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::IntList *val_as_IntList() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::IntList ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::DoubleList *val_as_DoubleList() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::DoubleList ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::BoolList *val_as_BoolList() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::BoolList ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::Device *val_as_Device() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::Device ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::EnumValue *val_as_EnumValue() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::EnumValue ? static_cast(val()) : nullptr; + } + const torch::jit::mobile::serialization::Function *val_as_Function() const { + return val_type() == torch::jit::mobile::serialization::IValueUnion::Function ? static_cast(val()) : nullptr; + } + void *mutable_val() { + return GetPointer(VT_VAL); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VAL_TYPE) && + VerifyOffset(verifier, VT_VAL) && + VerifyIValueUnion(verifier, val(), val_type()) && + verifier.EndTable(); + } +}; + +template<> inline const torch::jit::mobile::serialization::Int *IValue::val_as() const { + return val_as_Int(); +} + +template<> inline const torch::jit::mobile::serialization::Bool *IValue::val_as() const { + return val_as_Bool(); +} + +template<> inline const torch::jit::mobile::serialization::Double *IValue::val_as() const { + return val_as_Double(); +} + +template<> inline const torch::jit::mobile::serialization::ComplexDouble *IValue::val_as() const { + return val_as_ComplexDouble(); +} + +template<> inline const torch::jit::mobile::serialization::TensorMetadata *IValue::val_as() const { + return val_as_TensorMetadata(); +} + +template<> inline const torch::jit::mobile::serialization::String *IValue::val_as() const { + return val_as_String(); +} + +template<> inline const torch::jit::mobile::serialization::List *IValue::val_as() const { + return val_as_List(); +} + +template<> inline const torch::jit::mobile::serialization::Tuple *IValue::val_as() const { + return val_as_Tuple(); +} + +template<> inline const torch::jit::mobile::serialization::Dict *IValue::val_as() const { + return val_as_Dict(); +} + +template<> inline const torch::jit::mobile::serialization::Object *IValue::val_as() const { + return val_as_Object(); +} + +template<> inline const torch::jit::mobile::serialization::IntList *IValue::val_as() const { + return val_as_IntList(); +} + +template<> inline const torch::jit::mobile::serialization::DoubleList *IValue::val_as() const { + return val_as_DoubleList(); +} + +template<> inline const torch::jit::mobile::serialization::BoolList *IValue::val_as() const { + return val_as_BoolList(); +} + +template<> inline const torch::jit::mobile::serialization::Device *IValue::val_as() const { + return val_as_Device(); +} + +template<> inline const torch::jit::mobile::serialization::EnumValue *IValue::val_as() const { + return val_as_EnumValue(); +} + +template<> inline const torch::jit::mobile::serialization::Function *IValue::val_as() const { + return val_as_Function(); +} + +struct IValueBuilder { + typedef IValue Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_val_type(torch::jit::mobile::serialization::IValueUnion val_type) { + fbb_.AddElement(IValue::VT_VAL_TYPE, static_cast(val_type), 0); + } + void add_val(flatbuffers::Offset val) { + fbb_.AddOffset(IValue::VT_VAL, val); + } + explicit IValueBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateIValue( + flatbuffers::FlatBufferBuilder &_fbb, + torch::jit::mobile::serialization::IValueUnion val_type = torch::jit::mobile::serialization::IValueUnion::NONE, + flatbuffers::Offset val = 0) { + IValueBuilder builder_(_fbb); + builder_.add_val(val); + builder_.add_val_type(val_type); + return builder_.Finish(); +} + +struct ExtraFile FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ExtraFileBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_CONTENT = 6 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + flatbuffers::String *mutable_name() { + return GetPointer(VT_NAME); + } + const flatbuffers::String *content() const { + return GetPointer(VT_CONTENT); + } + flatbuffers::String *mutable_content() { + return GetPointer(VT_CONTENT); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_CONTENT) && + verifier.VerifyString(content()) && + verifier.EndTable(); + } +}; + +struct ExtraFileBuilder { + typedef ExtraFile Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(ExtraFile::VT_NAME, name); + } + void add_content(flatbuffers::Offset content) { + fbb_.AddOffset(ExtraFile::VT_CONTENT, content); + } + explicit ExtraFileBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateExtraFile( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + flatbuffers::Offset content = 0) { + ExtraFileBuilder builder_(_fbb); + builder_.add_content(content); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateExtraFileDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const char *content = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto content__ = content ? _fbb.CreateString(content) : 0; + return torch::jit::mobile::serialization::CreateExtraFile( + _fbb, + name__, + content__); +} + +struct Module FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ModuleBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VERSION = 4, + VT_EXTRA_FILES = 6, + VT_METHODS = 8, + VT_STATE_OBJ = 10, + VT_IVALUES = 12, + VT_STORAGE_DATA_SIZE = 14, + VT_STORAGE_DATA = 16, + VT_OBJECT_TYPES = 18 + }; + int32_t version() const { + return GetField(VT_VERSION, 0); + } + bool mutate_version(int32_t _version = 0) { + return SetField(VT_VERSION, _version, 0); + } + const flatbuffers::Vector> *extra_files() const { + return GetPointer> *>(VT_EXTRA_FILES); + } + flatbuffers::Vector> *mutable_extra_files() { + return GetPointer> *>(VT_EXTRA_FILES); + } + const flatbuffers::Vector *methods() const { + return GetPointer *>(VT_METHODS); + } + flatbuffers::Vector *mutable_methods() { + return GetPointer *>(VT_METHODS); + } + uint32_t state_obj() const { + return GetField(VT_STATE_OBJ, 0); + } + bool mutate_state_obj(uint32_t _state_obj = 0) { + return SetField(VT_STATE_OBJ, _state_obj, 0); + } + const flatbuffers::Vector> *ivalues() const { + return GetPointer> *>(VT_IVALUES); + } + flatbuffers::Vector> *mutable_ivalues() { + return GetPointer> *>(VT_IVALUES); + } + int32_t storage_data_size() const { + return GetField(VT_STORAGE_DATA_SIZE, 0); + } + bool mutate_storage_data_size(int32_t _storage_data_size = 0) { + return SetField(VT_STORAGE_DATA_SIZE, _storage_data_size, 0); + } + const flatbuffers::Vector> *storage_data() const { + return GetPointer> *>(VT_STORAGE_DATA); + } + flatbuffers::Vector> *mutable_storage_data() { + return GetPointer> *>(VT_STORAGE_DATA); + } + const flatbuffers::Vector> *object_types() const { + return GetPointer> *>(VT_OBJECT_TYPES); + } + flatbuffers::Vector> *mutable_object_types() { + return GetPointer> *>(VT_OBJECT_TYPES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VERSION) && + VerifyOffset(verifier, VT_EXTRA_FILES) && + verifier.VerifyVector(extra_files()) && + verifier.VerifyVectorOfTables(extra_files()) && + VerifyOffset(verifier, VT_METHODS) && + verifier.VerifyVector(methods()) && + VerifyField(verifier, VT_STATE_OBJ) && + VerifyOffset(verifier, VT_IVALUES) && + verifier.VerifyVector(ivalues()) && + verifier.VerifyVectorOfTables(ivalues()) && + VerifyField(verifier, VT_STORAGE_DATA_SIZE) && + VerifyOffset(verifier, VT_STORAGE_DATA) && + verifier.VerifyVector(storage_data()) && + verifier.VerifyVectorOfTables(storage_data()) && + VerifyOffset(verifier, VT_OBJECT_TYPES) && + verifier.VerifyVector(object_types()) && + verifier.VerifyVectorOfTables(object_types()) && + verifier.EndTable(); + } +}; + +struct ModuleBuilder { + typedef Module Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_version(int32_t version) { + fbb_.AddElement(Module::VT_VERSION, version, 0); + } + void add_extra_files(flatbuffers::Offset>> extra_files) { + fbb_.AddOffset(Module::VT_EXTRA_FILES, extra_files); + } + void add_methods(flatbuffers::Offset> methods) { + fbb_.AddOffset(Module::VT_METHODS, methods); + } + void add_state_obj(uint32_t state_obj) { + fbb_.AddElement(Module::VT_STATE_OBJ, state_obj, 0); + } + void add_ivalues(flatbuffers::Offset>> ivalues) { + fbb_.AddOffset(Module::VT_IVALUES, ivalues); + } + void add_storage_data_size(int32_t storage_data_size) { + fbb_.AddElement(Module::VT_STORAGE_DATA_SIZE, storage_data_size, 0); + } + void add_storage_data(flatbuffers::Offset>> storage_data) { + fbb_.AddOffset(Module::VT_STORAGE_DATA, storage_data); + } + void add_object_types(flatbuffers::Offset>> object_types) { + fbb_.AddOffset(Module::VT_OBJECT_TYPES, object_types); + } + explicit ModuleBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateModule( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t version = 0, + flatbuffers::Offset>> extra_files = 0, + flatbuffers::Offset> methods = 0, + uint32_t state_obj = 0, + flatbuffers::Offset>> ivalues = 0, + int32_t storage_data_size = 0, + flatbuffers::Offset>> storage_data = 0, + flatbuffers::Offset>> object_types = 0) { + ModuleBuilder builder_(_fbb); + builder_.add_object_types(object_types); + builder_.add_storage_data(storage_data); + builder_.add_storage_data_size(storage_data_size); + builder_.add_ivalues(ivalues); + builder_.add_state_obj(state_obj); + builder_.add_methods(methods); + builder_.add_extra_files(extra_files); + builder_.add_version(version); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateModuleDirect( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t version = 0, + const std::vector> *extra_files = nullptr, + const std::vector *methods = nullptr, + uint32_t state_obj = 0, + const std::vector> *ivalues = nullptr, + int32_t storage_data_size = 0, + const std::vector> *storage_data = nullptr, + const std::vector> *object_types = nullptr) { + auto extra_files__ = extra_files ? _fbb.CreateVector>(*extra_files) : 0; + auto methods__ = methods ? _fbb.CreateVector(*methods) : 0; + auto ivalues__ = ivalues ? _fbb.CreateVector>(*ivalues) : 0; + auto storage_data__ = storage_data ? _fbb.CreateVector>(*storage_data) : 0; + auto object_types__ = object_types ? _fbb.CreateVector>(*object_types) : 0; + return torch::jit::mobile::serialization::CreateModule( + _fbb, + version, + extra_files__, + methods__, + state_obj, + ivalues__, + storage_data_size, + storage_data__, + object_types__); +} + +inline bool VerifyIValueUnion(flatbuffers::Verifier &verifier, const void *obj, IValueUnion type) { + switch (type) { + case IValueUnion::NONE: { + return true; + } + case IValueUnion::Int: { + return verifier.Verify(static_cast(obj), 0); + } + case IValueUnion::Bool: { + return verifier.Verify(static_cast(obj), 0); + } + case IValueUnion::Double: { + return verifier.Verify(static_cast(obj), 0); + } + case IValueUnion::ComplexDouble: { + return verifier.Verify(static_cast(obj), 0); + } + case IValueUnion::TensorMetadata: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::String: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::List: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::Tuple: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::Dict: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::Object: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::IntList: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::DoubleList: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::BoolList: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::Device: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::EnumValue: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case IValueUnion::Function: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyIValueUnionVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyIValueUnion( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline const torch::jit::mobile::serialization::Module *GetModule(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const torch::jit::mobile::serialization::Module *GetSizePrefixedModule(const void *buf) { + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline Module *GetMutableModule(void *buf) { + return flatbuffers::GetMutableRoot(buf); +} + +inline torch::jit::mobile::serialization::Module *GetMutableSizePrefixedModule(void *buf) { + return flatbuffers::GetMutableSizePrefixedRoot(buf); +} + +inline bool VerifyModuleBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(nullptr); +} + +inline bool VerifySizePrefixedModuleBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(nullptr); +} + +inline void FinishModuleBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root); +} + +inline void FinishSizePrefixedModuleBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root); +} + +} // namespace serialization +} // namespace mobile +} // namespace jit +} // namespace torch + +#endif // FLATBUFFERS_GENERATED_MOBILEBYTECODE_TORCH_JIT_MOBILE_SERIALIZATION_H_ +// @generated