Introduce scopes during tracing (#3016)

This commit is contained in:
Luca Antiga
2017-12-04 18:19:06 +01:00
committed by Adam Paszke
parent 7ddcb91c7f
commit 4eb8e12765
11 changed files with 242 additions and 11 deletions

View File

@ -19,7 +19,7 @@ namespace torch { namespace jit {
void initPythonTracerBindings(PyObject* module_) {
auto m = py::handle(module_).cast<py::module>();
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState")
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
// NB: no constructor; you have to get it from C++ code
.def("__repr__", [](const TracingState& s) {
std::ostringstream ss;
@ -32,6 +32,14 @@ void initPythonTracerBindings(PyObject* module_) {
ss << *s.graph;
return ss.str();
})
.def("push_scope", [](TracingState& s, const std::string& scope_name) {
ASSERT_UNEXPIRED("push_scope");
s.push_scope(scope_name);
})
.def("pop_scope", [](TracingState& s) {
ASSERT_UNEXPIRED("pop_scope");
s.pop_scope();
})
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers, int64_t onnx_opset_version) {
ASSERT_UNEXPIRED("export");
return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version));
@ -52,6 +60,12 @@ void initPythonTracerBindings(PyObject* module_) {
m.def("_tracer_exit", [](variable_list var_outputs) {
tracer::exit(var_outputs);
});
m.def("_get_tracing_state", [](const variable_list& vars) {
return getTracingState(vars);
});
m.def("_is_tracing", [](const variable_list& vars) {
return isTracing(vars);
});
}
}} // namespace torch::jit