mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Improve Variable interface * Address comments from @apaszke and @colesbury * string ::operator= is not noexcept * Remove ir.h from tracer_state.h to improve build times * Make Variable a struct and pack SavedVariable fields * Implement as_variable_ref * grad_fn_ptr() -> grad_fn_unsafe() * Reduce hackiness of set_type hack * Include variable.h and edge.h in tracer_state.h because it uses them * class Variable -> struct Variable because Windows cant even * Make Variable::output_nr uint32_t instead of int * Add comment about tracing state * Replaced more static_cast<Variable&> and improve docs * Remove SavedVariable destructor and construct members in init list * Clarify docs for Variable * Variable::set_version -> set_version_counter
389 lines
15 KiB
C++
389 lines
15 KiB
C++
#include "Python.h"
|
|
#include "torch/csrc/jit/graph_executor.h"
|
|
#include "torch/csrc/jit/ir.h"
|
|
#include "torch/csrc/jit/argument_spec.h"
|
|
#include "torch/csrc/jit/autodiff.h"
|
|
#include "torch/csrc/jit/interpreter.h"
|
|
#include "torch/csrc/autograd/grad_mode.h"
|
|
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
|
|
#include "torch/csrc/jit/passes/shape_analysis.h"
|
|
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
|
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
|
#include "torch/csrc/jit/passes/peephole.h"
|
|
#include "torch/csrc/jit/passes/graph_fuser.h"
|
|
#include "torch/csrc/jit/passes/inplace_check.h"
|
|
#include "torch/csrc/jit/passes/batch_mm.h"
|
|
|
|
#include "torch/csrc/autograd/function.h"
|
|
#include "torch/csrc/autograd/edge.h"
|
|
|
|
#include <unordered_map>
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
namespace {
|
|
|
|
using tensor_list = std::vector<at::Tensor>;
|
|
using Variable = autograd::Variable;
|
|
using autograd::variable_list;
|
|
|
|
// this type is in ExecutionPlan to run its Gradient if it is
|
|
// specified. It has a list of inputs captured by ExecutionPlan that
|
|
// it concats with inputs to form the full set of inputs to graph.
|
|
// see struct Gradient for a description of how the derivative graph
|
|
// is constructed and what variables are captured.
|
|
struct ExecutionPlanAutogradFunction : public autograd::Function {
|
|
ExecutionPlanAutogradFunction(GraphExecutor graph, size_t capture_size)
|
|
: graph(std::move(graph)) {
|
|
captures.reserve(capture_size);
|
|
}
|
|
virtual variable_list apply(const variable_list& inputs) override {
|
|
// TODO: expensive copies here to convert to/from tensor_list
|
|
// TODO: because inputs is passed by const reference there is no
|
|
// way to release tensors incrementally as this runs
|
|
variable_tensor_list all_inputs;
|
|
all_inputs.reserve(captures.size() + inputs.size());
|
|
all_inputs.insert(all_inputs.end(), inputs.begin(), inputs.end());
|
|
for(auto & sv : captures) {
|
|
all_inputs.push_back(sv.unpack(this->shared_from_this()));
|
|
}
|
|
auto tensors = graph.run(std::move(all_inputs));
|
|
// TODO: another copy that needs to be removed
|
|
return autograd::variable_list(tensors.begin(), tensors.end());
|
|
}
|
|
private:
|
|
friend struct ExecutionPlan;
|
|
GraphExecutor graph;
|
|
std::vector<autograd::SavedVariable> captures;
|
|
};
|
|
|
|
|
|
// an optimized way of executing the subgraph computed directly on
|
|
// tensors rather than Variables.
|
|
// This will unwrap Variables, run the plan, and re-wrap them.
|
|
// It can optionally also have a gradient which is hooked up
|
|
// to the output Variables if present.
|
|
struct ExecutionPlan {
|
|
ExecutionPlan(std::shared_ptr<Graph> & graph)
|
|
: f(graph) {}
|
|
ExecutionPlan(std::shared_ptr<Graph> & graph, Gradient grad)
|
|
: f(graph), grad(std::move(grad)), grad_executor(this->grad.df) {}
|
|
|
|
variable_tensor_list run(variable_tensor_list inputs) const {
|
|
if(grad) {
|
|
return runWithGrad(std::move(inputs));
|
|
}
|
|
// TODO: interpreter needs to accept moved inputs
|
|
// and delete incrementally
|
|
tensor_list outputs;
|
|
InterpreterState(f).runOneStage(unwrapVariables(std::move(inputs)), outputs);
|
|
return wrapTensors(std::move(outputs));
|
|
}
|
|
private:
|
|
// inplace to avoid allocations
|
|
tensor_list unwrapVariables(variable_tensor_list && list) const {
|
|
for(auto & v : list) {
|
|
v = v.defined() ? autograd::as_variable_ref(v).data() : at::Tensor();
|
|
}
|
|
return std::move(list);
|
|
}
|
|
// inplace to avoid allocations
|
|
variable_tensor_list wrapTensors(tensor_list && list) const {
|
|
for(auto & v : list) {
|
|
v = autograd::make_variable(v, /*requires_grad=*/false);
|
|
}
|
|
return variable_tensor_list(std::move(list));
|
|
}
|
|
// Capture (save) inputs that would be required to subsequently run backwards
|
|
void captureInputs(ExecutionPlanAutogradFunction & grad_fn, variable_tensor_list & inputs) const {
|
|
for(auto offset : grad.df_input_captured_inputs) {
|
|
grad_fn.captures.emplace_back(autograd::as_variable_ref(inputs[offset]), false);
|
|
}
|
|
}
|
|
void captureOutputs(ExecutionPlanAutogradFunction & grad_fn, variable_tensor_list & outputs) const {
|
|
for(auto offset : grad.df_input_captured_outputs) {
|
|
grad_fn.captures.emplace_back(autograd::as_variable_ref(outputs[offset]), true);
|
|
}
|
|
}
|
|
|
|
variable_tensor_list runWithGrad(variable_tensor_list&& inputs) const {
|
|
auto grad_fn = std::make_shared<ExecutionPlanAutogradFunction>(grad_executor,
|
|
grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size());
|
|
// hook up the outputs of df to the gradient functions of the inputs that require
|
|
// gradients
|
|
for(auto idx : grad.df_output_vjps) {
|
|
auto & v = autograd::as_variable_ref(inputs[idx]);
|
|
// TODO: this kinda stuff is _way_ to low level to the public API of variable.
|
|
// Why do I have to care here whether v has a grad_fn or grad accumulator?
|
|
// Why do I have to care here about output_nr? I just want to say
|
|
// grad_fn->setOutputTo(i, v.input_port());
|
|
// TODO(psag): remove above rant once Function API is improved.
|
|
grad_fn->next_functions.push_back(v.gradient_edge());
|
|
}
|
|
captureInputs(*grad_fn, inputs);
|
|
|
|
tensor_list outputs_;
|
|
InterpreterState(f).runOneStage(unwrapVariables(std::move(inputs)), outputs_);
|
|
variable_tensor_list outputs = wrapTensors(std::move(outputs_));
|
|
|
|
// hookup the gradients for the output tensors that require gradients
|
|
// to the inputs to our gradient function df
|
|
// TODO - XXX - if any output is the same tensor multiple times, views have to be
|
|
// setup here. We need to refactor autograd until it is safe for
|
|
// tensors to be constructed without all the viewing infrastructure.
|
|
// this is currently intentionally not done here so we can get an idea of our
|
|
// perf before introducing overhead for correctness
|
|
for(auto idx : grad.df_input_vjps) {
|
|
// Note: we have to set this up in place, or we have to throw away and
|
|
// reallocate variables that were already created in wrapTensors. We
|
|
// should add an API for this.
|
|
auto& output = autograd::as_variable_ref(outputs[idx]);
|
|
output.set_gradient_edge(autograd::Edge(grad_fn, grad_fn->num_inputs++));
|
|
output.set_requires_grad(true);
|
|
}
|
|
captureOutputs(*grad_fn, outputs);
|
|
// drop the temporary outputs so that we return the same number of
|
|
// outputs as if we were not also calculating gradient
|
|
outputs.erase(outputs.begin() + grad.f_real_outputs, outputs.end());
|
|
return outputs;
|
|
}
|
|
Code f;
|
|
// description of gradient as a graph
|
|
Gradient grad; // if(grad) is false when this is unused
|
|
// executor for df, including code caches
|
|
GraphExecutor grad_executor;
|
|
};
|
|
|
|
} // anonymous namespace
|
|
|
|
// a Graph can be created via tracing, or via a language-based frontend
|
|
// GraphExecutor runs it. It can run the same graph on many different sizes
|
|
// and different requires_grad states, and handles specializations for each situation.
|
|
// GraphExecutor is completely unaware of tracing or module parameters to keep the
|
|
// tracing concerns separated.
|
|
struct GraphExecutorImpl {
|
|
|
|
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize, bool symbolically_differentiable)
|
|
: graph(std::move(graph))
|
|
, optimize(optimize)
|
|
, symbolically_differentiable(symbolically_differentiable) {}
|
|
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize)
|
|
: graph(std::move(graph))
|
|
, optimize(optimize)
|
|
, symbolically_differentiable(isDifferentiable(*this->graph)) {}
|
|
|
|
// entry point where execution begins
|
|
variable_tensor_list run(variable_tensor_list inputs) {
|
|
// this is the fallback pathway, when we cannot differentiate
|
|
if(!optimize || (!symbolically_differentiable && needsGradient(inputs))) {
|
|
auto & fb = getOrCreateAutogradFallback();
|
|
InterpreterState state(fb);
|
|
tensor_list outputs;
|
|
state.runOneStage(std::move(inputs), outputs);
|
|
// note: we never unwrapped inputs, because we want autograd to record the trace
|
|
return variable_tensor_list(std::move(outputs));
|
|
}
|
|
|
|
// either we can symbolically differentiate, or we do not need a gradient.
|
|
// go down the route where we treat the inputs as tensors
|
|
// and fully optimize
|
|
auto & implementation = getOrCompile(inputs);
|
|
return implementation.run(std::move(inputs));
|
|
}
|
|
|
|
private:
|
|
|
|
static bool needsGradient(const variable_tensor_list & inputs) {
|
|
if (!autograd::GradMode::is_enabled()) {
|
|
return false;
|
|
}
|
|
for (const auto & tensor : inputs) {
|
|
if(tensor.defined() && static_cast<const Variable&>(tensor).requires_grad())
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
static bool isDifferentiable(Graph & g) {
|
|
for(auto n : g.nodes()) {
|
|
if(!jit::isDifferentiable(n))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables) {
|
|
|
|
// these optimizations must run in the presence of variables
|
|
// and when shape information is not statically known.
|
|
|
|
EliminateDeadCode(graph);
|
|
CheckInplace(graph);
|
|
EliminateCommonSubexpression(graph);
|
|
|
|
if (!graphMustSupportVariables) {
|
|
// These optimizations can introduce operators like FusionGroup that
|
|
// do not work on variables
|
|
|
|
// They also may assume that concrete sizes/strides are availiable
|
|
|
|
//TODO: create peephole optimizations that are safe to run
|
|
// when we are using variables, and when we do not know sizes.
|
|
PeepholeOptimize(graph);
|
|
// TODO: remove mandatory size checking in BatchMM, otherwise
|
|
// it works fine on variables.
|
|
BatchMM(graph);
|
|
FuseGraph(graph);
|
|
}
|
|
}
|
|
const Code & getOrCreateAutogradFallback() {
|
|
std::lock_guard<std::mutex> lock(compile_mutex);
|
|
if(autograd_fallback) {
|
|
return autograd_fallback;
|
|
}
|
|
auto graph_ = graph->copy();
|
|
if(optimize) {
|
|
CreateAutodiffSubgraphs(*graph_);
|
|
runOptimization(graph_, /*graphMustSupportVariables=*/true);
|
|
}
|
|
autograd_fallback = Code(graph_);
|
|
return autograd_fallback;
|
|
}
|
|
const ExecutionPlan & getOrCompile(const variable_tensor_list & inputs) {
|
|
// outside lock guard, to minimize the time holding the lock on the fast path
|
|
// ArgumentSpec even computes its hashCode here.
|
|
ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs);
|
|
{
|
|
std::lock_guard<std::mutex> lock(compile_mutex);
|
|
auto it = plan_cache.find(spec);
|
|
if(it != plan_cache.end())
|
|
return it->second;
|
|
auto plan = compileSpec(spec);
|
|
auto r = plan_cache.emplace(std::move(spec), std::move(plan));
|
|
return r.first->second;
|
|
}
|
|
}
|
|
bool needsGradient(const ArgumentSpec & spec) {
|
|
for(size_t i = 0; i < spec.size(); ++i) {
|
|
if(spec.tensorInfo(i).requires_grad())
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
|
|
// remove ReplaceIfUndef(v, replacement) nodes that consume inputs with 'v' if
|
|
// the input is defined, and 'replacement' if it is not.
|
|
// Note: this is a very limited pass. It looks at undefined inputs,
|
|
// and cleans up ReplaceIfUndef nodes inserted by autodiff.
|
|
void specializeUndef(Graph & g, const ArgumentSpec & spec) {
|
|
for(size_t i = 0; i < spec.size(); i++) {
|
|
std::vector<Value*> to_replace;
|
|
// do not edit in place, since it invalidates uses iterator
|
|
for(auto u : g.inputs()[i]->uses()) {
|
|
if(u.user->kind() == kReplaceIfUndef) {
|
|
to_replace.push_back(u.user->output());
|
|
}
|
|
}
|
|
for(auto v : to_replace) {
|
|
// if it is defined, then we replace with 'v' if not,
|
|
// we replace with 'replacement' which is normally just a zero tensor
|
|
int idx = spec.tensorInfo(i).defined() ? 0 : 1;
|
|
v->replaceAllUsesWith(v->node()->inputs()[idx]);
|
|
v->node()->destroy();
|
|
}
|
|
}
|
|
}
|
|
// a + 0 -> a
|
|
// 0 + a -> a
|
|
void propagateZeros(Graph & g) {
|
|
for(auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
|
|
if(it->kind() == kadd && at::Scalar(it->t(kalpha)).toDouble() == 1.0) {
|
|
if(isZero(it->inputs()[0])) {
|
|
it->output()->replaceAllUsesWith(it->inputs()[1]);
|
|
it.destroyCurrent();
|
|
} else if(isZero(it->inputs()[1])) {
|
|
it->output()->replaceAllUsesWith(it->inputs()[0]);
|
|
it.destroyCurrent();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
void specializeToSpec(std::shared_ptr<Graph> g, const ArgumentSpec & spec) {
|
|
|
|
// The following passes are specialized to clean up after autograd
|
|
// decisions to insert/remove undefs nodes and to work before
|
|
// we propagate input shapes.
|
|
|
|
// clean up replaceIfUndef nodes
|
|
specializeUndef(*g, spec);
|
|
// clean up additions resulting from nodes that were in fact undefined
|
|
propagateZeros(*g);
|
|
// clean up dead constants from specialization
|
|
EliminateDeadCode(g);
|
|
// calculate all input shapes
|
|
PropagateInputShapes(*g, spec);
|
|
}
|
|
ExecutionPlan compileSpec(const ArgumentSpec & spec) {
|
|
auto graph_ = graph->copy();
|
|
specializeToSpec(graph_, spec);
|
|
if(!needsGradient(spec)) {
|
|
runOptimization(graph_, /*graphMustSupportVariables=*/false);
|
|
return ExecutionPlan(graph_);
|
|
}
|
|
JIT_ASSERT(symbolically_differentiable);
|
|
|
|
std::vector<bool> requires_grads;
|
|
requires_grads.reserve(spec.size());
|
|
for(size_t i = 0; i < spec.size(); i++)
|
|
requires_grads.push_back(spec.tensorInfo(i).requires_grad());
|
|
|
|
Gradient gradient = differentiate(graph_, requires_grads);
|
|
graph_ = gradient.f;
|
|
runOptimization(graph_, /*graphMustSupportVariables=*/false);
|
|
return ExecutionPlan(graph_, std::move(gradient));
|
|
}
|
|
// the unoptimized starting graph
|
|
// this is never mutated
|
|
std::shared_ptr<Graph> graph;
|
|
|
|
// true - do everything we can to make this graph run fast
|
|
// false - do not modifiy the graph at all and just use the interpreter
|
|
// to run the graph. Useful for debugging correctness issues in the implementation
|
|
bool optimize;
|
|
|
|
// GraphExecutor optimizes more aggresively when we _know_ the graph will be
|
|
// symbolically differentiable.
|
|
bool symbolically_differentiable;
|
|
|
|
// when this graph has some parts that are not symbolically_differentable,
|
|
// but some input does require a derivative, we create and use autograd_fallback,
|
|
// which wraps up the fully differentiable subgraphs, and then runs the outer
|
|
// graph through autograd.
|
|
// Since we can't optimize black box functions anyway, there is only one fallback path,
|
|
// and it must work on all sizes (so no optimizations that inspect sizes can run on it)
|
|
Code autograd_fallback;
|
|
|
|
// optimizable code paths, used when we can differentiate or when no derivative is needed
|
|
// Spec describes input conditions, Plan describes how to execute them.
|
|
std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
|
|
|
|
// GraphExecutor can be accessed from multiple thread so
|
|
// anytime we are checking or updating the autograd_fallback or
|
|
// plan_cache, we must hold the compile mutex.
|
|
// along the fast path (no compilation) code should
|
|
// hold this for as little time as possible.
|
|
std::mutex compile_mutex;
|
|
};
|
|
|
|
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
|
|
: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
|
|
|
|
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize, bool symbolically_differentiable)
|
|
: pImpl(new GraphExecutorImpl(std::move(graph), optimize, symbolically_differentiable)) {}
|
|
|
|
variable_tensor_list GraphExecutor::run(variable_tensor_list && inputs) {
|
|
return pImpl->run(std::move(inputs));
|
|
}
|
|
|
|
}}
|