[JIT] Register decomp reland

Reland of https://github.com/pytorch/pytorch/pull/76252
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76397
Approved by: https://github.com/davidberard98
This commit is contained in:
Elias Ellison
2022-04-26 23:17:18 +00:00
committed by PyTorch MergeBot
parent 4d1d1b3179
commit 81b9cb741c
8 changed files with 81 additions and 16 deletions

View File

@ -1,6 +1,7 @@
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <ATen/core/operator_name.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_init.h>
#include <torch/csrc/jit/codegen/cuda/interface.h>
@ -168,7 +169,7 @@ void initJITBindings(PyObject* module) {
if (!n->maybeSchema()) {
return c10::nullopt;
}
return DecompositionGraphForSchema(n->schema());
return GetDecomposition(n->schema());
})
.def("_jit_pass_run_decompositions", RunDecompositions)
// using Node* here instead of Schema because looking up the schema
@ -184,6 +185,16 @@ void initJITBindings(PyObject* module) {
TORCH_INTERNAL_ASSERT(false, "Expected schema", n);
}
})
.def(
"_jit_register_decomposition_for_schema",
[](const FunctionSchema& s, std::shared_ptr<Graph>& graph) {
// because this is invoked by python, the function schema *
// becomes different, and we need to find and reuse the
// one that is used for caching
auto op =
findOperatorFor(c10::OperatorName(s.name(), s.overload_name()));
RegisterDecomposition(op->schema(), graph);
})
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
.def(
"_jit_pass_propagate_shapes_on_graph_and_build_compute",