mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 23:54:56 +08:00
[JIT] Make aot autograd decompositions usable in JIT, add script for serializing the decompositions (#73938)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73938 This is a first step in porting and making usable all of the decompositions defined in [functorch](https://github.com/pytorch/functorch/blob/main/functorch/_src/decompositions.py#L349) in core and in JIT as well as C++. The decompositions are defined in python, scripted and inlined, and then serialized as C++ code which TorchScript can parse. The workflow is edit python decomposition file then run [tools/codegen/decompositions/gen_jit_decompositions.py](https://github.com/pytorch/pytorch/pull/73938/files#diff-6adef2116be233c3524e3b583e373ab0ffc9169beb6c1f6d96b5d0385e75afa1). Decompositions are mapped to their corresponding aten schemas via the schema in their python def. This allows multiple decompositions for an overloaded op like `aten.var` (shown here in the example). This is just a first PR, i'm sure there will be many follows ups such as: - making these runnable in C++ with simple executor - porting over more decompositions from AOT Autograd - Using opinfos / more robust testing - Categorizing decompositions - Hooking in decompositions at various points of JIT execution Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D34938126 Pulled By: eellison fbshipit-source-id: 9559a7cb731982e3a726f2f95af498b84fb09c13 (cherry picked from commit a4e0e748791e378e7e12a9dd0b63fb3c62dc1890)
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6ed689173
commit
aacdf291e0
@ -77,6 +77,7 @@
|
||||
#include <torch/csrc/jit/python/script_init.h>
|
||||
#include <torch/csrc/jit/runtime/argument_spec.h>
|
||||
#include <torch/csrc/jit/runtime/autodiff.h>
|
||||
#include <torch/csrc/jit/runtime/decomposition_registry.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <torch/csrc/jit/runtime/jit_trace.h>
|
||||
@ -160,6 +161,15 @@ void initJITBindings(PyObject* module) {
|
||||
}
|
||||
return shapeComputeGraphForSchema(n->schema());
|
||||
})
|
||||
.def(
|
||||
"_jit_decomposition_graph_for_node",
|
||||
[](Node* n) -> c10::optional<std::shared_ptr<Graph>> {
|
||||
if (!n->maybeSchema()) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
return DecompositionGraphForSchema(n->schema());
|
||||
})
|
||||
.def("_jit_pass_run_decompositions", RunDecompositions)
|
||||
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
|
||||
.def(
|
||||
"_jit_pass_propagate_shapes_on_graph_and_build_compute",
|
||||
|
||||
Reference in New Issue
Block a user