mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
jit trace (#59949)
Summary:
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59949
Reviewed By: ZolotukhinM
Differential Revision: D31366787
Pulled By: Krovatkin
fbshipit-source-id: 798cbcd97e8ecfba984f98cd70214954be9309af
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f1b5f1898b
commit
a7ebf76a15
@ -91,6 +91,7 @@
|
||||
#include <torch/csrc/jit/runtime/autodiff.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <torch/csrc/jit/runtime/jit_trace.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/print_handler.h>
|
||||
#include <torch/csrc/jit/runtime/static/init.h>
|
||||
@ -520,6 +521,22 @@ void initJITBindings(PyObject* module) {
|
||||
},
|
||||
py::doc(
|
||||
"Interpret a JIT graph with given inputs without running any optimization passes on it"))
|
||||
.def(
|
||||
"_jit_trace_graph",
|
||||
[](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
|
||||
Stack stack;
|
||||
stack.reserve(inputs.size()); // captures?
|
||||
for (auto& obj : inputs) {
|
||||
stack.push_back(toTypeInferredIValue(obj));
|
||||
}
|
||||
auto g_inputs = graph->inputs();
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
if (stack[i].isTensor()) {
|
||||
g_inputs[i]->setType(stack[i].type());
|
||||
}
|
||||
}
|
||||
return TraceGraph(graph, stack);
|
||||
})
|
||||
.def("_jit_pass_remove_expands", RemoveExpands)
|
||||
.def("_jit_pass_erase_number_types", EraseNumberTypes)
|
||||
.def("_jit_pass_inline_fork_wait", InlineForkWait)
|
||||
|
||||
Reference in New Issue
Block a user