diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 06661d31a90c..49f5ab305de8 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -58,13 +58,12 @@ def texpr_reductions_enabled(): torch._C._jit_set_texpr_reductions_enabled(old) @contextlib.contextmanager -def texpr_dynamic_enabled(): - old = torch._C._jit_texpr_dynamic_shape_enabled() - torch._C._jit_set_texpr_dynamic_shape_enabled(True) +def texpr_enable_strategy(strategy): + old = torch._C._jit_set_fusion_strategy(strategy) try: yield finally: - torch._C._jit_set_texpr_dynamic_shape_enabled(old) + torch._C._jit_set_fusion_strategy(old) @contextlib.contextmanager def inline_fusion_groups(): @@ -1921,7 +1920,7 @@ class TestTEFuser(JitTestCase): size = [2, 3, 4, 5] size[i] = 1 inp = torch.rand(size).to(memory_format=torch.channels_last) - with texpr_dynamic_enabled(): + with texpr_enable_strategy([("DYNAMIC", 20)]): foo_s = torch.jit.trace(eager, (inp, inp)) for _ in range(3): out = foo_s(inp, inp) @@ -1931,6 +1930,23 @@ class TestTEFuser(JitTestCase): g = torch.jit.last_executed_optimized_graph() FileCheck().check("TensorExpr").run(g) + def test_exhaust_specializations(self): + with texpr_enable_strategy([("STATIC", 1)]): + @torch.jit.script + def foo(x): + return x + x + x + + for _ in range(3): + foo(torch.rand([2, 2])) + + for _ in range(3): + foo(torch.rand([4, 4, 4])) + + g = torch.jit.last_executed_optimized_graph() + torch._C._jit_pass_inline(g) + + FileCheck().check_count("TensorExpr", 2, exactly=True).run(g) + def test_unsqueeze_var_dim(self): def eager(x, y, z: int): return x * torch.unsqueeze(y, dim=z) @@ -2094,7 +2110,7 @@ class TestTEFuser(JitTestCase): lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last), ) - with texpr_dynamic_enabled(): + with texpr_enable_strategy([("DYNAMIC", 20)]): def foo(x, y, z): return torch.sigmoid(torch.tanh(x)) diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 925ab21c390c..aefb164b7fff 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -123,6 +123,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -752,10 +753,31 @@ void initJITBindings(PyObject* module) { .def( "_jit_set_bailout_depth", [](size_t depth) { + TORCH_WARN("Use _jit_set_fusion_strategy, bailout depth is deprecated. Setting to (STATIC, ", depth, ")"); size_t old_depth = getBailoutDepth(); - getBailoutDepth() = depth; + FusionStrategy strat = {{FusionBehavior::STATIC, depth}}; + setFusionStrategy(strat); return old_depth; }) + .def("_jit_set_fusion_strategy", + [](std::vector> strategy) { + FusionStrategy vec_conv; + for (const auto& pair: strategy) { + if (pair.first == "STATIC") { + vec_conv.emplace_back(FusionBehavior::STATIC, pair.second); + } else if (pair.first == "DYNAMIC") { + vec_conv.emplace_back(FusionBehavior::DYNAMIC, pair.second); + } else { + TORCH_INTERNAL_ASSERT("FusionBehavior only supported 'STATIC' or 'DYNAMIC', got: ", pair.first); + } + } + auto old_strategy = getFusionStrategy(); + auto strat = fmap(old_strategy, [](std::pair behav) { + return std::pair(behav.first == FusionBehavior::STATIC ? "STATIC" : "DYNAMIC", behav.second); + }); + setFusionStrategy(vec_conv); + return strat; + }) .def( "_jit_set_inline_everything_mode", [](bool enabled) { getInlineEverythingMode() = enabled; }) diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index 1f6cf315cfeb..b24fa49461e7 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -773,7 +773,7 @@ c10::intrusive_ptr GraphExecutor::runAsync( } size_t GraphExecutor::getDefaultNumBailOuts() { - return getProfilingMode() ? getBailoutDepth().load() : 0; + return getProfilingMode() ? getBailoutDepth() : 0; } const ExecutionPlan& GraphExecutor::getPlanFor( diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index 9610d193a1e5..381fcc41062b 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -104,7 +104,7 @@ TORCH_API std::shared_ptr lastExecutedOptimizedGraph(); TORCH_API std::atomic& getProfilingMode(); TORCH_API std::atomic& getExecutorMode(); TORCH_API std::atomic& getNumProfiledRuns(); -TORCH_API std::atomic& getBailoutDepth(); +TORCH_API size_t getBailoutDepth(); TORCH_API bool IsNewExecutorEnabled(); struct TORCH_API GraphOptimizerEnabledGuard { diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 4848ce981fb8..6c87b80a0ae7 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -68,6 +68,22 @@ static std::atomic executor_mode{true}; static std::atomic profiling_mode{true}; #endif +static std::mutex fusion_strategy_lock; +static FusionStrategy fusion_strategy = {{FusionBehavior::STATIC, 20}}; + +FusionStrategy getFusionStrategy() { + std::lock_guard guard(fusion_strategy_lock); + FusionStrategy strategy = fusion_strategy; + return strategy; +} + +FusionStrategy setFusionStrategy(FusionStrategy& strategy) { + std::lock_guard guard(fusion_strategy_lock); + auto old_strategy = fusion_strategy; + fusion_strategy = strategy; + return old_strategy; +} + static std::atomic num_profiled_runs{kDefaultNumProfiledRuns}; static std::atomic bailout_depth{kDefaultBailoutDepth}; @@ -88,13 +104,13 @@ std::atomic& getNumProfiledRuns() { return num_profiled_runs; } -std::atomic& getBailoutDepth() { +size_t getBailoutDepth() { // Initialize bailout_depth from command-line flag. - static const size_t init = []() { - return bailout_depth = FLAGS_torch_jit_bailout_depth; - }(); - (void)init; // Silence clang-tidy. - return bailout_depth; + size_t depth = 0; + for (const auto& pair: getFusionStrategy()) { + depth += pair.second; + } + return depth; } static bool needsGradientInProfilingMode(Block* b) { @@ -350,7 +366,21 @@ void runPreAutodiffPassPipeline(std::shared_ptr& graph) { "After CheckInplace (end of runPreAutodiffPassPipeline)\n", *graph); } -void runNoGradOptimizations(std::shared_ptr& graph) { +FusionBehavior getCurrentBehavior(size_t remaining_depth) { + size_t curr_depth = 0; + auto curr_strategy = getFusionStrategy(); + for (int i = static_cast(curr_strategy.size()) -1; i >= 0; i--) { + curr_depth += curr_strategy[i].second; + if (remaining_depth <= curr_depth) { + return curr_strategy[i].first; + } + } + // should never get here + TORCH_WARN("Stratgy changed mid-invocation, NYI"); + return FusionBehavior::STATIC; +} + +void runNoGradOptimizations(std::shared_ptr& graph, size_t remaining_bailout_depth) { GRAPH_DEBUG( "After customPostPasses (beginning of runNoGradOptimizations)\n", *graph); // runNondiffOptimization @@ -383,7 +413,7 @@ void runNoGradOptimizations(std::shared_ptr& graph) { BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); auto min_size = getFusionGroupInlining() ? 2 : 1; - auto dyn_shapes = tensorExprDynamicShapeFusionEnabled(); + bool dyn_shapes = getCurrentBehavior(remaining_bailout_depth) == FusionBehavior::DYNAMIC; FuseTensorExprs(graph, min_size, /*composed_op*/false, dyn_shapes); GRAPH_DEBUG( "After Fusion, before RemoveTensorTypeSpecializations\n", *graph); @@ -412,7 +442,7 @@ void runNoGradOptimizations(std::shared_ptr& graph) { } void ProfilingGraphExecutorImpl::runProfilingOptimizations( - std::shared_ptr& copy) { + std::shared_ptr& copy, size_t remaining_bailout_depth) { GRAPH_DEBUG("Before runProfilingOptimizations:\n", *copy); if (!getGraphExecutorOptimize()) { runNooptPassPipeline(copy); @@ -460,7 +490,7 @@ void ProfilingGraphExecutorImpl::runProfilingOptimizations( GRAPH_DEBUG( "After InlineAutodiffSubgraphs and Removing Profiling Nodes\n", *copy); } else { - runNoGradOptimizations(copy); + runNoGradOptimizations(copy, remaining_bailout_depth); } EliminateDeadCode(copy); GRAPH_DEBUG("After runProfilingOptimizations:\n", *copy); @@ -604,7 +634,7 @@ const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor( auto copy = pr_->graph()->copy(); ProfilingRecord::removeProfileCounter(copy->block()); - runProfilingOptimizations(copy); + runProfilingOptimizations(copy, *remaining_bailout_depth_); // replaces a fallback graph inserted by // specialize_autogradzero if one exists replaceFallbackGraphWithFallbackFunction(copy->block()); diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h index 88a315794389..b9c62377e49b 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h @@ -4,6 +4,17 @@ namespace torch { namespace jit { +enum class FusionBehavior { STATIC, DYNAMIC }; + +using FusionStrategy = std::vector>; +// E.g. {(Static, 2), (Dynamic, 10)} +// Fuse with static shapes twice then fallback to 10 dynamic fusions, then stop +// // compiling new fusion groups +TORCH_API FusionStrategy getFusionStrategy(); +// returns previous strategy +TORCH_API FusionStrategy setFusionStrategy(FusionStrategy& fusion_strategy); + + struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase { ProfilingGraphExecutorImpl( const std::shared_ptr& graph, @@ -22,7 +33,7 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase { optimized_plan_.reset(); // prevent memory leaks fallback_functions_.clear(); - remaining_bailout_depth_.reset(); + remaining_bailout_depth_.reset(); } bool isOptimized() const override { @@ -34,7 +45,7 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase { Stack& stack, size_t remaining_bailout_depth); void runProfilingInsensitiveOptimizations(std::shared_ptr& graph); - void runProfilingOptimizations(std::shared_ptr& graph); + void runProfilingOptimizations(std::shared_ptr& graph, size_t remaining_depth); void replaceFallbackGraphWithFallbackFunction(Block* b); std::unique_ptr pr_; c10::optional