mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/129055 Approved by: https://github.com/r-barnes
242 lines
7.3 KiB
C++
242 lines
7.3 KiB
C++
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/serialization/onnx.h>
|
|
#include <torch/csrc/onnx/onnx.h>
|
|
|
|
#include <sstream>
|
|
#include <string>
|
|
|
|
namespace torch::jit {
|
|
|
|
namespace {
|
|
namespace onnx = ::ONNX_NAMESPACE;
|
|
|
|
// Pretty printing for ONNX
|
|
constexpr char indent_char = ' ';
|
|
constexpr size_t indent_multiplier = 2;
|
|
|
|
std::string idt(size_t indent) {
|
|
return std::string(indent * indent_multiplier, indent_char);
|
|
}
|
|
|
|
std::string nlidt(size_t indent) {
|
|
return std::string("\n") + idt(indent);
|
|
}
|
|
|
|
void dump(const onnx::TensorProto& tensor, std::ostream& stream) {
|
|
stream << "TensorProto shape: [";
|
|
for (const auto i : c10::irange(tensor.dims_size())) {
|
|
stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ? "" : " ");
|
|
}
|
|
stream << "]";
|
|
}
|
|
|
|
void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
|
|
for (const auto i : c10::irange(shape.dim_size())) {
|
|
auto& dim = shape.dim(i);
|
|
if (dim.has_dim_value()) {
|
|
stream << dim.dim_value();
|
|
} else {
|
|
stream << "?";
|
|
}
|
|
stream << (i == shape.dim_size() - 1 ? "" : " ");
|
|
}
|
|
}
|
|
|
|
void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) {
|
|
stream << "Tensor dtype: ";
|
|
if (tensor_type.has_elem_type()) {
|
|
stream << tensor_type.elem_type();
|
|
} else {
|
|
stream << "None.";
|
|
}
|
|
stream << ", ";
|
|
stream << "Tensor dims: ";
|
|
if (tensor_type.has_shape()) {
|
|
dump(tensor_type.shape(), stream);
|
|
} else {
|
|
stream << "None.";
|
|
}
|
|
}
|
|
|
|
void dump(const onnx::TypeProto& type, std::ostream& stream);
|
|
|
|
void dump(const onnx::TypeProto_Optional& optional_type, std::ostream& stream) {
|
|
stream << "Optional<";
|
|
if (optional_type.has_elem_type()) {
|
|
dump(optional_type.elem_type(), stream);
|
|
} else {
|
|
stream << "None";
|
|
}
|
|
stream << ">";
|
|
}
|
|
|
|
void dump(const onnx::TypeProto_Sequence& sequence_type, std::ostream& stream) {
|
|
stream << "Sequence<";
|
|
if (sequence_type.has_elem_type()) {
|
|
dump(sequence_type.elem_type(), stream);
|
|
} else {
|
|
stream << "None";
|
|
}
|
|
stream << ">";
|
|
}
|
|
|
|
void dump(const onnx::TypeProto& type, std::ostream& stream) {
|
|
if (type.has_tensor_type()) {
|
|
dump(type.tensor_type(), stream);
|
|
} else if (type.has_sequence_type()) {
|
|
dump(type.sequence_type(), stream);
|
|
} else if (type.has_optional_type()) {
|
|
dump(type.optional_type(), stream);
|
|
} else {
|
|
stream << "None";
|
|
}
|
|
}
|
|
|
|
void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
|
|
stream << "{name: \"" << value_info.name() << "\", type:";
|
|
dump(value_info.type(), stream);
|
|
stream << "}";
|
|
}
|
|
|
|
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
|
|
|
|
void dump(
|
|
const onnx::AttributeProto& attr,
|
|
std::ostream& stream,
|
|
size_t indent) {
|
|
stream << "{ name: '" << attr.name() << "', type: ";
|
|
if (attr.has_f()) {
|
|
stream << "float, value: " << attr.f();
|
|
} else if (attr.has_i()) {
|
|
stream << "int, value: " << attr.i();
|
|
} else if (attr.has_s()) {
|
|
stream << "string, value: '" << attr.s() << "'";
|
|
} else if (attr.has_g()) {
|
|
stream << "graph, value:\n";
|
|
dump(attr.g(), stream, indent + 1);
|
|
stream << nlidt(indent);
|
|
} else if (attr.has_t()) {
|
|
stream << "tensor, value:";
|
|
dump(attr.t(), stream);
|
|
} else if (attr.floats_size()) {
|
|
stream << "floats, values: [";
|
|
for (const auto i : c10::irange(attr.floats_size())) {
|
|
stream << attr.floats(i) << (i == attr.floats_size() - 1 ? "" : " ");
|
|
}
|
|
stream << "]";
|
|
} else if (attr.ints_size()) {
|
|
stream << "ints, values: [";
|
|
for (const auto i : c10::irange(attr.ints_size())) {
|
|
stream << attr.ints(i) << (i == attr.ints_size() - 1 ? "" : " ");
|
|
}
|
|
stream << "]";
|
|
} else if (attr.strings_size()) {
|
|
stream << "strings, values: [";
|
|
for (const auto i : c10::irange(attr.strings_size())) {
|
|
stream << "'" << attr.strings(i) << "'"
|
|
<< (i == attr.strings_size() - 1 ? "" : " ");
|
|
}
|
|
stream << "]";
|
|
} else if (attr.tensors_size()) {
|
|
stream << "tensors, values: [";
|
|
for (auto& t : attr.tensors()) {
|
|
dump(t, stream);
|
|
}
|
|
stream << "]";
|
|
} else if (attr.graphs_size()) {
|
|
stream << "graphs, values: [";
|
|
for (auto& g : attr.graphs()) {
|
|
dump(g, stream, indent + 1);
|
|
}
|
|
stream << "]";
|
|
} else {
|
|
stream << "UNKNOWN";
|
|
}
|
|
stream << "}";
|
|
}
|
|
|
|
void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) {
|
|
stream << "Node {type: \"" << node.op_type() << "\", inputs: [";
|
|
for (const auto i : c10::irange(node.input_size())) {
|
|
stream << node.input(i) << (i == node.input_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "], outputs: [";
|
|
for (const auto i : c10::irange(node.output_size())) {
|
|
stream << node.output(i) << (i == node.output_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "], attributes: [";
|
|
for (const auto i : c10::irange(node.attribute_size())) {
|
|
dump(node.attribute(i), stream, indent + 1);
|
|
stream << (i == node.attribute_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "]}";
|
|
}
|
|
|
|
void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
|
|
stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \""
|
|
<< graph.name() << "\"" << nlidt(indent + 1) << "inputs: [";
|
|
for (const auto i : c10::irange(graph.input_size())) {
|
|
dump(graph.input(i), stream);
|
|
stream << (i == graph.input_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "]" << nlidt(indent + 1) << "outputs: [";
|
|
for (const auto i : c10::irange(graph.output_size())) {
|
|
dump(graph.output(i), stream);
|
|
stream << (i == graph.output_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "]" << nlidt(indent + 1) << "value_infos: [";
|
|
for (const auto i : c10::irange(graph.value_info_size())) {
|
|
dump(graph.value_info(i), stream);
|
|
stream << (i == graph.value_info_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "]" << nlidt(indent + 1) << "initializers: [";
|
|
for (const auto i : c10::irange(graph.initializer_size())) {
|
|
dump(graph.initializer(i), stream);
|
|
stream << (i == graph.initializer_size() - 1 ? "" : ",");
|
|
}
|
|
stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2);
|
|
for (const auto i : c10::irange(graph.node_size())) {
|
|
dump(graph.node(i), stream, indent + 2);
|
|
if (i != graph.node_size() - 1) {
|
|
stream << "," << nlidt(indent + 2);
|
|
}
|
|
}
|
|
stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n";
|
|
}
|
|
|
|
void dump(
|
|
const onnx::OperatorSetIdProto& operator_set_id,
|
|
std::ostream& stream) {
|
|
stream << "OperatorSetIdProto { domain: " << operator_set_id.domain()
|
|
<< ", version: " << operator_set_id.version() << "}";
|
|
}
|
|
|
|
void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
|
|
stream << idt(indent) << "ModelProto {" << nlidt(indent + 1)
|
|
<< "producer_name: \"" << model.producer_name() << "\""
|
|
<< nlidt(indent + 1) << "domain: \"" << model.domain() << "\""
|
|
<< nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\"";
|
|
if (model.has_graph()) {
|
|
stream << nlidt(indent + 1) << "graph:\n";
|
|
dump(model.graph(), stream, indent + 2);
|
|
}
|
|
if (model.opset_import_size()) {
|
|
stream << idt(indent + 1) << "opset_import: [";
|
|
for (auto& opset_imp : model.opset_import()) {
|
|
dump(opset_imp, stream);
|
|
}
|
|
stream << "],\n";
|
|
}
|
|
stream << idt(indent) << "}\n";
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::string prettyPrint(const ::ONNX_NAMESPACE::ModelProto& model) {
|
|
std::ostringstream ss;
|
|
dump(model, ss, 0);
|
|
return ss.str();
|
|
}
|
|
|
|
} // namespace torch::jit
|