mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
* Switch JIT passes to take a graph rather than TracingState * Add pybind11 binding for ONNX pass from graph * Fix canonicalize pass * address comment * Switch ToONNX to explicitly return new graph * optimize_graph instead of optimize_trace
96 lines
3.5 KiB
C++
96 lines
3.5 KiB
C++
#include <Python.h>
|
|
|
|
#include "torch/csrc/jit/python_tracer.h"
|
|
#include "torch/csrc/jit/tracer.h"
|
|
#include "torch/csrc/assertions.h"
|
|
#include "torch/csrc/jit/export.h"
|
|
#include "torch/csrc/jit/pybind.h"
|
|
#include "torch/csrc/utils/python_strings.h"
|
|
|
|
#include <sstream>
|
|
|
|
using namespace torch::autograd;
|
|
using namespace torch::jit;
|
|
using namespace torch::jit::tracer;
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
#define ASSERT_UNEXPIRED(METHOD_NAME) if (s.is_expired()) throw std::runtime_error("calling " METHOD_NAME " on an expired trace")
|
|
|
|
void initPythonTracerBindings(PyObject* module_) {
|
|
auto m = py::handle(module_).cast<py::module>();
|
|
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
|
|
// NB: no constructor; you have to get it from C++ code
|
|
.def("__repr__", [](const TracingState& s) {
|
|
std::ostringstream ss;
|
|
ss << "<TracingState " << (const void*)&s << ">";
|
|
return ss.str();
|
|
})
|
|
.def("__str__", [](const TracingState& s) -> std::string {
|
|
if (s.is_expired()) return "<expired TracingState>";
|
|
std::ostringstream ss;
|
|
ss << *s.graph;
|
|
return ss.str();
|
|
})
|
|
.def("push_scope", [](TracingState& s, const std::string& scope_name) {
|
|
ASSERT_UNEXPIRED("push_scope");
|
|
s.push_scope(scope_name);
|
|
})
|
|
.def("pop_scope", [](TracingState& s) {
|
|
ASSERT_UNEXPIRED("pop_scope");
|
|
s.pop_scope();
|
|
})
|
|
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers,
|
|
int64_t onnx_opset_version, bool defer_weight_export=false) {
|
|
ASSERT_UNEXPIRED("export");
|
|
std::string graph;
|
|
RawDataExportMap export_map;
|
|
std::tie(graph, export_map) = ExportGraph(
|
|
s.graph, initializers, onnx_opset_version, defer_weight_export);
|
|
std::unordered_map<std::string, py::bytes> python_serialized_export_map;
|
|
for (auto& kv : export_map) {
|
|
auto t = kv.second;
|
|
size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
|
|
// TODO: this is an unecessary copy. In theory we can directly return
|
|
// the map from identifier to Tensor, but we need some API in Python
|
|
// to get raw `bytes` containing the raw tensor data.
|
|
python_serialized_export_map[kv.first] = py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
|
|
}
|
|
return std::make_tuple(
|
|
py::bytes(graph), python_serialized_export_map);
|
|
})
|
|
.def("set_graph", [](TracingState& s, std::shared_ptr<Graph> g) {
|
|
s.graph = g;
|
|
})
|
|
.def("graph", [](TracingState& s) {
|
|
return s.graph;
|
|
})
|
|
.def_property_readonly("is_expired", [](TracingState& s) {
|
|
return s.is_expired();
|
|
})
|
|
.def_property_readonly("is_complete", [](TracingState& s) {
|
|
return s.is_complete();
|
|
});
|
|
|
|
m.def("_tracer_enter", [](variable_list trace_inputs, std::size_t num_backwards) {
|
|
return tracer::enter(std::move(trace_inputs), num_backwards + 1);
|
|
});
|
|
m.def("_tracer_exit", [](variable_list var_outputs) {
|
|
tracer::exit(var_outputs);
|
|
});
|
|
m.def("_get_tracing_state", [](const variable_list& vars) {
|
|
return getTracingState(vars);
|
|
});
|
|
m.def("_get_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var) {
|
|
return getValueTrace(state, var);
|
|
});
|
|
m.def("_set_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var, Value* value) {
|
|
return setValueTrace(state, var, value);
|
|
});
|
|
m.def("_is_tracing", [](const variable_list& vars) {
|
|
return isTracingVar(vars);
|
|
});
|
|
}
|
|
|
|
}} // namespace torch::jit
|