Refactor PE so fusion specializations are configurable (#71650)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71650

*

Refactors PE so there is a current fusion strategy set, which will take in a vector of e.g. [(STATIC, 2), (DYNAMIC, 10)] which means fuse two static invocations then fuse 10 dynamic ones, then stop specializing.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33801501

Pulled By: eellison

fbshipit-source-id: ebc7ac3c57e35a3b9bb15ab751f0aa1d25cc9bd5
(cherry picked from commit 8dd89088d3ceae800ea110d0b6949b759d4fe582)
This commit is contained in:
Elias Ellison
2022-02-01 11:00:10 -08:00
committed by PyTorch MergeBot
parent cf1833df70
commit f1499d6c18
6 changed files with 101 additions and 22 deletions

View File

@ -123,6 +123,7 @@
#include <string>
#include <tuple>
#include <utility>
#include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
namespace torch {
namespace jit {
@ -752,10 +753,31 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_set_bailout_depth",
[](size_t depth) {
TORCH_WARN("Use _jit_set_fusion_strategy, bailout depth is deprecated. Setting to (STATIC, ", depth, ")");
size_t old_depth = getBailoutDepth();
getBailoutDepth() = depth;
FusionStrategy strat = {{FusionBehavior::STATIC, depth}};
setFusionStrategy(strat);
return old_depth;
})
.def("_jit_set_fusion_strategy",
[](std::vector<std::pair<std::string, size_t>> strategy) {
FusionStrategy vec_conv;
for (const auto& pair: strategy) {
if (pair.first == "STATIC") {
vec_conv.emplace_back(FusionBehavior::STATIC, pair.second);
} else if (pair.first == "DYNAMIC") {
vec_conv.emplace_back(FusionBehavior::DYNAMIC, pair.second);
} else {
TORCH_INTERNAL_ASSERT("FusionBehavior only supported 'STATIC' or 'DYNAMIC', got: ", pair.first);
}
}
auto old_strategy = getFusionStrategy();
auto strat = fmap(old_strategy, [](std::pair<FusionBehavior, size_t> behav) {
return std::pair<std::string, size_t>(behav.first == FusionBehavior::STATIC ? "STATIC" : "DYNAMIC", behav.second);
});
setFusionStrategy(vec_conv);
return strat;
})
.def(
"_jit_set_inline_everything_mode",
[](bool enabled) { getInlineEverythingMode() = enabled; })