[JIT] Make aot autograd decompositions usable in JIT, add script for serializing the decompositions (#73938)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73938

This is a first step in porting and making usable all of the decompositions defined in [functorch](https://github.com/pytorch/functorch/blob/main/functorch/_src/decompositions.py#L349) in core and in JIT as well as C++.

The decompositions are defined in python, scripted and inlined, and then serialized as C++ code which TorchScript can parse. The workflow is edit python decomposition file then run [tools/codegen/decompositions/gen_jit_decompositions.py](https://github.com/pytorch/pytorch/pull/73938/files#diff-6adef2116be233c3524e3b583e373ab0ffc9169beb6c1f6d96b5d0385e75afa1).

Decompositions are mapped to their corresponding aten schemas via the schema in their python def. This allows multiple decompositions for an overloaded op like `aten.var` (shown here in the example).

This is just a first PR, i'm sure there will be many follows ups such as:
- making these runnable in C++ with simple executor
- porting over more decompositions from AOT Autograd
- Using opinfos / more robust testing
- Categorizing decompositions
- Hooking in decompositions at various points of JIT execution

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D34938126

Pulled By: eellison

fbshipit-source-id: 9559a7cb731982e3a726f2f95af498b84fb09c13
(cherry picked from commit a4e0e748791e378e7e12a9dd0b63fb3c62dc1890)
This commit is contained in:
Elias Ellison
2022-03-29 11:32:31 -07:00
committed by PyTorch MergeBot
parent a6ed689173
commit aacdf291e0
11 changed files with 460 additions and 1 deletions

View File

@ -77,6 +77,7 @@
#include <torch/csrc/jit/python/script_init.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/jit_trace.h>
@ -160,6 +161,15 @@ void initJITBindings(PyObject* module) {
}
return shapeComputeGraphForSchema(n->schema());
})
.def(
"_jit_decomposition_graph_for_node",
[](Node* n) -> c10::optional<std::shared_ptr<Graph>> {
if (!n->maybeSchema()) {
return c10::nullopt;
}
return DecompositionGraphForSchema(n->schema());
})
.def("_jit_pass_run_decompositions", RunDecompositions)
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
.def(
"_jit_pass_propagate_shapes_on_graph_and_build_compute",