Split module.cpp and export.cpp to support saving on mobile (#29881)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29881

Breaking these into separate files allows us to have three different builds:
- Mobile inference-only.
- Mobile with module saving.
- Server with module saving and other export functions like ONNX.

And this can be accomplished just by selecting which cpp files to compile,
without setting any preprocessor flags.

Test Plan: CI.  Local mobile+saving build.

Reviewed By: smessmer

Differential Revision: D18509296

fbshipit-source-id: 9438273bac4624df5c7f035b2bacb901cce43053
This commit is contained in:
David Reiss
2019-11-20 10:44:36 -08:00
committed by Facebook Github Bot
parent 72bc7bf37b
commit fbcb88e8b3
6 changed files with 326 additions and 320 deletions

View File

@ -443,6 +443,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module_save.cpp
${TORCH_SRC_DIR}/csrc/jit/script/object.cpp
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp
@ -485,6 +486,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/jit/export.cpp
${TORCH_SRC_DIR}/csrc/jit/export_module.cpp
${TORCH_SRC_DIR}/csrc/jit/import_legacy.cpp
${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp

View File

@ -82,6 +82,7 @@ libtorch_sources = [
"torch/csrc/jit/constants.cpp",
"torch/csrc/jit/node_hashing.cpp",
"torch/csrc/jit/export.cpp",
"torch/csrc/jit/export_module.cpp",
"torch/csrc/jit/pass_manager.cpp",
"torch/csrc/jit/pickler.cpp",
"torch/csrc/jit/unpickler.cpp",
@ -164,6 +165,7 @@ libtorch_sources = [
"torch/csrc/jit/hooks_for_testing.cpp",
"torch/csrc/jit/script/builtin_functions.cpp",
"torch/csrc/jit/script/module.cpp",
"torch/csrc/jit/script/module_save.cpp",
"torch/csrc/jit/script/object.cpp",
"torch/csrc/jit/tracer.cpp",
"torch/csrc/jit/fuser/kernel_cache.cpp",

View File

@ -1,6 +1,3 @@
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/type_resolver_util.h>
#include <torch/csrc/autograd/symbolic.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/onnx/onnx.h>
@ -9,16 +6,8 @@
#include <c10/util/Exception.h>
#include <torch/csrc/jit/import_export_helpers.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include <torch/csrc/jit/instruction.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <caffe2/core/types.h>
#include <caffe2/proto/caffe2_pb.h>
#include <caffe2/proto/torch_pb.h>
#include <caffe2/serialize/inline_container.h>
#include <onnx/onnx_pb.h>
#include <ATen/ATen.h>
@ -33,21 +22,11 @@
namespace torch {
namespace jit {
char const * toString(OpCode op);
namespace {
namespace onnx_torch = ::torch::onnx;
namespace onnx = ::ONNX_NAMESPACE;
namespace {
ExportModuleExtraFilesHook& GetExtraFilesHook() {
static ExportModuleExtraFilesHook func = nullptr;
return func;
};
}
class ScriptModuleSerializer;
std::string getNodeStackTraceString(const Node* n) {
return n->sourceRange().str();
}
@ -544,231 +523,6 @@ void GraphEncoder::EncodeTensor(
}
}
class ScriptModuleSerializer {
public:
explicit ScriptModuleSerializer(const std::string& filename)
: writer_(filename) {}
explicit ScriptModuleSerializer(
const std::function<size_t(const void *, size_t)>& writer_func)
: writer_(writer_func) {}
void serialize(
const script::Module& module,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
C10_LOG_API_USAGE_ONCE("torch.script.save");
writeExtraFiles(module, extra_files);
// Serialize the model object
writeArchive("data", module._ivalue());
// Then we werialize all code info.
writeCode(module.type());
// The tensor constants from the code are written to a separate archive
// so loading the code does not depend on loading the data
std::vector<IValue> ivalue_constants(
constant_table_.begin(), constant_table_.end());
writeArchive("constants", c10::ivalue::Tuple::create(ivalue_constants));
if (bytecode_format) {
writeByteCode(module);
}
}
private:
void writeArchive(const std::string& archive_name, const IValue& value) {
std::vector<char> data;
// Vector to capture the run-time class types during pickling the IValues
std::vector<c10::ClassTypePtr> memorizedClassTypes;
Pickler data_pickle(
[&](const char* buf, size_t size) {
data.insert(data.end(), buf, buf + size);
},
nullptr,
&memorizedClassTypes);
data_pickle.protocol();
data_pickle.pushIValue(value);
data_pickle.stop();
size_t i = 0;
std::string prefix = archive_name + "/";
for (const auto& td : data_pickle.tensorData()) {
std::string fname = prefix + std::to_string(i++);
writer_.writeRecord(fname, td.data(), td.sizeInBytes());
}
std::string fname = archive_name + ".pkl";
writer_.writeRecord(fname, data.data(), data.size());
// serialize all the captured run-time class types
for (const c10::ClassTypePtr& wroteType : memorizedClassTypes) {
convertNamedType(wroteType);
}
}
void writeExtraFiles(
const script::Module& module,
const script::ExtraFilesMap& extra_files) {
// Write out extra files.
for (const auto& kv : extra_files) {
const std::string key = "extra/" + kv.first;
writer_.writeRecord(key, kv.second.data(), kv.second.size());
}
auto hook = GetExtraFilesHook();
if (hook) {
script::ExtraFilesMap hook_files = hook(module);
for (const auto& kv : hook_files) {
const std::string key = "extra/" + kv.first;
writer_.writeRecord(key, kv.second.data(), kv.second.size());
}
}
}
void writeCode(const at::NamedTypePtr& root_type) {
class_deps_.push_back(root_type);
for (size_t i = 0; i < class_deps_.size(); ++i) {
// note: convertNameType may extend class_deps_, so re-checking
// .size() is necessary
convertNamedType(class_deps_[i]);
}
// Mapping of filename => src. We need this because multiple clases may go
// in the same file (e.g. foo.bar.Baz and foo.bar.Qux)
for (auto& item : file_streams_) {
const std::string filename = qualifierToArchivePath(item.key(), "code/");
std::string src = item.value().str();
// Only compress these records if they're not tiny.
// The cpu cost of generating zip datastructs and compressing isn't
// well-spent for very small records.
static constexpr size_t kMinToCompress = 200;
writer_.writeRecord(
filename, src.c_str(), src.size(),
src.size() > kMinToCompress /*compress*/);
// Write out the debug information
std::string debugFilename = filename + ".debug_pkl";
SourceRangePickler source_range_pickler;
auto range_data =
source_range_pickler.pickle(item.value().ranges());
writer_.writeRecord(
debugFilename,
range_data.data(),
range_data.size(),
range_data.size() > kMinToCompress /*compress*/);
}
}
void writeByteCode(const script::Module& module) {
auto methods = module.get_methods();
std::vector<c10::IValue> elements;
for (const auto& method : methods) {
const auto& func = method.function();
auto graph = func.graph()->copy();
Inline(*graph);
torch::jit::Code code(graph);
// Make a copy of opnames. Some of them may be changed for mobile later.
std::vector<c10::OperatorName> opnames;
for (size_t i = 0; i < code.instructions().size(); ++i) {
Instruction ins = code.instructions()[i];
if (ins.op == OP) {
auto node = code.instructions_source()[i];
opnames.emplace_back(node->schema().operator_name());
}
}
// instructions
std::vector<IValue> inss;
for (size_t i = 0; i < code.instructions().size(); ++i) {
Instruction ins = code.instructions()[i];
TORCH_CHECK(isOpSupportedInMobile(ins.op), toString(ins.op),
" is not supported in mobile module.");
if (ins.op == OP) {
if (opnames[ins.X].name == "prim::ListConstruct" ||
opnames[ins.X].name == "prim::TupleConstruct" ||
opnames[ins.X].name == "prim::TupleUnpack" ||
opnames[ins.X].name == "aten::format") {
auto node = code.instructions_source()[i];
ins.op = OPN;
if (opnames[ins.X].name == "prim::TupleUnpack") {
ins.N = node->outputs().size();
} else {
ins.N = node->inputs().size();
}
if (opnames[ins.X].name == "prim::ListConstruct") {
ListTypePtr lt = node->output()->type()->expect<ListType>();
if (lt->getElementType() == IntType::get()) {
opnames[ins.X].overload_name = "int";
} else if (lt->getElementType() == FloatType::get()) {
opnames[ins.X].overload_name = "float";
} else if (lt->getElementType() == BoolType::get()) {
opnames[ins.X].overload_name = "bool";
} else if (lt->getElementType()->isSubtypeOf(TensorType::get())) {
opnames[ins.X].overload_name = "Tensor";
} else {
opnames[ins.X].overload_name = "generic";
}
} else if (opnames[ins.X].name == "prim::TupleConstruct" &&
node->output()->type()->expect<TupleType>()->name().has_value()) {
AT_WARN("Named tuple is serialized as un-named tuple.");
}
}
}
std::vector<IValue> insv{toString(ins.op), ins.X, ins.N};
inss.emplace_back(c10::ivalue::Tuple::create(std::move(insv)));
}
auto instructions = c10::ivalue::Tuple::create(std::move(inss));
auto named_ins = c10::ivalue::Tuple::create({"instructions", instructions});
// operators
std::vector<IValue> opss;
for (const auto& opname : opnames) {
opss.emplace_back(c10::ivalue::Tuple::create({opname.name, opname.overload_name}));
}
auto operators = c10::ivalue::Tuple::create(std::move(opss));
auto named_ops = c10::ivalue::Tuple::create({"operators", operators});
// constants
auto constants = c10::ivalue::Tuple::create(code.constant_table());
auto named_consts = c10::ivalue::Tuple::create({"constants", constants});
// since the register location is embedded into the bytecode, pass the register size
auto named_regsize = c10::ivalue::Tuple::create({"register_size",
static_cast<int>(code.register_size())});
auto element = c10::ivalue::Tuple::create({named_ins, named_ops, named_consts, named_regsize});
elements.push_back(c10::ivalue::Tuple::create({func.qualname().qualifiedName(), element}));
}
auto telements = c10::ivalue::Tuple::create(std::move(elements));
writeArchive("bytecode", telements);
}
void convertNamedType(const c10::NamedTypePtr& class_type) {
if (converted_types_.count(class_type)) {
return;
}
converted_types_.insert(class_type);
std::string qualifier = class_type->name()->prefix();
PythonPrint* pp = file_streams_.find(qualifier);
if (!pp) {
pp = &file_streams_.insert(
qualifier,
PythonPrint(
constant_table_, class_deps_, /*enforce_importable=*/true));
pp->LEGACY_printOpVersion();
}
pp->printNamedType(class_type);
}
caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::Tensor> constant_table_;
std::unordered_set<c10::NamedTypePtr> converted_types_;
std::vector<c10::NamedTypePtr> class_deps_;
// qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
// created
OrderedDict<std::string, PythonPrint> file_streams_;
bool bytecode_format_;
};
// Pretty printing for ONNX
constexpr char indent_char = ' ';
constexpr size_t indent_multiplier = 2;
@ -946,10 +700,6 @@ std::string prettyPrint(const onnx::ModelProto& model) {
} // namespace
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
GetExtraFilesHook() = hook;
}
std::string pretty_print_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
@ -1001,37 +751,5 @@ std::tuple<std::string, RawDataExportMap> export_onnx(
graph_encoder.get_raw_data_export_map());
}
void ExportModule(
const script::Module& module,
std::ostream& out,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(
[&](const void* buf, size_t nbytes) -> size_t {
out.write(static_cast<const char *>(buf), nbytes);
return !out ? 0 : nbytes;
});
serializer.serialize(module, extra_files, bytecode_format);
}
void ExportModule(
const script::Module& module,
const std::string& filename,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(filename);
serializer.serialize(module, extra_files, bytecode_format);
}
void ExportModule(
const script::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(writer_func);
serializer.serialize(module, extra_files, bytecode_format);
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,291 @@
#include <torch/csrc/jit/export.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/import_export_helpers.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pickle.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include <torch/csrc/jit/instruction.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <caffe2/serialize/inline_container.h>
#include <ATen/ATen.h>
#include <string>
#include <vector>
namespace torch {
namespace jit {
char const * toString(OpCode op);
namespace {
ExportModuleExtraFilesHook& GetExtraFilesHook() {
static ExportModuleExtraFilesHook func = nullptr;
return func;
};
}
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
GetExtraFilesHook() = hook;
}
class ScriptModuleSerializer {
public:
explicit ScriptModuleSerializer(const std::string& filename)
: writer_(filename) {}
explicit ScriptModuleSerializer(
const std::function<size_t(const void *, size_t)>& writer_func)
: writer_(writer_func) {}
void serialize(
const script::Module& module,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
C10_LOG_API_USAGE_ONCE("torch.script.save");
writeExtraFiles(module, extra_files);
// Serialize the model object
writeArchive("data", module._ivalue());
// Then we werialize all code info.
writeCode(module.type());
// The tensor constants from the code are written to a separate archive
// so loading the code does not depend on loading the data
std::vector<IValue> ivalue_constants(
constant_table_.begin(), constant_table_.end());
writeArchive("constants", c10::ivalue::Tuple::create(ivalue_constants));
if (bytecode_format) {
writeByteCode(module);
}
}
private:
void writeArchive(const std::string& archive_name, const IValue& value) {
std::vector<char> data;
// Vector to capture the run-time class types during pickling the IValues
std::vector<c10::ClassTypePtr> memorizedClassTypes;
Pickler data_pickle(
[&](const char* buf, size_t size) {
data.insert(data.end(), buf, buf + size);
},
nullptr,
&memorizedClassTypes);
data_pickle.protocol();
data_pickle.pushIValue(value);
data_pickle.stop();
size_t i = 0;
std::string prefix = archive_name + "/";
for (const auto& td : data_pickle.tensorData()) {
std::string fname = prefix + std::to_string(i++);
writer_.writeRecord(fname, td.data(), td.sizeInBytes());
}
std::string fname = archive_name + ".pkl";
writer_.writeRecord(fname, data.data(), data.size());
// serialize all the captured run-time class types
for (const c10::ClassTypePtr& wroteType : memorizedClassTypes) {
convertNamedType(wroteType);
}
}
void writeExtraFiles(
const script::Module& module,
const script::ExtraFilesMap& extra_files) {
// Write out extra files.
for (const auto& kv : extra_files) {
const std::string key = "extra/" + kv.first;
writer_.writeRecord(key, kv.second.data(), kv.second.size());
}
auto hook = GetExtraFilesHook();
if (hook) {
script::ExtraFilesMap hook_files = hook(module);
for (const auto& kv : hook_files) {
const std::string key = "extra/" + kv.first;
writer_.writeRecord(key, kv.second.data(), kv.second.size());
}
}
}
void writeCode(const at::NamedTypePtr& root_type) {
class_deps_.push_back(root_type);
for (size_t i = 0; i < class_deps_.size(); ++i) {
// note: convertNameType may extend class_deps_, so re-checking
// .size() is necessary
convertNamedType(class_deps_[i]);
}
// Mapping of filename => src. We need this because multiple clases may go
// in the same file (e.g. foo.bar.Baz and foo.bar.Qux)
for (auto& item : file_streams_) {
const std::string filename = qualifierToArchivePath(item.key(), "code/");
std::string src = item.value().str();
// Only compress these records if they're not tiny.
// The cpu cost of generating zip datastructs and compressing isn't
// well-spent for very small records.
static constexpr size_t kMinToCompress = 200;
writer_.writeRecord(
filename, src.c_str(), src.size(),
src.size() > kMinToCompress /*compress*/);
// Write out the debug information
std::string debugFilename = filename + ".debug_pkl";
SourceRangePickler source_range_pickler;
auto range_data =
source_range_pickler.pickle(item.value().ranges());
writer_.writeRecord(
debugFilename,
range_data.data(),
range_data.size(),
range_data.size() > kMinToCompress /*compress*/);
}
}
void writeByteCode(const script::Module& module) {
auto methods = module.get_methods();
std::vector<c10::IValue> elements;
for (const auto& method : methods) {
const auto& func = method.function();
auto graph = func.graph()->copy();
Inline(*graph);
torch::jit::Code code(graph);
// Make a copy of opnames. Some of them may be changed for mobile later.
std::vector<c10::OperatorName> opnames;
for (size_t i = 0; i < code.instructions().size(); ++i) {
Instruction ins = code.instructions()[i];
if (ins.op == OP) {
auto node = code.instructions_source()[i];
opnames.emplace_back(node->schema().operator_name());
}
}
// instructions
std::vector<IValue> inss;
for (size_t i = 0; i < code.instructions().size(); ++i) {
Instruction ins = code.instructions()[i];
TORCH_CHECK(isOpSupportedInMobile(ins.op), toString(ins.op),
" is not supported in mobile module.");
if (ins.op == OP) {
if (opnames[ins.X].name == "prim::ListConstruct" ||
opnames[ins.X].name == "prim::TupleConstruct" ||
opnames[ins.X].name == "prim::TupleUnpack" ||
opnames[ins.X].name == "aten::format") {
auto node = code.instructions_source()[i];
ins.op = OPN;
if (opnames[ins.X].name == "prim::TupleUnpack") {
ins.N = node->outputs().size();
} else {
ins.N = node->inputs().size();
}
if (opnames[ins.X].name == "prim::ListConstruct") {
ListTypePtr lt = node->output()->type()->expect<ListType>();
if (lt->getElementType() == IntType::get()) {
opnames[ins.X].overload_name = "int";
} else if (lt->getElementType() == FloatType::get()) {
opnames[ins.X].overload_name = "float";
} else if (lt->getElementType() == BoolType::get()) {
opnames[ins.X].overload_name = "bool";
} else if (lt->getElementType()->isSubtypeOf(TensorType::get())) {
opnames[ins.X].overload_name = "Tensor";
} else {
opnames[ins.X].overload_name = "generic";
}
} else if (opnames[ins.X].name == "prim::TupleConstruct" &&
node->output()->type()->expect<TupleType>()->name().has_value()) {
AT_WARN("Named tuple is serialized as un-named tuple.");
}
}
}
std::vector<IValue> insv{toString(ins.op), ins.X, ins.N};
inss.emplace_back(c10::ivalue::Tuple::create(std::move(insv)));
}
auto instructions = c10::ivalue::Tuple::create(std::move(inss));
auto named_ins = c10::ivalue::Tuple::create({"instructions", instructions});
// operators
std::vector<IValue> opss;
for (const auto& opname : opnames) {
opss.emplace_back(c10::ivalue::Tuple::create({opname.name, opname.overload_name}));
}
auto operators = c10::ivalue::Tuple::create(std::move(opss));
auto named_ops = c10::ivalue::Tuple::create({"operators", operators});
// constants
auto constants = c10::ivalue::Tuple::create(code.constant_table());
auto named_consts = c10::ivalue::Tuple::create({"constants", constants});
// since the register location is embedded into the bytecode, pass the register size
auto named_regsize = c10::ivalue::Tuple::create({"register_size",
static_cast<int>(code.register_size())});
auto element = c10::ivalue::Tuple::create({named_ins, named_ops, named_consts, named_regsize});
elements.push_back(c10::ivalue::Tuple::create({func.qualname().qualifiedName(), element}));
}
auto telements = c10::ivalue::Tuple::create(std::move(elements));
writeArchive("bytecode", telements);
}
void convertNamedType(const c10::NamedTypePtr& class_type) {
if (converted_types_.count(class_type)) {
return;
}
converted_types_.insert(class_type);
std::string qualifier = class_type->name()->prefix();
PythonPrint* pp = file_streams_.find(qualifier);
if (!pp) {
pp = &file_streams_.insert(
qualifier,
PythonPrint(
constant_table_, class_deps_, /*enforce_importable=*/true));
pp->LEGACY_printOpVersion();
}
pp->printNamedType(class_type);
}
caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::Tensor> constant_table_;
std::unordered_set<c10::NamedTypePtr> converted_types_;
std::vector<c10::NamedTypePtr> class_deps_;
// qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
// created
OrderedDict<std::string, PythonPrint> file_streams_;
bool bytecode_format_;
};
void ExportModule(
const script::Module& module,
std::ostream& out,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(
[&](const void* buf, size_t nbytes) -> size_t {
out.write(static_cast<const char *>(buf), nbytes);
return !out ? 0 : nbytes;
});
serializer.serialize(module, extra_files, bytecode_format);
}
void ExportModule(
const script::Module& module,
const std::string& filename,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(filename);
serializer.serialize(module, extra_files, bytecode_format);
}
void ExportModule(
const script::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const script::ExtraFilesMap& extra_files,
bool bytecode_format) {
ScriptModuleSerializer serializer(writer_func);
serializer.serialize(module, extra_files, bytecode_format);
}
} // namespace jit
} // namespace torch

View File

@ -1,7 +1,6 @@
#include <torch/csrc/jit/script/module.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
@ -75,43 +74,6 @@ void Module::to(at::Device device, bool non_blocking) {
to_impl(device, /*dtype=*/c10::nullopt, non_blocking);
}
void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) const {
#ifndef C10_MOBILE
ExportModule(*this, out, extra_files, false);
#else
AT_ERROR("Saving module is not supported on mobile.");
#endif
}
void Module::save(const std::string& filename, const ExtraFilesMap& extra_files)
const {
#ifndef C10_MOBILE
ExportModule(*this, filename, extra_files, false);
#else
AT_ERROR("Saving module is not supported on mobile.");
#endif
}
void Module::_save_for_mobile(
std::ostream& out,
const ExtraFilesMap& extra_files) const {
#ifndef C10_MOBILE
ExportModule(*this, out, extra_files, true);
#else
AT_ERROR("Saving module is not supported on mobile.");
#endif
}
void Module::_save_for_mobile(
const std::string& filename,
const ExtraFilesMap& extra_files) const {
#ifndef C10_MOBILE
ExportModule(*this, filename, extra_files, true);
#else
AT_ERROR("Saving module is not supported on mobile.");
#endif
}
void module_state_to(
autograd::Variable variable,
const c10::optional<at::Device>& device,

View File

@ -0,0 +1,31 @@
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/export.h>
namespace torch {
namespace jit {
namespace script {
void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) const {
ExportModule(*this, out, extra_files, false);
}
void Module::save(const std::string& filename, const ExtraFilesMap& extra_files)
const {
ExportModule(*this, filename, extra_files, false);
}
void Module::_save_for_mobile(
std::ostream& out,
const ExtraFilesMap& extra_files) const {
ExportModule(*this, out, extra_files, true);
}
void Module::_save_for_mobile(
const std::string& filename,
const ExtraFilesMap& extra_files) const {
ExportModule(*this, filename, extra_files, true);
}
} // namespace script
} // namespace jit
} // namespace torch