mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 08:34:52 +08:00
[JIT] Register decomp reland
Reland of https://github.com/pytorch/pytorch/pull/76252 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76397 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d1d1b3179
commit
81b9cb741c
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
|
||||
#include <ATen/core/operator_name.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/backends/backend_init.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/interface.h>
|
||||
@ -168,7 +169,7 @@ void initJITBindings(PyObject* module) {
|
||||
if (!n->maybeSchema()) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
return DecompositionGraphForSchema(n->schema());
|
||||
return GetDecomposition(n->schema());
|
||||
})
|
||||
.def("_jit_pass_run_decompositions", RunDecompositions)
|
||||
// using Node* here instead of Schema because looking up the schema
|
||||
@ -184,6 +185,16 @@ void initJITBindings(PyObject* module) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Expected schema", n);
|
||||
}
|
||||
})
|
||||
.def(
|
||||
"_jit_register_decomposition_for_schema",
|
||||
[](const FunctionSchema& s, std::shared_ptr<Graph>& graph) {
|
||||
// because this is invoked by python, the function schema *
|
||||
// becomes different, and we need to find and reuse the
|
||||
// one that is used for caching
|
||||
auto op =
|
||||
findOperatorFor(c10::OperatorName(s.name(), s.overload_name()));
|
||||
RegisterDecomposition(op->schema(), graph);
|
||||
})
|
||||
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
|
||||
.def(
|
||||
"_jit_pass_propagate_shapes_on_graph_and_build_compute",
|
||||
|
||||
Reference in New Issue
Block a user