mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This PR applies clang-tidy readability checks to jit sources and all headers in the code base. `readability-redundant-inline-specifier` is suppressed because it incurs too many changes. `readability-redundant-inline-specifier` is used to detect redundant inline specifiers on function and variable declarations. There are many in-class method definitions that are marked inline. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164652 Approved by: https://github.com/Skylion007
1479 lines
50 KiB
C++
1479 lines
50 KiB
C++
#include <torch/csrc/jit/serialization/export.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/Utils.h>
|
|
#include <ATen/core/functional.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/accumulate.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/autograd/symbolic.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/inliner.h>
|
|
#include <torch/csrc/jit/runtime/instruction.h>
|
|
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
|
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
|
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
|
#include <torch/csrc/jit/serialization/onnx.h>
|
|
#include <torch/csrc/jit/serialization/pickler.h>
|
|
#include <torch/csrc/onnx/back_compat.h>
|
|
#include <torch/csrc/onnx/onnx.h>
|
|
#include <torch/version.h>
|
|
|
|
#include <onnx/checker.h>
|
|
#include <onnx/onnx_pb.h>
|
|
#include <onnx/proto_utils.h>
|
|
#include <onnx/shape_inference/implementation.h>
|
|
|
|
#include <memory>
|
|
#include <optional>
|
|
#include <regex>
|
|
#include <set>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch::jit {
|
|
|
|
static std::string get_little_endian_data(const at::Tensor& t) {
|
|
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
|
return std::string(
|
|
static_cast<char*>(t.data_ptr()), t.element_size() * t.numel());
|
|
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
|
const size_t element_size = t.element_size();
|
|
const size_t num_elements = t.numel();
|
|
|
|
std::vector<char> data_copy{
|
|
static_cast<char*>(t.data_ptr()),
|
|
static_cast<char*>(t.data_ptr()) + element_size * num_elements};
|
|
|
|
for (size_t i = 0; i < num_elements; ++i) {
|
|
char* start_byte = data_copy.data() + i * element_size;
|
|
char* end_byte = start_byte + element_size - 1;
|
|
/* keep swapping */
|
|
for (size_t count = 0; count < element_size / 2; ++count) {
|
|
std::swap(*start_byte, *end_byte);
|
|
++start_byte;
|
|
--end_byte;
|
|
}
|
|
}
|
|
|
|
return std::string(data_copy.data(), element_size * num_elements);
|
|
#else
|
|
#error Unexpected or undefined __BYTE_ORDER__
|
|
#endif
|
|
}
|
|
|
|
void writeArchiveAndTensors(
|
|
const std::string& archive_name,
|
|
const char* data,
|
|
size_t size,
|
|
const std::vector<at::Tensor>& tensors,
|
|
caffe2::serialize::PyTorchStreamWriter& out) {
|
|
std::string prefix = archive_name + "/";
|
|
size_t i = 0;
|
|
for (const auto& td : tensors) {
|
|
WriteableTensorData writable_td = getWriteableTensorData(td);
|
|
std::string fname = prefix + std::to_string(i++);
|
|
out.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
|
|
}
|
|
std::string fname = archive_name + ".pkl";
|
|
out.writeRecord(fname, data, size);
|
|
}
|
|
|
|
namespace {
|
|
namespace onnx_torch = ::torch::onnx;
|
|
namespace onnx = ::ONNX_NAMESPACE;
|
|
|
|
constexpr int kInvalidOpsetVersion = -1;
|
|
constexpr int kMainOpsetVersion = 23;
|
|
// Based on OP_SET_ID_VERSION_MAP in
|
|
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
|
|
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
|
|
kOpsetVersionToIRVersion = {
|
|
kInvalidOpsetVersion,
|
|
3, // opset 1
|
|
kInvalidOpsetVersion,
|
|
kInvalidOpsetVersion,
|
|
kInvalidOpsetVersion,
|
|
3, // opset 5
|
|
3, // opset 6
|
|
3, // opset 7
|
|
3, // opset 8
|
|
4, // opset 9
|
|
5, // opset 10
|
|
6, // opset 11
|
|
7, // opset 12
|
|
7, // opset 13
|
|
7, // opset 14
|
|
8, // opset 15
|
|
8, // opset 16
|
|
8, // opset 17
|
|
8, // opset 18
|
|
9, // opset 19
|
|
9, // opset 20
|
|
10, // opset 21
|
|
10, // opset 22
|
|
11, // opset 23
|
|
};
|
|
|
|
std::string getNodeStackTraceString(const Node* n) {
|
|
return n->sourceRange().str();
|
|
}
|
|
|
|
void validateBlock(
|
|
Block* b,
|
|
onnx_torch::OperatorExportTypes operator_export_type) {
|
|
for (auto node : b->nodes()) {
|
|
for (Block* sub_block : node->blocks()) {
|
|
validateBlock(sub_block, operator_export_type);
|
|
}
|
|
// Macro'ed so we get a marginally better line number on failed export
|
|
#define FAIL_EXPORT(name) \
|
|
throw std::runtime_error( \
|
|
std::string("ONNX export failed: ") + name + \
|
|
"\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
|
|
// Special error messages for certain types of operators
|
|
if (node->kind() == prim::PythonOp) {
|
|
if (operator_export_type !=
|
|
onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH) {
|
|
auto py_node = static_cast<PythonOp*>(node);
|
|
FAIL_EXPORT(
|
|
"Couldn't export Python operator " + py_node->name() +
|
|
"\n\nDefined at:\n" + getNodeStackTraceString(node))
|
|
}
|
|
} else {
|
|
if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
|
|
if (operator_export_type !=
|
|
onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH) {
|
|
FAIL_EXPORT(
|
|
"Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
|
|
getNodeStackTraceString(node));
|
|
}
|
|
}
|
|
bool is_aten_enabled = operator_export_type ==
|
|
onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
|
|
operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN ||
|
|
operator_export_type ==
|
|
onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH;
|
|
if (node->kind().is_aten() && !is_aten_enabled && !node->mustBeNone()) {
|
|
FAIL_EXPORT(
|
|
"Couldn't export operator " + node->kind().toDisplayString() +
|
|
"\n\nDefined at:\n" + getNodeStackTraceString(node));
|
|
}
|
|
}
|
|
#undef FAIL_EXPORT
|
|
}
|
|
}
|
|
|
|
void validateGraph(
|
|
const std::shared_ptr<Graph>& graph,
|
|
onnx_torch::OperatorExportTypes operator_export_type) {
|
|
validateBlock(graph->block(), operator_export_type);
|
|
}
|
|
|
|
std::string GetFileRootPath(const std::string& rootPath) {
|
|
std::string rootPath_ = rootPath;
|
|
// First, making slash consistent.
|
|
std::replace(rootPath_.begin(), rootPath_.end(), '\\', '/');
|
|
// Second, remove trailing slashes, if any
|
|
std::regex trailer("/+$");
|
|
std::string root = std::regex_replace(rootPath_, trailer, std::string());
|
|
std::string folder = root.substr(0, root.find_last_of('/'));
|
|
if (folder == rootPath_) { // If no root folder specified, select cwd.
|
|
return std::string(".");
|
|
}
|
|
return folder;
|
|
}
|
|
|
|
std::string GetExternalFileName(
|
|
const std::optional<std::string>& external_ref) {
|
|
auto tensorName = external_ref.value();
|
|
const std::string illegalChars = "\\/:?\"<>|";
|
|
for (char& i : tensorName) {
|
|
if (illegalChars.find(i) != std::string::npos) {
|
|
i = '_';
|
|
}
|
|
}
|
|
return tensorName;
|
|
}
|
|
|
|
void CloseFile(FILE* fp) {
|
|
fclose(fp);
|
|
}
|
|
|
|
void CreateExternalFile(
|
|
const at::Tensor& tensor,
|
|
const std::string& tensorName,
|
|
const std::string& onnx_file_path) {
|
|
auto folder = GetFileRootPath(onnx_file_path);
|
|
std::string fullFilePath = folder + "/" + tensorName;
|
|
std::unique_ptr<FILE, decltype(&CloseFile)> fp(
|
|
fopen(fullFilePath.c_str(), "wb"), &CloseFile);
|
|
if (fp == nullptr) {
|
|
throw std::runtime_error(
|
|
std::string("ONNX export failed. Could not open file or directory: ") +
|
|
fullFilePath);
|
|
}
|
|
std::string s = get_little_endian_data(tensor);
|
|
fwrite(s.c_str(), tensor.element_size(), tensor.numel(), fp.get());
|
|
} // fclose() called here through CloseFile(), if FILE* is not a null pointer.
|
|
|
|
class GraphEncoder {
|
|
public:
|
|
GraphEncoder(
|
|
const std::shared_ptr<Graph>& graph,
|
|
int64_t onnx_opset_version,
|
|
onnx_torch::OperatorExportTypes operator_export_type,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
bool defer_weight_export,
|
|
bool strip_doc,
|
|
bool keep_initializers_as_inputs,
|
|
const std::map<std::string, int>& custom_opsets,
|
|
bool add_node_names,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path,
|
|
NodeAttrNameMap node_attr_to_name = {});
|
|
|
|
std::shared_ptr<onnx::ModelProto> get_model_proto() {
|
|
return model_proto_;
|
|
}
|
|
|
|
SymbolDimMap get_symbol_dim_param_map() {
|
|
return symbol_dim_map_;
|
|
}
|
|
|
|
RawDataExportMap get_raw_data_export_map() {
|
|
return raw_data_export_map_;
|
|
}
|
|
|
|
bool get_use_external_data_format() {
|
|
return use_external_data_format_;
|
|
}
|
|
|
|
NodeNameMap get_onnx_node_names() {
|
|
return onnx_node_name_map_;
|
|
}
|
|
|
|
private:
|
|
// Using std::map instead of std::unordered_map for initializers
|
|
// in EncodeGraph constructor so that the order in which initializers
|
|
// get written to the ONNX graph is always the deterministic and
|
|
// predictable. While this is not a ONNX requirement, it is needed
|
|
// for testing purposes in tests that use _export_to_pretty_string()
|
|
// for validating ONNX graphs.
|
|
void EncodeGraph(
|
|
onnx::GraphProto* graph_proto,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::map<std::string, at::Tensor>& initializers =
|
|
std::map<std::string, at::Tensor>(),
|
|
const std::
|
|
unordered_map<std::string, std::unordered_map<int64_t, std::string>>&
|
|
dynamic_axes = std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>(),
|
|
bool keep_initializers_as_inputs = true,
|
|
bool add_node_names = true,
|
|
bool use_external_data_format = false,
|
|
const std::string& onnx_file_path = std::string());
|
|
|
|
void EncodeBlock(
|
|
onnx::GraphProto* graph_proto,
|
|
const Block* block,
|
|
const std::map<std::string, at::Tensor>& initializers =
|
|
std::map<std::string, at::Tensor>(),
|
|
const std::
|
|
unordered_map<std::string, std::unordered_map<int64_t, std::string>>&
|
|
dynamic_axes = std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>(),
|
|
bool keep_initializers_as_inputs = true,
|
|
bool add_node_names = true,
|
|
bool use_external_data_format = false,
|
|
const std::string& onnx_file_path = std::string());
|
|
|
|
void AddInitializersIntoGraphProto(
|
|
onnx::GraphProto* graph_proto,
|
|
const Block* block,
|
|
const std::map<std::string, at::Tensor>& initializers =
|
|
std::map<std::string, at::Tensor>(),
|
|
bool use_external_data_format = false,
|
|
const std::string& onnx_file_path = std::string());
|
|
|
|
unsigned long long int GetGraphProtoSize(
|
|
onnx::GraphProto* graph_proto,
|
|
const std::shared_ptr<Graph>& graph,
|
|
bool add_node_names,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path,
|
|
const std::map<std::string, at::Tensor>& initializers =
|
|
std::map<std::string, at::Tensor>());
|
|
|
|
void EncodeNode(
|
|
onnx::GraphProto* graph_proto,
|
|
onnx::NodeProto* node_proto,
|
|
const Node* node,
|
|
bool add_node_names = true,
|
|
bool use_external_data_format = false,
|
|
const std::string& onnx_file_path = std::string());
|
|
|
|
void EncodeTypeProto(
|
|
onnx::TypeProto* type_proto,
|
|
const TypePtr& node_type,
|
|
const std::string& name);
|
|
|
|
void EncodeLocalFunctionOpsetImport(
|
|
onnx::FunctionProto* func_proto,
|
|
const Node* n,
|
|
std::unordered_set<std::string>& custom_domains);
|
|
|
|
void EncodeLocalFunction(
|
|
onnx::GraphProto* graph_proto,
|
|
onnx::FunctionProto* func_proto,
|
|
const Node* n,
|
|
bool add_node_names = true,
|
|
bool use_external_data_format = false,
|
|
const std::string& onnx_file_path = std::string());
|
|
|
|
void EncodeTensor(
|
|
onnx::TensorProto* tensor_proto,
|
|
const at::Tensor& tensor,
|
|
const std::optional<std::string>& external_ref = {},
|
|
const bool use_external_data_format = false,
|
|
const std::string& onnx_file_path = std::string());
|
|
|
|
void EncodeIntermediateValueInfo(
|
|
onnx::GraphProto* graph_proto,
|
|
const Value* n);
|
|
|
|
void EncodeValueInfo(
|
|
onnx::GraphProto* graph_proto,
|
|
onnx::ValueInfoProto* v,
|
|
const Value* n,
|
|
const std::
|
|
unordered_map<std::string, std::unordered_map<int64_t, std::string>>&
|
|
dynamic_axes = std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>());
|
|
|
|
void EncodeValueInfoType(
|
|
onnx::TypeProto* onnx_type,
|
|
const TypePtr& node_type,
|
|
const Value* n,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>& dynamic_axes);
|
|
|
|
void AddAttribute(
|
|
onnx::NodeProto* node_proto,
|
|
const jit::Symbol name,
|
|
const std::string& ref_attr_name,
|
|
const AttributeKind attr_kind);
|
|
|
|
void AddAttribute(
|
|
onnx::NodeProto* node_proto,
|
|
const jit::Node* node,
|
|
const jit::Symbol name,
|
|
const bool use_external_data_format = false,
|
|
const std::string& onnx_file_path = std::string());
|
|
|
|
void AddAttribute(onnx::FunctionProto* func_proto, const std::string& name);
|
|
|
|
void TensorTypeToONNXType(
|
|
const TensorTypePtr& tensor_type,
|
|
const std::string& dim_name_prefix,
|
|
const std::string& name,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
onnx::TypeProto_Tensor* onnx_tensor_type,
|
|
bool assign_dim_param = true);
|
|
|
|
SymbolDimMap symbol_dim_map_;
|
|
std::shared_ptr<onnx::ModelProto> model_proto_;
|
|
size_t num_blocks_{0};
|
|
size_t num_op_nodes_{0};
|
|
size_t num_external_data_{0};
|
|
onnx_torch::OperatorExportTypes operator_export_type_;
|
|
bool strip_doc_;
|
|
std::set<std::string> domains_;
|
|
RawDataExportMap raw_data_export_map_;
|
|
bool defer_weight_export_;
|
|
bool use_external_data_format_;
|
|
int64_t onnx_opset_version_;
|
|
std::map<std::string, int> custom_opsets_;
|
|
std::shared_ptr<Graph> graph_;
|
|
NodeAttrNameMap node_attr_to_name_;
|
|
NodeNameMap onnx_node_name_map_;
|
|
// For large models, the parameters can be stored in separate binary files.
|
|
// This parameter sets a threshold on the number of elements in the parameter
|
|
// tensor, beyond which the parameter is stored in a separate file (if
|
|
// use_external_data_format_ is True). This threshold is in place
|
|
// so as not to create too many external files.
|
|
static constexpr size_t ParamSizeThresholdForExternalStorage = 1024;
|
|
};
|
|
|
|
static onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
|
|
switch (at_type) {
|
|
case at::kDouble:
|
|
return onnx::TensorProto_DataType_DOUBLE;
|
|
case at::kFloat:
|
|
return onnx::TensorProto_DataType_FLOAT;
|
|
case at::kHalf:
|
|
return onnx::TensorProto_DataType_FLOAT16;
|
|
case at::kByte:
|
|
return onnx::TensorProto_DataType_UINT8;
|
|
case at::kChar:
|
|
return onnx::TensorProto_DataType_INT8;
|
|
case at::kShort:
|
|
return onnx::TensorProto_DataType_INT16;
|
|
case at::kInt:
|
|
return onnx::TensorProto_DataType_INT32;
|
|
case at::kLong:
|
|
return onnx::TensorProto_DataType_INT64;
|
|
case at::kBool:
|
|
return onnx::TensorProto_DataType_BOOL;
|
|
case at::kQInt8:
|
|
return onnx::TensorProto_DataType_INT8;
|
|
case at::kQUInt8:
|
|
return onnx::TensorProto_DataType_UINT8;
|
|
case at::kQInt32:
|
|
return onnx::TensorProto_DataType_INT32;
|
|
case at::kBFloat16:
|
|
return onnx::TensorProto_DataType_BFLOAT16;
|
|
case at::kFloat8_e4m3fn:
|
|
return onnx_torch::TensorProto_DataType_FLOAT8E4M3FN;
|
|
case at::kFloat8_e5m2:
|
|
return onnx_torch::TensorProto_DataType_FLOAT8E5M2;
|
|
case at::kFloat8_e4m3fnuz:
|
|
return onnx_torch::TensorProto_DataType_FLOAT8E4M3FNUZ;
|
|
case at::kFloat8_e5m2fnuz:
|
|
return onnx_torch::TensorProto_DataType_FLOAT8E5M2FNUZ;
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
"ScalarType ",
|
|
toString(at_type),
|
|
" is an unexpected tensor scalar type");
|
|
}
|
|
}
|
|
|
|
static onnx::AttributeProto_AttributeType ATenAttributeKindToOnnxAttributeType(
|
|
AttributeKind at_kind,
|
|
const jit::Symbol name) {
|
|
switch (at_kind) {
|
|
case AttributeKind::f:
|
|
return onnx::AttributeProto_AttributeType_FLOAT;
|
|
case AttributeKind::fs:
|
|
return onnx::AttributeProto_AttributeType_FLOATS;
|
|
case AttributeKind::i:
|
|
return onnx::AttributeProto_AttributeType_INT;
|
|
case AttributeKind::is:
|
|
return onnx::AttributeProto_AttributeType_INTS;
|
|
case AttributeKind::s:
|
|
return onnx::AttributeProto_AttributeType_STRING;
|
|
case AttributeKind::ss:
|
|
return onnx::AttributeProto_AttributeType_STRINGS;
|
|
case AttributeKind::t:
|
|
return onnx::AttributeProto_AttributeType_TENSOR;
|
|
case AttributeKind::ts:
|
|
return onnx::AttributeProto_AttributeType_TENSORS;
|
|
case AttributeKind::ty:
|
|
return onnx::AttributeProto_AttributeType_TYPE_PROTO;
|
|
case AttributeKind::tys:
|
|
return onnx::AttributeProto_AttributeType_TYPE_PROTOS;
|
|
case AttributeKind::g:
|
|
return onnx::AttributeProto_AttributeType_GRAPH;
|
|
case AttributeKind::gs:
|
|
return onnx::AttributeProto_AttributeType_GRAPHS;
|
|
default:
|
|
std::ostringstream err_msg;
|
|
err_msg << "attribute \"" << name.toDisplayString()
|
|
<< "\" has unexpected kind: " << toString(at_kind);
|
|
throw std::runtime_error(err_msg.str());
|
|
}
|
|
}
|
|
|
|
GraphEncoder::GraphEncoder(
|
|
const std::shared_ptr<Graph>& graph,
|
|
int64_t onnx_opset_version,
|
|
onnx_torch::OperatorExportTypes operator_export_type,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
bool defer_weight_export,
|
|
bool strip_doc,
|
|
bool keep_initializers_as_inputs,
|
|
const std::map<std::string, int>& custom_opsets,
|
|
bool add_node_names,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path,
|
|
NodeAttrNameMap node_attr_to_name)
|
|
: model_proto_(std::make_shared<onnx::ModelProto>()),
|
|
|
|
operator_export_type_(operator_export_type),
|
|
strip_doc_(strip_doc),
|
|
defer_weight_export_(defer_weight_export),
|
|
use_external_data_format_(use_external_data_format),
|
|
onnx_opset_version_(onnx_opset_version),
|
|
custom_opsets_(custom_opsets),
|
|
graph_(graph),
|
|
node_attr_to_name_(std::move(node_attr_to_name)) {
|
|
model_proto_->set_producer_name("pytorch");
|
|
TORCH_CHECK(
|
|
onnx_opset_version > 0 &&
|
|
static_cast<size_t>(onnx_opset_version) <
|
|
kOpsetVersionToIRVersion.size() &&
|
|
kOpsetVersionToIRVersion[onnx_opset_version] != kInvalidOpsetVersion,
|
|
"Unsupported onnx_opset_version: ",
|
|
onnx_opset_version);
|
|
|
|
model_proto_->set_ir_version(kOpsetVersionToIRVersion[onnx_opset_version]);
|
|
model_proto_->set_producer_version(TORCH_VERSION);
|
|
validateGraph(graph, operator_export_type);
|
|
|
|
// If graph proto size exceed maximum protobuf size of 2GB, set
|
|
// use_external_data_format to true.
|
|
if (!use_external_data_format &&
|
|
GetGraphProtoSize(
|
|
model_proto_->mutable_graph(),
|
|
graph,
|
|
add_node_names,
|
|
use_external_data_format,
|
|
onnx_file_path,
|
|
initializers) > INT_MAX) {
|
|
GRAPH_DEBUG(
|
|
"Exporting model exceed maximum protobuf size of 2GB. Storing model parameters in external data files");
|
|
use_external_data_format = true;
|
|
// use_external_data_format_ is one of graph_encoder private variable set
|
|
// for return `use_external_data_format` value.
|
|
use_external_data_format_ = use_external_data_format;
|
|
}
|
|
|
|
if (use_external_data_format) {
|
|
TORCH_CHECK(
|
|
!onnx_file_path.empty(),
|
|
"The serialized model is larger than the 2GiB limit imposed by the protobuf library. ",
|
|
"Therefore the output file must be a file path, so that the ONNX external data can ",
|
|
"be written to the same directory. Please specify the output file name.");
|
|
}
|
|
|
|
auto* imp = model_proto_->add_opset_import();
|
|
// This is the version of ONNX operator set we are targeting
|
|
imp->set_version(onnx_opset_version);
|
|
|
|
EncodeGraph(
|
|
model_proto_->mutable_graph(),
|
|
graph,
|
|
initializers,
|
|
dynamic_axes,
|
|
keep_initializers_as_inputs,
|
|
add_node_names,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
|
|
for (const std::string& domain : domains_) {
|
|
auto* opset = model_proto_->add_opset_import();
|
|
opset->set_domain(domain);
|
|
// Check if domain version is registered. If not, set to version 1
|
|
auto it = custom_opsets.find(domain);
|
|
if (it == custom_opsets.end())
|
|
opset->set_version(1);
|
|
else {
|
|
opset->set_version(it->second);
|
|
}
|
|
}
|
|
|
|
for (auto const& custom_opset : custom_opsets) {
|
|
if (!std::count(domains_.begin(), domains_.end(), custom_opset.first)) {
|
|
TORCH_WARN(
|
|
"Custom opset domain: '",
|
|
custom_opset.first,
|
|
"' provided is not used in the model. ",
|
|
"Please verify custom opset domain names.");
|
|
}
|
|
}
|
|
}
|
|
|
|
// NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
void GraphEncoder::TensorTypeToONNXType(
|
|
const TensorTypePtr& tensor_type,
|
|
const std::string& dim_name_prefix,
|
|
const std::string& name,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
onnx::TypeProto_Tensor* onnx_tensor_type,
|
|
bool assign_dim_param) {
|
|
if (tensor_type->dim()) {
|
|
onnx::TensorShapeProto* shape = onnx_tensor_type->mutable_shape();
|
|
auto sizes = tensor_type->symbolic_sizes().sizes().value();
|
|
for (const auto i : c10::irange(sizes.size())) {
|
|
shape->add_dim();
|
|
if ((dynamic_axes.find(name) != dynamic_axes.end()) &&
|
|
(dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())) {
|
|
shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i));
|
|
if (!sizes[i].is_static()) {
|
|
symbol_dim_map_[sizes[i]] = dynamic_axes.at(name).at(i);
|
|
}
|
|
} else if (sizes[i].is_static()) {
|
|
shape->mutable_dim(i)->set_dim_value(sizes[i].static_size());
|
|
} else if (assign_dim_param) {
|
|
if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) {
|
|
symbol_dim_map_[sizes[i]] =
|
|
dim_name_prefix + name + "_dim_" + std::to_string(i);
|
|
}
|
|
shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]);
|
|
}
|
|
}
|
|
}
|
|
if (tensor_type->scalarType()) {
|
|
onnx_tensor_type->set_elem_type(
|
|
ATenTypeToOnnxType(tensor_type->scalarType().value()));
|
|
}
|
|
}
|
|
// NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
|
|
void GraphEncoder::EncodeValueInfoType(
|
|
onnx::TypeProto* onnx_type,
|
|
const TypePtr& node_type,
|
|
const Value* n,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>& dynamic_axes) {
|
|
std::string dim_name_prefix;
|
|
if (n->node()->kind() != prim::Param) {
|
|
dim_name_prefix = n->node()->kind().toUnqualString();
|
|
}
|
|
if (TensorTypePtr tensor_type = node_type->cast<TensorType>()) {
|
|
if (tensor_type->dim() || tensor_type->scalarType()) {
|
|
// Encode type if either shape or dtype exists.
|
|
onnx::TypeProto_Tensor* onnx_tensor_type =
|
|
onnx_type->mutable_tensor_type();
|
|
// Do not assign dim_param for sequence tensor type.
|
|
// Sequence of tensors could differ in dimension size.
|
|
// Use a dimension with neither dim_value nor dim_param set
|
|
// to denote an unknown dimension.
|
|
// Create and assign dim_param for normal tensor type.
|
|
auto is_sequence_tensor = static_cast<bool>(n->type()->cast<ListType>());
|
|
TensorTypeToONNXType(
|
|
tensor_type,
|
|
dim_name_prefix,
|
|
n->debugName(),
|
|
dynamic_axes,
|
|
onnx_tensor_type,
|
|
!is_sequence_tensor);
|
|
}
|
|
} else if (BoolTypePtr bool_type = node_type->cast<BoolType>()) {
|
|
onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
|
|
onnx_tensor_type->set_elem_type(ATenTypeToOnnxType(at::kBool));
|
|
} else if (IntTypePtr int_type = node_type->cast<IntType>()) {
|
|
onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
|
|
onnx_tensor_type->set_elem_type(ATenTypeToOnnxType(at::kLong));
|
|
} else if (FloatTypePtr float_type = node_type->cast<FloatType>()) {
|
|
onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
|
|
onnx_tensor_type->set_elem_type(ATenTypeToOnnxType(at::kFloat));
|
|
} else if (ListTypePtr list_type = node_type->cast<ListType>()) {
|
|
auto list_elem_type = list_type->getElementType();
|
|
onnx::TypeProto_Sequence* sequence_type =
|
|
onnx_type->mutable_sequence_type();
|
|
onnx::TypeProto* onnx_tensor_type = sequence_type->mutable_elem_type();
|
|
EncodeValueInfoType(onnx_tensor_type, list_elem_type, n, dynamic_axes);
|
|
} else if (OptionalTypePtr optional_type = node_type->cast<OptionalType>()) {
|
|
auto elem_type = optional_type->getElementType();
|
|
if (TensorTypePtr tensor_type = elem_type->cast<TensorType>()) {
|
|
onnx::TypeProto_Optional* onnx_optional_type =
|
|
onnx_type->mutable_optional_type();
|
|
onnx::TypeProto_Tensor* onnx_tensor_type =
|
|
onnx_optional_type->mutable_elem_type()->mutable_tensor_type();
|
|
TensorTypeToONNXType(
|
|
tensor_type,
|
|
dim_name_prefix,
|
|
n->debugName(),
|
|
dynamic_axes,
|
|
onnx_tensor_type);
|
|
} else if (ListTypePtr inner_node_type = elem_type->cast<ListType>()) {
|
|
auto list_elem_type = inner_node_type->getElementType();
|
|
if (TensorTypePtr tensor_type = list_elem_type->cast<TensorType>()) {
|
|
onnx::TypeProto_Optional* onnx_optional_type =
|
|
onnx_type->mutable_optional_type();
|
|
onnx::TypeProto_Sequence* onnx_optional_sequence_type =
|
|
onnx_optional_type->mutable_elem_type()->mutable_sequence_type();
|
|
onnx::TypeProto_Tensor* onnx_tensor_type =
|
|
onnx_optional_sequence_type->mutable_elem_type()
|
|
->mutable_tensor_type();
|
|
TensorTypeToONNXType(
|
|
tensor_type,
|
|
dim_name_prefix,
|
|
n->debugName(),
|
|
dynamic_axes,
|
|
onnx_tensor_type);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void GraphEncoder::EncodeValueInfo(
|
|
onnx::GraphProto* graph_proto,
|
|
onnx::ValueInfoProto* v,
|
|
const Value* n,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>& dynamic_axes) {
|
|
std::string name = n->debugName();
|
|
v->set_name(name);
|
|
EncodeValueInfoType(v->mutable_type(), n->type(), n, dynamic_axes);
|
|
}
|
|
|
|
void GraphEncoder::EncodeGraph(
|
|
onnx::GraphProto* graph_proto,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
bool keep_initializers_as_inputs,
|
|
bool add_node_names,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path) {
|
|
EncodeBlock(
|
|
graph_proto,
|
|
graph->block(),
|
|
initializers,
|
|
dynamic_axes,
|
|
keep_initializers_as_inputs,
|
|
add_node_names,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
}
|
|
|
|
void GraphEncoder::EncodeBlock(
|
|
onnx::GraphProto* graph_proto,
|
|
const Block* block,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>& dynamic_axes,
|
|
bool keep_initializers_as_inputs,
|
|
bool add_node_names,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path) {
|
|
TORCH_INTERNAL_ASSERT(graph_proto != nullptr);
|
|
if (nullptr == block->owningNode()) {
|
|
// Top level main graph.
|
|
graph_proto->set_name("main_graph");
|
|
} else {
|
|
// TODO: Set more meaningful name for sub-graphs.
|
|
std::string block_name = "sub_graph";
|
|
if (num_blocks_) {
|
|
block_name += std::to_string(num_blocks_);
|
|
}
|
|
num_blocks_++;
|
|
graph_proto->set_name(block_name);
|
|
}
|
|
|
|
// Since ONNX IR VERSION 4, initializers do not have to
|
|
// be a subset of graph inputs. We use keep_initializers_as_inputs
|
|
// argument to determine whether to add initializers
|
|
// as inputs or not. If keep_initializers_as_inputs=false,
|
|
// we only add non-parameter inputs as inputs to ONNX graph, and
|
|
// not the initializers (parameters). If keep_initializers_as_inputs
|
|
// =true, we add initializers as inputs too. Setting
|
|
// keep_initializers_as_inputs=false allows better
|
|
// optimizations, such as constant-folding, on ONNX graphs
|
|
// by backends/optimizers.
|
|
if (keep_initializers_as_inputs) {
|
|
for (auto input : block->inputs()) {
|
|
onnx::ValueInfoProto* v = graph_proto->add_input();
|
|
EncodeValueInfo(graph_proto, v, input, dynamic_axes);
|
|
}
|
|
} else {
|
|
for (auto input : block->inputs()) {
|
|
auto it = initializers.find(input->debugName());
|
|
if (it == initializers.end()) {
|
|
onnx::ValueInfoProto* v = graph_proto->add_input();
|
|
EncodeValueInfo(graph_proto, v, input, dynamic_axes);
|
|
}
|
|
}
|
|
}
|
|
for (auto output : block->outputs()) {
|
|
onnx::ValueInfoProto* v = graph_proto->add_output();
|
|
EncodeValueInfo(graph_proto, v, output, dynamic_axes);
|
|
}
|
|
for (auto node : block->nodes()) {
|
|
if (node->mustBeNone()) {
|
|
// None nodes are used to implement optional inputs. One
|
|
// way to "not provide" an optional input is to create an
|
|
// Undefined node, and pass its output as that input.
|
|
continue;
|
|
}
|
|
if (node->kind() == ::c10::Symbol::onnx("LocalFunctionDef")) {
|
|
auto* func_proto = model_proto_->add_functions();
|
|
EncodeLocalFunction(
|
|
graph_proto,
|
|
func_proto,
|
|
node,
|
|
add_node_names,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
continue;
|
|
}
|
|
auto* n_proto = graph_proto->add_node();
|
|
EncodeNode(
|
|
graph_proto,
|
|
n_proto,
|
|
node,
|
|
add_node_names,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
}
|
|
AddInitializersIntoGraphProto(
|
|
graph_proto,
|
|
block,
|
|
initializers,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
}
|
|
|
|
void GraphEncoder::AddInitializersIntoGraphProto(
|
|
onnx::GraphProto* graph_proto,
|
|
const Block* block,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path) {
|
|
TORCH_INTERNAL_ASSERT(block->inputs().size() >= initializers.size());
|
|
for (auto input : block->inputs()) {
|
|
auto name_tensor_pair = initializers.find(input->debugName());
|
|
if (name_tensor_pair == initializers.end()) {
|
|
continue;
|
|
}
|
|
auto p = graph_proto->add_initializer();
|
|
p->set_name(name_tensor_pair->first);
|
|
EncodeTensor(
|
|
p,
|
|
name_tensor_pair->second,
|
|
name_tensor_pair->first,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
}
|
|
}
|
|
|
|
unsigned long long int GraphEncoder::GetGraphProtoSize(
|
|
onnx::GraphProto* graph_proto,
|
|
const std::shared_ptr<Graph>& graph,
|
|
bool add_node_names,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path,
|
|
const std::map<std::string, at::Tensor>& initializers) {
|
|
// Model size = sum(size(initializers)) + sum(size(onnx_constant_nodes))
|
|
|
|
// Add up all Initializers
|
|
onnx::GraphProto graph_proto_copy = onnx::GraphProto(*graph_proto);
|
|
unsigned long long int size = graph_proto_copy.ByteSizeLong();
|
|
for (auto input : graph->inputs()) {
|
|
auto name_tensor_pair = initializers.find(input->debugName());
|
|
if (name_tensor_pair == initializers.end()) {
|
|
continue;
|
|
}
|
|
auto tensor_proto = graph_proto_copy.add_initializer();
|
|
const at::Tensor& tensor = name_tensor_pair->second;
|
|
for (auto d : tensor.sizes()) {
|
|
tensor_proto->add_dims(d);
|
|
}
|
|
tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
|
|
|
|
// Don't actually copy the buffer into tensor_proto since that is expensive.
|
|
// All we actually need is its size.
|
|
size += tensor_proto->ByteSizeLong();
|
|
size += tensor.element_size() * tensor.numel();
|
|
}
|
|
|
|
// Add up all onnx::Constant nodes that are Tensors
|
|
for (const auto& node : graph->nodes()) {
|
|
if (node->kind() == ::c10::onnx::Constant &&
|
|
node->hasAttribute(attr::value) &&
|
|
node->kindOf(attr::value) == AttributeKind::t) {
|
|
at::Tensor tensor = node->t(attr::value);
|
|
|
|
// Don't actually copy the buffer into n_proto since that is expensive.
|
|
// All we actually need is its size.
|
|
auto* n_proto = graph_proto_copy.add_node();
|
|
EncodeNode(
|
|
&graph_proto_copy,
|
|
n_proto,
|
|
node,
|
|
add_node_names,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
|
|
// Calculate the size of the tensor in bytes
|
|
size += n_proto->ByteSizeLong();
|
|
size += tensor.element_size() * tensor.numel();
|
|
}
|
|
}
|
|
return size;
|
|
}
|
|
|
|
void GraphEncoder::EncodeNode(
|
|
onnx::GraphProto* graph_proto,
|
|
onnx::NodeProto* node_proto,
|
|
const Node* node,
|
|
bool add_node_names,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path) {
|
|
if (!strip_doc_) {
|
|
node_proto->set_doc_string(node->sourceRange().str());
|
|
}
|
|
for (auto input : node->inputs()) {
|
|
if (input->node()->mustBeNone()) {
|
|
node_proto->add_input("");
|
|
} else {
|
|
node_proto->add_input(input->debugName());
|
|
}
|
|
}
|
|
for (auto output : node->outputs()) {
|
|
node_proto->add_output(output->debugName());
|
|
EncodeIntermediateValueInfo(graph_proto, output);
|
|
}
|
|
if (!node->kind().is_onnx()) {
|
|
std::string domain;
|
|
if (node->kind().is_aten() || node->kind().is_caffe2()) {
|
|
domain = node->kind().domainString();
|
|
} else { // Custom namespace and domain
|
|
domain = node->kind().ns().toUnqualString();
|
|
}
|
|
// TODO: set correct domain for function proto.
|
|
domains_.insert(domain);
|
|
node_proto->set_domain(domain);
|
|
}
|
|
if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!node->kind().is_aten() && !node->kind().is_prim() &&
|
|
!node->kind().is_attr());
|
|
}
|
|
node_proto->set_op_type(node->kind().toUnqualString());
|
|
const auto node_name_attribute_symbol =
|
|
Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute);
|
|
if (add_node_names) {
|
|
std::string node_name =
|
|
node_proto->op_type() + "_" + std::to_string(num_op_nodes_);
|
|
if (node->hasAttribute(node_name_attribute_symbol)) {
|
|
node_name = node->s(node_name_attribute_symbol);
|
|
}
|
|
node_proto->set_name(node_name);
|
|
onnx_node_name_map_[node] = node_name;
|
|
num_op_nodes_++;
|
|
}
|
|
auto attrs_it = node_attr_to_name_.find(node);
|
|
for (auto attr_name : node->attributeNames()) {
|
|
if (attr_name == node_name_attribute_symbol) {
|
|
// Skip the node name attribute.
|
|
continue;
|
|
}
|
|
if (attrs_it != node_attr_to_name_.end()) {
|
|
auto attr_it = attrs_it->second.find(attr_name.toUnqualString());
|
|
if (attr_it != attrs_it->second.end()) {
|
|
AddAttribute(
|
|
node_proto, attr_name, attr_it->second, node->kindOf(attr_name));
|
|
continue;
|
|
}
|
|
}
|
|
AddAttribute(
|
|
node_proto, node, attr_name, use_external_data_format, onnx_file_path);
|
|
}
|
|
if (node->kind() == ::c10::onnx::Loop) {
|
|
TORCH_INTERNAL_ASSERT(node->blocks().size() == 1);
|
|
|
|
auto body = node_proto->add_attribute();
|
|
body->set_name("body");
|
|
body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
|
|
auto g = body->mutable_g();
|
|
EncodeBlock(
|
|
g,
|
|
node->blocks()[0],
|
|
{},
|
|
{},
|
|
true,
|
|
true,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
}
|
|
if (node->kind() == ::c10::onnx::If) {
|
|
TORCH_INTERNAL_ASSERT(node->blocks().size() == 2);
|
|
|
|
auto then_branch = node_proto->add_attribute();
|
|
then_branch->set_name("then_branch");
|
|
then_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
|
|
auto true_g = then_branch->mutable_g();
|
|
EncodeBlock(
|
|
true_g,
|
|
node->blocks()[0],
|
|
{},
|
|
{},
|
|
true,
|
|
true,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
|
|
auto else_branch = node_proto->add_attribute();
|
|
else_branch->set_name("else_branch");
|
|
else_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
|
|
auto false_g = else_branch->mutable_g();
|
|
EncodeBlock(
|
|
false_g,
|
|
node->blocks()[1],
|
|
{},
|
|
{},
|
|
true,
|
|
true,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
}
|
|
}
|
|
|
|
void GraphEncoder::AddAttribute(
|
|
onnx::NodeProto* node_proto,
|
|
const jit::Symbol name,
|
|
const std::string& ref_attr_name,
|
|
const AttributeKind attr_kind) {
|
|
auto attr = node_proto->add_attribute();
|
|
TORCH_INTERNAL_ASSERT(name.is_attr());
|
|
attr->set_name(name.toUnqualString());
|
|
attr->set_ref_attr_name(ref_attr_name);
|
|
attr->set_type(ATenAttributeKindToOnnxAttributeType(attr_kind, name));
|
|
}
|
|
|
|
void GraphEncoder::AddAttribute(
|
|
onnx::NodeProto* node_proto,
|
|
const jit::Node* node,
|
|
const jit::Symbol name,
|
|
const bool use_external_data_format,
|
|
const std::string& onnx_file_path) {
|
|
auto createAttributeTensorName =
|
|
[](const onnx::NodeProto* node_proto,
|
|
onnx::TensorProto* tensor_proto,
|
|
const jit::Symbol attr_name,
|
|
size_t& num_external_data) -> std::string {
|
|
if (tensor_proto->has_name()) {
|
|
return tensor_proto->name();
|
|
}
|
|
if (!node_proto->has_name()) {
|
|
auto name = node_proto->op_type() + "_" + attr_name.toDisplayString() +
|
|
"_" + std::to_string(num_external_data);
|
|
num_external_data++;
|
|
return name;
|
|
} else {
|
|
return node_proto->name() + "_" + attr_name.toDisplayString();
|
|
}
|
|
};
|
|
|
|
auto attr = node_proto->add_attribute();
|
|
TORCH_INTERNAL_ASSERT(name.is_attr());
|
|
attr->set_name(name.toUnqualString());
|
|
attr->set_type(
|
|
ATenAttributeKindToOnnxAttributeType(node->kindOf(name), name));
|
|
switch (node->kindOf(name)) {
|
|
case AttributeKind::f:
|
|
attr->set_f(static_cast<float>(node->f(name)));
|
|
break;
|
|
case AttributeKind::fs:
|
|
for (auto& v : node->fs(name))
|
|
attr->add_floats(static_cast<float>(v));
|
|
break;
|
|
case AttributeKind::i:
|
|
attr->set_i(node->i(name));
|
|
break;
|
|
case AttributeKind::is:
|
|
for (auto& v : node->is(name))
|
|
attr->add_ints(v);
|
|
break;
|
|
case AttributeKind::s:
|
|
attr->set_s(node->s(name));
|
|
break;
|
|
case AttributeKind::ss:
|
|
for (auto& v : node->ss(name))
|
|
attr->add_strings(v);
|
|
break;
|
|
case AttributeKind::t: {
|
|
auto t = attr->mutable_t();
|
|
if (use_external_data_format && !t->has_name()) {
|
|
t->set_name(
|
|
createAttributeTensorName(node_proto, t, name, num_external_data_));
|
|
}
|
|
EncodeTensor(
|
|
t, node->t(name), {}, use_external_data_format, onnx_file_path);
|
|
} break;
|
|
case AttributeKind::ts:
|
|
for (auto& v : node->ts(name)) {
|
|
auto t = attr->add_tensors();
|
|
if (use_external_data_format && !t->has_name()) {
|
|
t->set_name(createAttributeTensorName(
|
|
node_proto, t, name, num_external_data_));
|
|
}
|
|
EncodeTensor(t, v, {}, use_external_data_format, onnx_file_path);
|
|
}
|
|
break;
|
|
case AttributeKind::ty: {
|
|
attr->set_type(onnx::AttributeProto_AttributeType_TYPE_PROTO);
|
|
auto tp = attr->mutable_tp();
|
|
const TypePtr& node_type = node->ty(name);
|
|
EncodeTypeProto(
|
|
tp, node_type, node_proto->op_type() + "_" + name.toDisplayString());
|
|
} break;
|
|
case AttributeKind::tys: {
|
|
attr->set_type(onnx::AttributeProto_AttributeType_TYPE_PROTOS);
|
|
size_t index = 0;
|
|
for (auto& v : node->tys(name)) {
|
|
auto tp = attr->add_type_protos();
|
|
EncodeTypeProto(
|
|
tp,
|
|
v,
|
|
node_proto->op_type() + "_" + name.toDisplayString() + "_" +
|
|
std::to_string(index));
|
|
index++;
|
|
}
|
|
} break;
|
|
case AttributeKind::g: {
|
|
auto g = attr->mutable_g();
|
|
EncodeGraph(
|
|
g,
|
|
node->g(name),
|
|
{},
|
|
{},
|
|
true,
|
|
true,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
} break;
|
|
case AttributeKind::gs:
|
|
for (auto& v : node->gs(name)) {
|
|
auto g = attr->add_graphs();
|
|
EncodeGraph(
|
|
g, v, {}, {}, true, true, use_external_data_format, onnx_file_path);
|
|
}
|
|
break;
|
|
default:
|
|
std::ostringstream err_msg;
|
|
err_msg << "attribute \"" << name.toDisplayString()
|
|
<< "\" has unexpected kind: " << toString(node->kindOf(name));
|
|
throw std::runtime_error(err_msg.str());
|
|
}
|
|
}
|
|
|
|
void GraphEncoder::AddAttribute(
|
|
onnx::FunctionProto* func_proto,
|
|
const std::string& name) {
|
|
TORCH_INTERNAL_ASSERT(nullptr != func_proto);
|
|
func_proto->add_attribute(name);
|
|
}
|
|
|
|
void GraphEncoder::EncodeLocalFunctionOpsetImport(
|
|
onnx::FunctionProto* func_proto,
|
|
const Node* n,
|
|
std::unordered_set<std::string>& custom_domains) {
|
|
if (!n->kind().is_onnx()) {
|
|
std::string domain;
|
|
if (n->kind().is_aten() || n->kind().is_caffe2()) {
|
|
domain = n->kind().domainString();
|
|
} else { // Custom namespace and domain
|
|
domain = n->kind().ns().toUnqualString();
|
|
}
|
|
domains_.insert(domain);
|
|
|
|
if (custom_domains.find(domain) == custom_domains.end()) {
|
|
custom_domains.insert(domain);
|
|
|
|
auto* custom_imp = func_proto->add_opset_import();
|
|
custom_imp->set_domain(domain);
|
|
// Check if domain version is registered. If not, set to version 1
|
|
auto it = custom_opsets_.find(domain);
|
|
if (it == custom_opsets_.end())
|
|
custom_imp->set_version(1);
|
|
else {
|
|
custom_imp->set_version(it->second);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto* b : n->blocks()) {
|
|
for (auto* sub_n : b->nodes()) {
|
|
EncodeLocalFunctionOpsetImport(func_proto, sub_n, custom_domains);
|
|
}
|
|
}
|
|
}
|
|
|
|
void GraphEncoder::EncodeLocalFunction(
|
|
onnx::GraphProto* graph_proto,
|
|
onnx::FunctionProto* func_proto,
|
|
const Node* n,
|
|
bool add_node_names,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path) {
|
|
const auto fsub_g = n->g(Symbol::attr("graph"));
|
|
func_proto->set_name(n->s(::c10::attr::name));
|
|
|
|
for (auto input : fsub_g->inputs()) {
|
|
func_proto->add_input(input->debugName());
|
|
}
|
|
for (auto output : fsub_g->outputs()) {
|
|
func_proto->add_output(output->debugName());
|
|
}
|
|
|
|
// encode attributes names
|
|
if (n->hasAttribute(Symbol::attr("attributes"))) {
|
|
for (const auto& attr_name : n->ss(Symbol::attr("attributes"))) {
|
|
AddAttribute(func_proto, attr_name);
|
|
}
|
|
}
|
|
|
|
auto* imp = func_proto->add_opset_import();
|
|
// This is the version of ONNX operator set we are targeting
|
|
imp->set_version(onnx_opset_version_);
|
|
|
|
// add for custom domain as well.
|
|
const auto& domain = n->s(Symbol::attr("domain"));
|
|
func_proto->set_domain(domain);
|
|
domains_.insert(domain);
|
|
std::unordered_set<std::string> custom_domains;
|
|
|
|
for (auto* fsub_n : fsub_g->nodes()) {
|
|
if (fsub_n->mustBeNone()) {
|
|
// None nodes are used to implement optional inputs. One
|
|
// way to "not provide" an optional input is to create an
|
|
// Undefined node, and pass its output as that input.
|
|
continue;
|
|
}
|
|
auto* n_proto = func_proto->add_node();
|
|
EncodeNode(
|
|
graph_proto,
|
|
n_proto,
|
|
fsub_n,
|
|
add_node_names,
|
|
use_external_data_format,
|
|
onnx_file_path);
|
|
EncodeLocalFunctionOpsetImport(func_proto, fsub_n, custom_domains);
|
|
}
|
|
}
|
|
|
|
void GraphEncoder::EncodeTypeProto(
|
|
onnx::TypeProto* type_proto,
|
|
const TypePtr& node_type,
|
|
const std::string& name) {
|
|
if (TensorTypePtr tensor_type = node_type->cast<TensorType>()) {
|
|
onnx::TypeProto_Tensor* onnx_tensor_type =
|
|
type_proto->mutable_tensor_type();
|
|
TensorTypeToONNXType(tensor_type, "", name, {}, onnx_tensor_type);
|
|
} else if (ListTypePtr list_type = node_type->cast<ListType>()) {
|
|
onnx::TypeProto_Sequence* seq_type = type_proto->mutable_sequence_type();
|
|
auto elem_type = list_type->getElementType();
|
|
EncodeTypeProto(seq_type->mutable_elem_type(), elem_type, name);
|
|
}
|
|
}
|
|
|
|
void GraphEncoder::EncodeTensor(
|
|
onnx::TensorProto* tensor_proto,
|
|
const at::Tensor& tensor,
|
|
const std::optional<std::string>& external_ref,
|
|
const bool use_external_data_format,
|
|
const std::string& onnx_file_path) {
|
|
for (auto d : tensor.sizes()) {
|
|
tensor_proto->add_dims(d);
|
|
}
|
|
tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
|
|
at::Tensor t;
|
|
// CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
|
|
// TODO We don't call .cpu() on quantized tensors as it fails when calling
|
|
// aten::empty() on quantized tensors beyond certain size. Issue #29435.
|
|
if (tensor.is_quantized()) {
|
|
t = tensor.contiguous();
|
|
} else {
|
|
t = tensor.contiguous().cpu();
|
|
}
|
|
|
|
// Either defer_weight_export should be true and external_ref must be present,
|
|
// or use_external_data_format should be true, not both at the same time. They
|
|
// can both be false at the same time (for ONNX export for regular model
|
|
// size).
|
|
TORCH_INTERNAL_ASSERT(
|
|
!((defer_weight_export_ && external_ref) && use_external_data_format));
|
|
// Add a buffer to the raw_data_export_map for the caller to dump into an
|
|
// external data store. If external_ref is not specified, we instead dump
|
|
// the contiguous data into the protobuf itself
|
|
if (defer_weight_export_ && external_ref) {
|
|
// For now, we use the name of the tensor as the external lookup name to
|
|
// avoid ONNX protobuf changes.
|
|
TORCH_INTERNAL_ASSERT(external_ref.value() == tensor_proto->name());
|
|
TORCH_INTERNAL_ASSERT(
|
|
raw_data_export_map_.count(external_ref.value()) == 0);
|
|
raw_data_export_map_[external_ref.value()] = t;
|
|
tensor_proto->set_raw_data("__EXTERNAL");
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(t.is_contiguous());
|
|
size_t tensorSize = static_cast<size_t>(c10::multiply_integers(
|
|
std::begin(tensor.sizes()), std::end(tensor.sizes())));
|
|
if (use_external_data_format &&
|
|
tensorSize > ParamSizeThresholdForExternalStorage) {
|
|
TORCH_INTERNAL_ASSERT(!onnx_file_path.empty());
|
|
TORCH_INTERNAL_ASSERT(tensor_proto->has_name());
|
|
auto tensorName = GetExternalFileName(tensor_proto->name());
|
|
CreateExternalFile(t, tensorName, onnx_file_path);
|
|
onnx::StringStringEntryProto* location =
|
|
tensor_proto->mutable_external_data()->Add();
|
|
location->set_key("location");
|
|
location->set_value(tensorName);
|
|
tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL);
|
|
} else {
|
|
// According to ParseData function's comments in onnx, tensor data is
|
|
// always little endian.
|
|
tensor_proto->set_raw_data(get_little_endian_data(t));
|
|
}
|
|
}
|
|
}
|
|
|
|
void GraphEncoder::EncodeIntermediateValueInfo(
|
|
onnx::GraphProto* graph_proto,
|
|
const Value* v) {
|
|
// Motivation is to encode ValueInfo for onnx local function nodes.
|
|
auto n = v->node();
|
|
if (n->kind().is_onnx() || n->kind().is_aten()) {
|
|
// Encode value info only for non-onnx or non-ATen nodes.
|
|
return;
|
|
}
|
|
if (n->owningGraph() != graph_.get()) {
|
|
// Encode value info only for node in main graph.
|
|
return;
|
|
}
|
|
for (const auto* o : graph_->outputs()) {
|
|
// Do not encode value info for graph outputs.
|
|
if (o == v) {
|
|
return;
|
|
}
|
|
}
|
|
auto v_info_p = graph_proto->add_value_info();
|
|
EncodeValueInfo(graph_proto, v_info_p, v);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::string pretty_print_onnx(
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
int64_t onnx_opset_version,
|
|
bool defer_weight_export,
|
|
::torch::onnx::OperatorExportTypes operator_export_type,
|
|
bool google_printer,
|
|
bool keep_initializers_as_inputs,
|
|
const std::map<std::string, int>& custom_opsets,
|
|
bool add_node_names) {
|
|
auto graph_encoder = GraphEncoder(
|
|
graph,
|
|
onnx_opset_version,
|
|
operator_export_type,
|
|
initializers,
|
|
std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<int64_t, std::string>>{},
|
|
defer_weight_export,
|
|
true,
|
|
keep_initializers_as_inputs,
|
|
custom_opsets,
|
|
add_node_names,
|
|
false,
|
|
std::string());
|
|
if (google_printer) {
|
|
return graph_encoder.get_model_proto()->DebugString();
|
|
}
|
|
return prettyPrint(*graph_encoder.get_model_proto());
|
|
}
|
|
|
|
std::tuple<
|
|
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
|
|
RawDataExportMap,
|
|
SymbolDimMap,
|
|
bool,
|
|
NodeNameMap>
|
|
export_onnx(
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::map<std::string, at::Tensor>& initializers,
|
|
int64_t onnx_opset_version,
|
|
const std::unordered_map<
|
|
std::string,
|
|
std::unordered_map<std::int64_t, std::string>>& dynamic_axes,
|
|
bool defer_weight_export,
|
|
::torch::onnx::OperatorExportTypes operator_export_type,
|
|
bool strip_doc_string,
|
|
bool keep_initializers_as_inputs,
|
|
const std::map<std::string, int>& custom_opsets,
|
|
bool add_node_names,
|
|
bool use_external_data_format,
|
|
const std::string& onnx_file_path,
|
|
const NodeAttrNameMap& node_attr_to_name) {
|
|
auto graph_encoder = GraphEncoder(
|
|
graph,
|
|
onnx_opset_version,
|
|
operator_export_type,
|
|
initializers,
|
|
dynamic_axes,
|
|
defer_weight_export,
|
|
strip_doc_string,
|
|
keep_initializers_as_inputs,
|
|
custom_opsets,
|
|
add_node_names,
|
|
use_external_data_format,
|
|
onnx_file_path,
|
|
node_attr_to_name);
|
|
GRAPH_DEBUG("onnx proto:", prettyPrint(*graph_encoder.get_model_proto()));
|
|
return std::make_tuple(
|
|
graph_encoder.get_model_proto(),
|
|
graph_encoder.get_raw_data_export_map(),
|
|
graph_encoder.get_symbol_dim_param_map(),
|
|
graph_encoder.get_use_external_data_format(),
|
|
graph_encoder.get_onnx_node_names());
|
|
}
|
|
|
|
std::string serialize_model_proto_to_string(
|
|
const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto) {
|
|
return model_proto->SerializeAsString();
|
|
}
|
|
|
|
void check_onnx_proto(const std::string& proto_string) {
|
|
onnx::ModelProto model;
|
|
if (!ParseProtoFromBytes(&model, proto_string.c_str(), proto_string.size())) {
|
|
throw std::runtime_error("Invalid ONNX proto string.");
|
|
return;
|
|
}
|
|
// 1. baseline check
|
|
// These two checks prevent broken graph being generated
|
|
// And errors out exporting if that happens.
|
|
onnx::checker::check_model(model);
|
|
onnx::shape_inference::InferShapes(model);
|
|
// 2. full check
|
|
// apply strict mode shape type inference check which examines
|
|
// whether it's a valid ONNX graph or not. As for some users, they
|
|
// don't need a fully valid ONNX graph to run their model, we simply
|
|
// add this information as warning message if it fails.
|
|
try {
|
|
auto* schema_registry = onnx::OpSchemaRegistry::Instance();
|
|
onnx::ShapeInferenceOptions options{
|
|
/*check_type_val=*/true,
|
|
/*strict_mode_val=*/true};
|
|
onnx::shape_inference::InferShapes(model, schema_registry, options);
|
|
} catch (const onnx::InferenceError& ex) {
|
|
TORCH_WARN(
|
|
"The exported ONNX model failed ONNX shape inference. "
|
|
"The model will not be executable by the ONNX Runtime. "
|
|
"If this is unintended and you believe there is a bug, "
|
|
"please report an issue at https://github.com/pytorch/pytorch/issues. "
|
|
"Error reported by strict ONNX shape inference: ",
|
|
ex.what());
|
|
}
|
|
}
|
|
|
|
} // namespace torch::jit
|