[JIT][script][ONNX] ScriptModule ONNX export + ONNX export for control flow nodes (#6608)

* ScriptModule ONNX export

* ScriptModule ONNX export

* Export for control flow nodes

* Add pretty-print capability for ONNX export testing

* Update tests and handling of mutliple GraphProto names

* Maybe bugfix?

* factor out code from export and pretty print
This commit is contained in:
James Reed
2018-04-19 23:45:03 -07:00
committed by GitHub
parent 945cb0fabc
commit ef76e24f60
24 changed files with 723 additions and 83 deletions

View File

@ -40,25 +40,6 @@ void initPythonTracerBindings(PyObject* module_) {
ASSERT_UNEXPIRED("pop_scope");
s.pop_scope();
})
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers,
int64_t onnx_opset_version, bool defer_weight_export=false) {
ASSERT_UNEXPIRED("export");
std::string graph;
RawDataExportMap export_map;
std::tie(graph, export_map) = ExportGraph(
s.graph, initializers, onnx_opset_version, defer_weight_export);
std::unordered_map<std::string, py::bytes> python_serialized_export_map;
for (auto& kv : export_map) {
auto t = kv.second;
size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
// TODO: this is an unecessary copy. In theory we can directly return
// the map from identifier to Tensor, but we need some API in Python
// to get raw `bytes` containing the raw tensor data.
python_serialized_export_map[kv.first] = py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
}
return std::make_tuple(
py::bytes(graph), python_serialized_export_map);
})
.def("set_graph", [](TracingState& s, std::shared_ptr<Graph> g) {
s.graph = g;
})