mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[JIT][script][ONNX] ScriptModule ONNX export + ONNX export for control flow nodes (#6608)
* ScriptModule ONNX export * ScriptModule ONNX export * Export for control flow nodes * Add pretty-print capability for ONNX export testing * Update tests and handling of mutliple GraphProto names * Maybe bugfix? * factor out code from export and pretty print
This commit is contained in:
@ -40,25 +40,6 @@ void initPythonTracerBindings(PyObject* module_) {
|
||||
ASSERT_UNEXPIRED("pop_scope");
|
||||
s.pop_scope();
|
||||
})
|
||||
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers,
|
||||
int64_t onnx_opset_version, bool defer_weight_export=false) {
|
||||
ASSERT_UNEXPIRED("export");
|
||||
std::string graph;
|
||||
RawDataExportMap export_map;
|
||||
std::tie(graph, export_map) = ExportGraph(
|
||||
s.graph, initializers, onnx_opset_version, defer_weight_export);
|
||||
std::unordered_map<std::string, py::bytes> python_serialized_export_map;
|
||||
for (auto& kv : export_map) {
|
||||
auto t = kv.second;
|
||||
size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
|
||||
// TODO: this is an unecessary copy. In theory we can directly return
|
||||
// the map from identifier to Tensor, but we need some API in Python
|
||||
// to get raw `bytes` containing the raw tensor data.
|
||||
python_serialized_export_map[kv.first] = py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
|
||||
}
|
||||
return std::make_tuple(
|
||||
py::bytes(graph), python_serialized_export_map);
|
||||
})
|
||||
.def("set_graph", [](TracingState& s, std::shared_ptr<Graph> g) {
|
||||
s.graph = g;
|
||||
})
|
||||
|
Reference in New Issue
Block a user