mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Introduce scopes during tracing (#3016)
This commit is contained in:
@ -19,7 +19,7 @@ namespace torch { namespace jit {
|
||||
|
||||
void initPythonTracerBindings(PyObject* module_) {
|
||||
auto m = py::handle(module_).cast<py::module>();
|
||||
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState")
|
||||
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
|
||||
// NB: no constructor; you have to get it from C++ code
|
||||
.def("__repr__", [](const TracingState& s) {
|
||||
std::ostringstream ss;
|
||||
@ -32,6 +32,14 @@ void initPythonTracerBindings(PyObject* module_) {
|
||||
ss << *s.graph;
|
||||
return ss.str();
|
||||
})
|
||||
.def("push_scope", [](TracingState& s, const std::string& scope_name) {
|
||||
ASSERT_UNEXPIRED("push_scope");
|
||||
s.push_scope(scope_name);
|
||||
})
|
||||
.def("pop_scope", [](TracingState& s) {
|
||||
ASSERT_UNEXPIRED("pop_scope");
|
||||
s.pop_scope();
|
||||
})
|
||||
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers, int64_t onnx_opset_version) {
|
||||
ASSERT_UNEXPIRED("export");
|
||||
return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version));
|
||||
@ -52,6 +60,12 @@ void initPythonTracerBindings(PyObject* module_) {
|
||||
m.def("_tracer_exit", [](variable_list var_outputs) {
|
||||
tracer::exit(var_outputs);
|
||||
});
|
||||
m.def("_get_tracing_state", [](const variable_list& vars) {
|
||||
return getTracingState(vars);
|
||||
});
|
||||
m.def("_is_tracing", [](const variable_list& vars) {
|
||||
return isTracing(vars);
|
||||
});
|
||||
}
|
||||
|
||||
}} // namespace torch::jit
|
||||
|
||||
Reference in New Issue
Block a user