Files
pytorch/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp

475 lines
15 KiB
C++

#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/remove_redundant_profiles.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/autodiff.h>
namespace torch::jit {
namespace {
struct WorkBlock : public std::pair<Node*, Node*> {
using pair::pair;
Node* begin() {
return this->first;
}
Node* end() {
return this->second;
}
};
class SubgraphSlicer {
public:
SubgraphSlicer(
Block* block,
std::shared_ptr<Graph> graph,
size_t minSubgraphSize,
AliasDb& aliasDb,
std::vector<Node*>& diff_nodes)
: block_(block),
graph_(std::move(graph)),
minSubgraphSize_(minSubgraphSize),
aliasDb_(aliasDb),
diff_nodes_(diff_nodes) {}
void run() {
// We maintain alias db correctness in-place while building up the autodiff
// subgraphs, however it is difficult to preserve correctness when
// un-inlining autodiff subgraphs. We first recursively construct all
// subgraphs and then recursively cleanup & unmerge the small subgraphs
buildupSubgraphs();
GRAPH_DUMP("before unfuseAliasedOutputs", graph_);
unfuseAliasedOutputs(block_);
cleanupSubgraphs();
// Run CSE globally onceto eliminate duplicates that may have occurred
// while inlining subgraphs.
EliminateCommonSubexpression(graph_);
}
void cleanupSubgraphs() {
auto curNode = *block_->nodes().rbegin();
while (curNode != *block_->nodes().rend()) {
// Save the previous node, since we might delete `curNode` in next block
auto prevNode = curNode->prev();
if (curNode->kind() == prim::DifferentiableGraph) {
// Inlining nodes may cause some subexpression to come back in the
// subgraphs (for example, copying constants in repeatedly will generate
// redundant prim::Constants). Run CSE to clean them up.
EliminateCommonSubexpression(curNode->g(attr::Subgraph));
if (!inlineIfTooSmall(curNode)) {
diff_nodes_.push_back(curNode);
}
}
curNode = prevNode;
}
for (Node* n : block_->nodes()) {
for (Block* b : n->blocks()) {
SubgraphSlicer(b, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
.cleanupSubgraphs();
}
}
}
void buildupSubgraphs() {
// We need to run the slicer multiple times in order to get all merge
// opportunities. This is because moveBeforeTopologicalValid may reorder
// nodes to be AFTER the current iteration point. In order to properly
// consider those nodes for merging, we need run the pass until no changes
// have been made.
//
// Example:
// c = f(a, b)
// d = f(c)
// e = f(d) <- iter is here, moving upward
// After c.moveBeforeTopologicallyValid(e), we have:
// c = f(a, b)
// e = f(d) <- iter still here
// d = f(c) <- this was node moved on the other side.
// see [workblocks]
auto workblocks = buildWorkBlocks();
for (auto& workblock : workblocks) {
bool any_changed = true;
while (any_changed) {
any_changed = false;
for (auto it = workblock.end()->reverseIterator();
it != workblock.begin()->reverseIterator();) {
auto [tmp_it, changed] = scanNode(*it);
it = tmp_it;
any_changed |= changed;
}
}
}
// Construct Subgraphs Recursively
for (Node* n : block_->nodes()) {
for (auto subBlock : n->blocks()) {
SubgraphSlicer(
subBlock, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
.buildupSubgraphs();
}
}
}
private:
void unfuseAliasedOutputs(Block* b) {
bool any_changed = true;
while (any_changed) {
any_changed = false;
// we walk in the reverse order, so we can skip
// nodes that might get unfused after the current
// prim::DifferentiableGraph
for (auto n : b->nodes().reverse()) {
if (n->kind() == prim::DifferentiableGraph) {
// aliased outputs in DifferentiableGraphs must be unfused
// since autodiff doesn't know how to handle them correctly
// N.B. Note, |= since we don't want `unfuseAliasedOutputs`
// to short-circuit
any_changed |= SubgraphUtils::unmergeAliasedOutputs(n);
any_changed |= SubgraphUtils::unmergeOutputsAlisingInputs(n);
GRAPH_DEBUG(
"any_changed on ",
any_changed,
" ",
n->g(attr::Subgraph)->toString(false));
}
}
}
for (Node* n : b->nodes()) {
for (Block* ib : n->blocks()) {
unfuseAliasedOutputs(ib);
}
}
}
std::vector<WorkBlock> buildWorkBlocks() {
// [workblocks]
// the IR has many nodes which can never be reordered around, such as a
// prim::Bailout. if a node N is surrounded by two nodes which cannot be
// reordered, A and B, then a differentiable subgraph that is created from N
// can only contain nodes from (A, B) The nodes from A to B represent one
// work block for the subgraph slicer to work on. By creating these up
// front, we avoid retraversing the whole graph block any time scanNode
// returns, and we can also avoid attempting to create differentiable
// subgraphs in work blocks that do not contain a # of differentiable nodes
// >= minSubgraphSize_
Node* end_bound_node = block_->return_node();
Node* curr = end_bound_node->prev();
std::vector<WorkBlock> worklist;
size_t differentiable_nodes = 0;
while (curr != block_->param_node()) {
differentiable_nodes += shouldConsiderForMerge(curr);
// cannot reorder around side effectful nodes
if (curr->hasSideEffects()) {
// not enough differentiable nodes to create a differentiable subgraph
if (differentiable_nodes >= minSubgraphSize_) {
worklist.emplace_back(curr, end_bound_node);
}
differentiable_nodes = 0;
end_bound_node = curr;
}
curr = curr->prev();
}
if (differentiable_nodes >= minSubgraphSize_) {
worklist.emplace_back(curr, end_bound_node);
}
return worklist;
}
// Inline this node's group subgraph into the outer graph if it's smaller
// than the specified minimum size.
//
// Returns true if an inlining has occurred, false otherwise.
bool inlineIfTooSmall(Node* n) {
AT_ASSERT(n->kind() == prim::DifferentiableGraph);
auto subgraph = SubgraphUtils::getSubgraph(n);
size_t i = 0;
for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
++it) {
i += !it->notExecutedOp();
if (i >= minSubgraphSize_) {
return false;
}
}
SubgraphUtils::unmergeSubgraph(n);
return true;
}
value_list sortReverseTopological(ArrayRef<Value*> inputs) {
value_list result;
for (auto i : inputs) {
if (i->node()->owningBlock() == block_) {
result.push_back(i);
}
}
// Sort in reverse topological order
std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
return a->node()->isAfter(b->node());
});
return result;
}
bool isViewOp(Node* n) {
switch (n->kind()) {
case aten::view:
case aten::view_as:
case aten::reshape:
case aten::reshape_as:
case aten::transpose:
case aten::expand:
case aten::expand_as:
return true;
}
return false;
}
bool shouldConsiderForMerge(Node* node) {
// if we're already in the process of merging
if (node->kind() == prim::DifferentiableGraph) {
return true;
}
if (node->kind() == prim::Constant) {
return false;
}
// view ops as outputs of differentiable subgraphs can cause incorrect
// differentiation for now, do not include them in the subgraph
if (isViewOp(node)) {
return false;
}
return isDifferentiable(node);
}
std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
if (shouldConsiderForMerge(consumer)) {
if (consumer->kind() != prim::DifferentiableGraph) {
consumer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
consumer, prim::DifferentiableGraph, aliasDb_);
}
auto inputs = sortReverseTopological(consumer->inputs());
for (auto input : inputs) {
if (auto group = tryMerge(consumer, input->node())) {
// we successfully merged, so the new group's `inputs` may have
// changed. So rescan the new group for more merging opportunities.
return std::make_pair(group.value()->reverseIterator(), true);
}
}
}
return std::make_pair(++consumer->reverseIterator(), false);
}
// Try to merge `producer` into `consumer`. If successful, this destroys
// `producer` and returns the `consumer` group.
std::optional<Node*> tryMerge(Node* consumer, Node* producer) {
AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
bool canMerge = shouldConsiderForMerge(producer) &&
aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
if (!canMerge) {
return std::nullopt;
}
SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
producer, consumer, aliasDb_);
return consumer;
}
Block* block_;
std::shared_ptr<Graph> graph_;
size_t minSubgraphSize_;
AliasDb& aliasDb_;
std::vector<Node*>& diff_nodes_;
};
std::optional<bool> getProfileNodeRequiresGrad(Node* n) {
TORCH_INTERNAL_ASSERT(n->kind() == prim::profile);
if (!n->hasAttribute(attr::profiled_type)) {
return std::nullopt;
}
auto& type = n->ty(attr::profiled_type);
if (type->castRaw<TensorType>() == nullptr) {
return std::nullopt;
}
return type->expectRef<TensorType>().requiresGrad();
}
struct ContextMapping {
std::vector<const Node*> ctx_stack_;
std::unordered_map<const Node*, const Node*> node_to_ctx_;
void processNode(Node* n) {
node_to_ctx_[n] = ctx_stack_.back();
if (n->kind() == prim::Enter) {
ctx_stack_.push_back(n);
} else if (n->kind() == prim::Exit) {
ctx_stack_.pop_back();
}
}
void processBlock(Block* block) {
for (Node* n : block->nodes()) {
processNode(n);
for (Block* b : n->blocks()) {
processBlock(b);
}
if (n->kind() == prim::DifferentiableGraph) {
const auto& subgraph = n->g(attr::Subgraph);
processBlock(subgraph->block());
}
}
}
ContextMapping(const std::shared_ptr<Graph>& graph) {
ctx_stack_.push_back(nullptr);
processBlock(graph->block());
}
const Node* get(const Node* n) const {
auto it = node_to_ctx_.find(n);
TORCH_INTERNAL_ASSERT(
it != node_to_ctx_.end(),
"Cannot find node in node-to-context mapping.");
return it->second;
}
bool has(const Node* n) const {
return node_to_ctx_.find(n) != node_to_ctx_.end();
}
};
std::optional<bool> findRequiresGradForOutput(
Node* diff_graph,
Value* output,
const ContextMapping& ctx_mapping) {
for (auto& use : output->uses()) {
// [Only consider profiles in the same context]
// Ignore profiled uses if the use is within a different context.
// For example, a profile node within a no_grad() context will record the
// wrong requires_grad information.
if (ctx_mapping.has(use.user) &&
ctx_mapping.get(use.user) != ctx_mapping.get(diff_graph)) {
continue;
}
if (use.user->kind() == prim::profile) {
auto req_grad_use = getProfileNodeRequiresGrad(use.user);
if (req_grad_use.has_value()) {
return req_grad_use;
}
}
// maybe the profile node got absorbed into a differentiable graph
if (use.user->kind() == prim::DifferentiableGraph) {
const auto& dg = use.user->g(attr::Subgraph);
// check all the uses of this graph input to look for profile nodes.
Value* dg_value = dg->inputs()[use.offset];
for (auto& dg_use : dg_value->uses()) {
// See [Only consider profiles in the same context]
if (ctx_mapping.has(dg_use.user) &&
ctx_mapping.get(dg_use.user) != ctx_mapping.get(diff_graph)) {
continue;
}
if (dg_use.user->kind() == prim::profile) {
auto req_grad_use = getProfileNodeRequiresGrad(dg_use.user);
if (req_grad_use.has_value()) {
return req_grad_use;
}
}
}
}
}
return std::nullopt;
}
void AddRequiresGradToDifferentiableGraph(
Node* diff_graph,
const ContextMapping& ctx_mapping) {
TORCH_INTERNAL_ASSERT(diff_graph->kind() == prim::DifferentiableGraph);
const auto& subgraph = diff_graph->g(attr::Subgraph);
for (auto i : c10::irange(subgraph->outputs().size())) {
Value* output = subgraph->outputs()[i];
if (output->node()->kind() == prim::profile) {
// already have requires_grad info from this profile node
continue;
}
if (output->type()->castRaw<TensorType>() == nullptr) {
// non-tensors don't get profiled.
continue;
}
if (output->type()->expectRef<TensorType>().requiresGrad().has_value()) {
continue;
}
// this node doesn't have any requires_grad info.
// look at its uses to try to find a profile node.
auto requires_grad = findRequiresGradForOutput(
diff_graph, diff_graph->output(i), ctx_mapping);
output->setType(output->type()->expectRef<TensorType>().withRequiresGrad(
requires_grad));
}
}
void AddRequiresGradOnOutputNodes(
Block* block,
const ContextMapping& ctx_mapping) {
for (Node* n : block->nodes()) {
if (n->kind() == prim::DifferentiableGraph) {
AddRequiresGradToDifferentiableGraph(n, ctx_mapping);
}
for (Block* b : n->blocks()) {
AddRequiresGradOnOutputNodes(b, ctx_mapping);
}
}
}
// autodiff.cpp needs to know, for each output, whether or not it requires
// grad. Sometimes a profile node will be present on the output, but sometimes
// it won't be present. This might happen if there's a node with side effects
// in between the definition of the output node and the profile node; in this
// case the profile node and output node would be in different workblocks and
// couldn't be merged into the same DifferentiableGraph. (see [workblocks])
// Or it could happen if the output is profiled twice and the profile nodes get
// removed by unfusedAliasedOutputs.
void AddRequiresGradOnOutputNodes(const std::shared_ptr<Graph>& graph) {
ContextMapping ctx_mapping(graph);
AddRequiresGradOnOutputNodes(graph->block(), ctx_mapping);
}
} // anonymous namespace
std::vector<Node*> CreateAutodiffSubgraphs(
const std::shared_ptr<Graph>& graph,
size_t threshold) {
std::vector<Node*> diff_nodes;
AliasDb db(graph);
GRAPH_DEBUG("Before creating autodiff subgraphs", *graph);
SubgraphSlicer(graph->block(), graph, threshold, db, diff_nodes).run();
GRAPH_DEBUG("After creating autodiff subgraphs", *graph);
AddRequiresGradOnOutputNodes(graph);
GRAPH_DEBUG("diff_nodes.size() ", diff_nodes.size());
return diff_nodes;
}
} // namespace torch::jit