Files
pytorch/torch/csrc/jit/runtime/graph_executor.h
Elias Ellison f1499d6c18 Refactor PE so fusion specializations are configurable (#71650)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71650

*

Refactors PE so there is a current fusion strategy set, which will take in a vector of e.g. [(STATIC, 2), (DYNAMIC, 10)] which means fuse two static invocations then fuse 10 dynamic ones, then stop specializing.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33801501

Pulled By: eellison

fbshipit-source-id: ebc7ac3c57e35a3b9bb15ab751f0aa1d25cc9bd5
(cherry picked from commit 8dd89088d3ceae800ea110d0b6949b759d4fe582)
2022-02-01 19:07:02 +00:00

139 lines
4.2 KiB
C++

#pragma once
#include <atomic>
#include <memory>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/python/update_graph_executor_opt.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/variable_tensor_list.h>
C10_DECLARE_bool(torch_jit_enable_new_executor);
namespace torch {
namespace jit {
struct GraphExecutorState;
struct Code;
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(
std::shared_ptr<Graph> graph,
std::string function_name,
size_t remaining_bailout_depth = 0)
: code(graph, std::move(function_name), remaining_bailout_depth),
graph(std::move(graph)) {}
operator bool() const {
return static_cast<bool>(graph);
}
Code code;
std::shared_ptr<Graph> graph;
};
// Notice that those structs don't manage lifetime of their members.
// They is only valid only right after you call getDebugState() and should never
// be used again once another GraphExecutor function is called.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct GraphExecutorState {
const Graph* graph = nullptr;
ExecutionPlan fallback; // XXX: members of this field are optional
std::unordered_map<ArgumentSpec, ExecutionPlan> execution_plans;
};
struct TORCH_API EnableProfilingGuard {
EnableProfilingGuard();
~EnableProfilingGuard();
private:
bool old_executor_mode = false;
bool old_profiling_mode = false;
};
struct GraphExecutorImplBase;
struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(const std::shared_ptr<Graph>& graph, std::string function_name);
void run(Stack& inputs);
c10::intrusive_ptr<Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch);
// `remaining_bailout_depth` stands for the maximum number of profiled and
// specialized recompilations allowed for the current `GraphExecutor`. if
// remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any
// profiling and specialization. This is also equivalent to the
// SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0,
// `GraphExecutor` will profile and specialize its input graph based on the
// profiled information whenever a bailout check is failed/triggered, a new
// `GraphExecutor` will be created. This new `GraphExecutor`'s
// remaining_bailout_depth will be reduced by 1.
const ExecutionPlan& getPlanFor(
Stack& inputs,
size_t remaining_bailout_depth);
GraphExecutorState getDebugState();
static size_t getDefaultNumBailOuts();
void debugFlushCompilationCache();
bool isOptimized() const;
private:
std::shared_ptr<GraphExecutorImplBase> pImpl;
};
TORCH_API Node* replaceBlockWithFallbackGraph(
Block* b,
ArrayRef<Value*> inputs);
// These passes need to run before it is valid to pass to the interpreter
// regardless of whether sizes have been specialized or not.
TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
TORCH_API void debugSetFusionGroupInlining(bool state);
TORCH_API bool getFusionGroupInlining();
TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
TORCH_API std::atomic<bool>& getProfilingMode();
TORCH_API std::atomic<bool>& getExecutorMode();
TORCH_API std::atomic<size_t>& getNumProfiledRuns();
TORCH_API size_t getBailoutDepth();
TORCH_API bool IsNewExecutorEnabled();
struct TORCH_API GraphOptimizerEnabledGuard {
GraphOptimizerEnabledGuard(bool state)
: old_state_(getGraphExecutorOptimize()) {
setGraphExecutorOptimize(state);
}
~GraphOptimizerEnabledGuard() {
setGraphExecutorOptimize(old_state_);
}
bool old_state_;
};
namespace detail {
GraphExecutor* getGradExecutor(Operation& op);
GraphExecutor* getDifferentiableGraphOpExecutor(Operation& op);
// for debugging information we expose a way to get the last actually
// run graph. Previous approaches allowed querying the GraphExecutor
// for what graph it would run in certain circumstances (graphFor), but
// this is fragile because we sometimes change how these decisions are made.
// This interface still allows our tests to look at optimized graphs, but
// with less plumbing.
} // namespace detail
} // namespace jit
} // namespace torch