mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 11:14:56 +08:00
Refactor PE so fusion specializations are configurable (#71650)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71650 * Refactors PE so there is a current fusion strategy set, which will take in a vector of e.g. [(STATIC, 2), (DYNAMIC, 10)] which means fuse two static invocations then fuse 10 dynamic ones, then stop specializing. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33801501 Pulled By: eellison fbshipit-source-id: ebc7ac3c57e35a3b9bb15ab751f0aa1d25cc9bd5 (cherry picked from commit 8dd89088d3ceae800ea110d0b6949b759d4fe582)
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf1833df70
commit
f1499d6c18
@ -123,6 +123,7 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -752,10 +753,31 @@ void initJITBindings(PyObject* module) {
|
||||
.def(
|
||||
"_jit_set_bailout_depth",
|
||||
[](size_t depth) {
|
||||
TORCH_WARN("Use _jit_set_fusion_strategy, bailout depth is deprecated. Setting to (STATIC, ", depth, ")");
|
||||
size_t old_depth = getBailoutDepth();
|
||||
getBailoutDepth() = depth;
|
||||
FusionStrategy strat = {{FusionBehavior::STATIC, depth}};
|
||||
setFusionStrategy(strat);
|
||||
return old_depth;
|
||||
})
|
||||
.def("_jit_set_fusion_strategy",
|
||||
[](std::vector<std::pair<std::string, size_t>> strategy) {
|
||||
FusionStrategy vec_conv;
|
||||
for (const auto& pair: strategy) {
|
||||
if (pair.first == "STATIC") {
|
||||
vec_conv.emplace_back(FusionBehavior::STATIC, pair.second);
|
||||
} else if (pair.first == "DYNAMIC") {
|
||||
vec_conv.emplace_back(FusionBehavior::DYNAMIC, pair.second);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT("FusionBehavior only supported 'STATIC' or 'DYNAMIC', got: ", pair.first);
|
||||
}
|
||||
}
|
||||
auto old_strategy = getFusionStrategy();
|
||||
auto strat = fmap(old_strategy, [](std::pair<FusionBehavior, size_t> behav) {
|
||||
return std::pair<std::string, size_t>(behav.first == FusionBehavior::STATIC ? "STATIC" : "DYNAMIC", behav.second);
|
||||
});
|
||||
setFusionStrategy(vec_conv);
|
||||
return strat;
|
||||
})
|
||||
.def(
|
||||
"_jit_set_inline_everything_mode",
|
||||
[](bool enabled) { getInlineEverythingMode() = enabled; })
|
||||
|
||||
Reference in New Issue
Block a user