mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Follows #132604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132753 Approved by: https://github.com/Skylion007
475 lines
15 KiB
C++
475 lines
15 KiB
C++
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/canonicalize.h>
|
|
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
|
|
#include <torch/csrc/jit/passes/remove_redundant_profiles.h>
|
|
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
|
#include <torch/csrc/jit/runtime/autodiff.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
namespace {
|
|
|
|
struct WorkBlock : public std::pair<Node*, Node*> {
|
|
using pair::pair;
|
|
|
|
Node* begin() {
|
|
return this->first;
|
|
}
|
|
Node* end() {
|
|
return this->second;
|
|
}
|
|
};
|
|
|
|
class SubgraphSlicer {
|
|
public:
|
|
SubgraphSlicer(
|
|
Block* block,
|
|
std::shared_ptr<Graph> graph,
|
|
size_t minSubgraphSize,
|
|
AliasDb& aliasDb,
|
|
std::vector<Node*>& diff_nodes)
|
|
: block_(block),
|
|
graph_(std::move(graph)),
|
|
minSubgraphSize_(minSubgraphSize),
|
|
aliasDb_(aliasDb),
|
|
diff_nodes_(diff_nodes) {}
|
|
|
|
void run() {
|
|
// We maintain alias db correctness in-place while building up the autodiff
|
|
// subgraphs, however it is difficult to preserve correctness when
|
|
// un-inlining autodiff subgraphs. We first recursively construct all
|
|
// subgraphs and then recursively cleanup & unmerge the small subgraphs
|
|
buildupSubgraphs();
|
|
GRAPH_DUMP("before unfuseAliasedOutputs", graph_);
|
|
unfuseAliasedOutputs(block_);
|
|
cleanupSubgraphs();
|
|
// Run CSE globally onceto eliminate duplicates that may have occurred
|
|
// while inlining subgraphs.
|
|
EliminateCommonSubexpression(graph_);
|
|
}
|
|
|
|
void cleanupSubgraphs() {
|
|
auto curNode = *block_->nodes().rbegin();
|
|
while (curNode != *block_->nodes().rend()) {
|
|
// Save the previous node, since we might delete `curNode` in next block
|
|
auto prevNode = curNode->prev();
|
|
if (curNode->kind() == prim::DifferentiableGraph) {
|
|
// Inlining nodes may cause some subexpression to come back in the
|
|
// subgraphs (for example, copying constants in repeatedly will generate
|
|
// redundant prim::Constants). Run CSE to clean them up.
|
|
EliminateCommonSubexpression(curNode->g(attr::Subgraph));
|
|
|
|
if (!inlineIfTooSmall(curNode)) {
|
|
diff_nodes_.push_back(curNode);
|
|
}
|
|
}
|
|
curNode = prevNode;
|
|
}
|
|
|
|
for (Node* n : block_->nodes()) {
|
|
for (Block* b : n->blocks()) {
|
|
SubgraphSlicer(b, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
|
|
.cleanupSubgraphs();
|
|
}
|
|
}
|
|
}
|
|
|
|
void buildupSubgraphs() {
|
|
// We need to run the slicer multiple times in order to get all merge
|
|
// opportunities. This is because moveBeforeTopologicalValid may reorder
|
|
// nodes to be AFTER the current iteration point. In order to properly
|
|
// consider those nodes for merging, we need run the pass until no changes
|
|
// have been made.
|
|
//
|
|
// Example:
|
|
// c = f(a, b)
|
|
// d = f(c)
|
|
// e = f(d) <- iter is here, moving upward
|
|
// After c.moveBeforeTopologicallyValid(e), we have:
|
|
// c = f(a, b)
|
|
// e = f(d) <- iter still here
|
|
// d = f(c) <- this was node moved on the other side.
|
|
|
|
// see [workblocks]
|
|
auto workblocks = buildWorkBlocks();
|
|
for (auto& workblock : workblocks) {
|
|
bool any_changed = true;
|
|
while (any_changed) {
|
|
any_changed = false;
|
|
for (auto it = workblock.end()->reverseIterator();
|
|
it != workblock.begin()->reverseIterator();) {
|
|
auto [tmp_it, changed] = scanNode(*it);
|
|
it = tmp_it;
|
|
any_changed |= changed;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Construct Subgraphs Recursively
|
|
for (Node* n : block_->nodes()) {
|
|
for (auto subBlock : n->blocks()) {
|
|
SubgraphSlicer(
|
|
subBlock, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
|
|
.buildupSubgraphs();
|
|
}
|
|
}
|
|
}
|
|
|
|
private:
|
|
void unfuseAliasedOutputs(Block* b) {
|
|
bool any_changed = true;
|
|
while (any_changed) {
|
|
any_changed = false;
|
|
// we walk in the reverse order, so we can skip
|
|
// nodes that might get unfused after the current
|
|
// prim::DifferentiableGraph
|
|
for (auto n : b->nodes().reverse()) {
|
|
if (n->kind() == prim::DifferentiableGraph) {
|
|
// aliased outputs in DifferentiableGraphs must be unfused
|
|
// since autodiff doesn't know how to handle them correctly
|
|
// N.B. Note, |= since we don't want `unfuseAliasedOutputs`
|
|
// to short-circuit
|
|
any_changed |= SubgraphUtils::unmergeAliasedOutputs(n);
|
|
any_changed |= SubgraphUtils::unmergeOutputsAlisingInputs(n);
|
|
GRAPH_DEBUG(
|
|
"any_changed on ",
|
|
any_changed,
|
|
" ",
|
|
n->g(attr::Subgraph)->toString(false));
|
|
}
|
|
}
|
|
}
|
|
|
|
for (Node* n : b->nodes()) {
|
|
for (Block* ib : n->blocks()) {
|
|
unfuseAliasedOutputs(ib);
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<WorkBlock> buildWorkBlocks() {
|
|
// [workblocks]
|
|
// the IR has many nodes which can never be reordered around, such as a
|
|
// prim::Bailout. if a node N is surrounded by two nodes which cannot be
|
|
// reordered, A and B, then a differentiable subgraph that is created from N
|
|
// can only contain nodes from (A, B) The nodes from A to B represent one
|
|
// work block for the subgraph slicer to work on. By creating these up
|
|
// front, we avoid retraversing the whole graph block any time scanNode
|
|
// returns, and we can also avoid attempting to create differentiable
|
|
// subgraphs in work blocks that do not contain a # of differentiable nodes
|
|
// >= minSubgraphSize_
|
|
|
|
Node* end_bound_node = block_->return_node();
|
|
Node* curr = end_bound_node->prev();
|
|
|
|
std::vector<WorkBlock> worklist;
|
|
size_t differentiable_nodes = 0;
|
|
|
|
while (curr != block_->param_node()) {
|
|
differentiable_nodes += shouldConsiderForMerge(curr);
|
|
|
|
// cannot reorder around side effectful nodes
|
|
if (curr->hasSideEffects()) {
|
|
// not enough differentiable nodes to create a differentiable subgraph
|
|
if (differentiable_nodes >= minSubgraphSize_) {
|
|
worklist.emplace_back(curr, end_bound_node);
|
|
}
|
|
differentiable_nodes = 0;
|
|
end_bound_node = curr;
|
|
}
|
|
curr = curr->prev();
|
|
}
|
|
|
|
if (differentiable_nodes >= minSubgraphSize_) {
|
|
worklist.emplace_back(curr, end_bound_node);
|
|
}
|
|
|
|
return worklist;
|
|
}
|
|
|
|
// Inline this node's group subgraph into the outer graph if it's smaller
|
|
// than the specified minimum size.
|
|
//
|
|
// Returns true if an inlining has occurred, false otherwise.
|
|
bool inlineIfTooSmall(Node* n) {
|
|
AT_ASSERT(n->kind() == prim::DifferentiableGraph);
|
|
auto subgraph = SubgraphUtils::getSubgraph(n);
|
|
size_t i = 0;
|
|
for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
|
|
++it) {
|
|
i += !it->notExecutedOp();
|
|
if (i >= minSubgraphSize_) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
SubgraphUtils::unmergeSubgraph(n);
|
|
return true;
|
|
}
|
|
|
|
value_list sortReverseTopological(ArrayRef<Value*> inputs) {
|
|
value_list result;
|
|
for (auto i : inputs) {
|
|
if (i->node()->owningBlock() == block_) {
|
|
result.push_back(i);
|
|
}
|
|
}
|
|
// Sort in reverse topological order
|
|
std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
|
|
return a->node()->isAfter(b->node());
|
|
});
|
|
return result;
|
|
}
|
|
|
|
bool isViewOp(Node* n) {
|
|
switch (n->kind()) {
|
|
case aten::view:
|
|
case aten::view_as:
|
|
case aten::reshape:
|
|
case aten::reshape_as:
|
|
case aten::transpose:
|
|
case aten::expand:
|
|
case aten::expand_as:
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool shouldConsiderForMerge(Node* node) {
|
|
// if we're already in the process of merging
|
|
if (node->kind() == prim::DifferentiableGraph) {
|
|
return true;
|
|
}
|
|
if (node->kind() == prim::Constant) {
|
|
return false;
|
|
}
|
|
|
|
// view ops as outputs of differentiable subgraphs can cause incorrect
|
|
// differentiation for now, do not include them in the subgraph
|
|
if (isViewOp(node)) {
|
|
return false;
|
|
}
|
|
|
|
return isDifferentiable(node);
|
|
}
|
|
|
|
std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
|
|
if (shouldConsiderForMerge(consumer)) {
|
|
if (consumer->kind() != prim::DifferentiableGraph) {
|
|
consumer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
|
|
consumer, prim::DifferentiableGraph, aliasDb_);
|
|
}
|
|
auto inputs = sortReverseTopological(consumer->inputs());
|
|
for (auto input : inputs) {
|
|
if (auto group = tryMerge(consumer, input->node())) {
|
|
// we successfully merged, so the new group's `inputs` may have
|
|
// changed. So rescan the new group for more merging opportunities.
|
|
return std::make_pair(group.value()->reverseIterator(), true);
|
|
}
|
|
}
|
|
}
|
|
|
|
return std::make_pair(++consumer->reverseIterator(), false);
|
|
}
|
|
|
|
// Try to merge `producer` into `consumer`. If successful, this destroys
|
|
// `producer` and returns the `consumer` group.
|
|
std::optional<Node*> tryMerge(Node* consumer, Node* producer) {
|
|
AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
|
|
bool canMerge = shouldConsiderForMerge(producer) &&
|
|
aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
|
|
|
|
if (!canMerge) {
|
|
return std::nullopt;
|
|
}
|
|
|
|
SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
|
|
producer, consumer, aliasDb_);
|
|
return consumer;
|
|
}
|
|
|
|
Block* block_;
|
|
std::shared_ptr<Graph> graph_;
|
|
size_t minSubgraphSize_;
|
|
AliasDb& aliasDb_;
|
|
std::vector<Node*>& diff_nodes_;
|
|
};
|
|
|
|
std::optional<bool> getProfileNodeRequiresGrad(Node* n) {
|
|
TORCH_INTERNAL_ASSERT(n->kind() == prim::profile);
|
|
if (!n->hasAttribute(attr::profiled_type)) {
|
|
return std::nullopt;
|
|
}
|
|
auto& type = n->ty(attr::profiled_type);
|
|
if (type->castRaw<TensorType>() == nullptr) {
|
|
return std::nullopt;
|
|
}
|
|
return type->expectRef<TensorType>().requiresGrad();
|
|
}
|
|
|
|
struct ContextMapping {
|
|
std::vector<const Node*> ctx_stack_;
|
|
std::unordered_map<const Node*, const Node*> node_to_ctx_;
|
|
|
|
void processNode(Node* n) {
|
|
node_to_ctx_[n] = ctx_stack_.back();
|
|
|
|
if (n->kind() == prim::Enter) {
|
|
ctx_stack_.push_back(n);
|
|
} else if (n->kind() == prim::Exit) {
|
|
ctx_stack_.pop_back();
|
|
}
|
|
}
|
|
|
|
void processBlock(Block* block) {
|
|
for (Node* n : block->nodes()) {
|
|
processNode(n);
|
|
for (Block* b : n->blocks()) {
|
|
processBlock(b);
|
|
}
|
|
if (n->kind() == prim::DifferentiableGraph) {
|
|
const auto& subgraph = n->g(attr::Subgraph);
|
|
processBlock(subgraph->block());
|
|
}
|
|
}
|
|
}
|
|
|
|
ContextMapping(const std::shared_ptr<Graph>& graph) {
|
|
ctx_stack_.push_back(nullptr);
|
|
processBlock(graph->block());
|
|
}
|
|
|
|
const Node* get(const Node* n) const {
|
|
auto it = node_to_ctx_.find(n);
|
|
TORCH_INTERNAL_ASSERT(
|
|
it != node_to_ctx_.end(),
|
|
"Cannot find node in node-to-context mapping.");
|
|
return it->second;
|
|
}
|
|
|
|
bool has(const Node* n) const {
|
|
return node_to_ctx_.find(n) != node_to_ctx_.end();
|
|
}
|
|
};
|
|
|
|
std::optional<bool> findRequiresGradForOutput(
|
|
Node* diff_graph,
|
|
Value* output,
|
|
const ContextMapping& ctx_mapping) {
|
|
for (auto& use : output->uses()) {
|
|
// [Only consider profiles in the same context]
|
|
// Ignore profiled uses if the use is within a different context.
|
|
// For example, a profile node within a no_grad() context will record the
|
|
// wrong requires_grad information.
|
|
if (ctx_mapping.has(use.user) &&
|
|
ctx_mapping.get(use.user) != ctx_mapping.get(diff_graph)) {
|
|
continue;
|
|
}
|
|
|
|
if (use.user->kind() == prim::profile) {
|
|
auto req_grad_use = getProfileNodeRequiresGrad(use.user);
|
|
if (req_grad_use.has_value()) {
|
|
return req_grad_use;
|
|
}
|
|
}
|
|
|
|
// maybe the profile node got absorbed into a differentiable graph
|
|
if (use.user->kind() == prim::DifferentiableGraph) {
|
|
const auto& dg = use.user->g(attr::Subgraph);
|
|
// check all the uses of this graph input to look for profile nodes.
|
|
Value* dg_value = dg->inputs()[use.offset];
|
|
for (auto& dg_use : dg_value->uses()) {
|
|
// See [Only consider profiles in the same context]
|
|
if (ctx_mapping.has(dg_use.user) &&
|
|
ctx_mapping.get(dg_use.user) != ctx_mapping.get(diff_graph)) {
|
|
continue;
|
|
}
|
|
|
|
if (dg_use.user->kind() == prim::profile) {
|
|
auto req_grad_use = getProfileNodeRequiresGrad(dg_use.user);
|
|
if (req_grad_use.has_value()) {
|
|
return req_grad_use;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
void AddRequiresGradToDifferentiableGraph(
|
|
Node* diff_graph,
|
|
const ContextMapping& ctx_mapping) {
|
|
TORCH_INTERNAL_ASSERT(diff_graph->kind() == prim::DifferentiableGraph);
|
|
const auto& subgraph = diff_graph->g(attr::Subgraph);
|
|
for (auto i : c10::irange(subgraph->outputs().size())) {
|
|
Value* output = subgraph->outputs()[i];
|
|
if (output->node()->kind() == prim::profile) {
|
|
// already have requires_grad info from this profile node
|
|
continue;
|
|
}
|
|
if (output->type()->castRaw<TensorType>() == nullptr) {
|
|
// non-tensors don't get profiled.
|
|
continue;
|
|
}
|
|
if (output->type()->expectRef<TensorType>().requiresGrad().has_value()) {
|
|
continue;
|
|
}
|
|
|
|
// this node doesn't have any requires_grad info.
|
|
// look at its uses to try to find a profile node.
|
|
auto requires_grad = findRequiresGradForOutput(
|
|
diff_graph, diff_graph->output(i), ctx_mapping);
|
|
|
|
output->setType(output->type()->expectRef<TensorType>().withRequiresGrad(
|
|
requires_grad));
|
|
}
|
|
}
|
|
|
|
void AddRequiresGradOnOutputNodes(
|
|
Block* block,
|
|
const ContextMapping& ctx_mapping) {
|
|
for (Node* n : block->nodes()) {
|
|
if (n->kind() == prim::DifferentiableGraph) {
|
|
AddRequiresGradToDifferentiableGraph(n, ctx_mapping);
|
|
}
|
|
for (Block* b : n->blocks()) {
|
|
AddRequiresGradOnOutputNodes(b, ctx_mapping);
|
|
}
|
|
}
|
|
}
|
|
|
|
// autodiff.cpp needs to know, for each output, whether or not it requires
|
|
// grad. Sometimes a profile node will be present on the output, but sometimes
|
|
// it won't be present. This might happen if there's a node with side effects
|
|
// in between the definition of the output node and the profile node; in this
|
|
// case the profile node and output node would be in different workblocks and
|
|
// couldn't be merged into the same DifferentiableGraph. (see [workblocks])
|
|
// Or it could happen if the output is profiled twice and the profile nodes get
|
|
// removed by unfusedAliasedOutputs.
|
|
void AddRequiresGradOnOutputNodes(const std::shared_ptr<Graph>& graph) {
|
|
ContextMapping ctx_mapping(graph);
|
|
AddRequiresGradOnOutputNodes(graph->block(), ctx_mapping);
|
|
}
|
|
} // anonymous namespace
|
|
|
|
std::vector<Node*> CreateAutodiffSubgraphs(
|
|
const std::shared_ptr<Graph>& graph,
|
|
size_t threshold) {
|
|
std::vector<Node*> diff_nodes;
|
|
AliasDb db(graph);
|
|
GRAPH_DEBUG("Before creating autodiff subgraphs", *graph);
|
|
SubgraphSlicer(graph->block(), graph, threshold, db, diff_nodes).run();
|
|
GRAPH_DEBUG("After creating autodiff subgraphs", *graph);
|
|
AddRequiresGradOnOutputNodes(graph);
|
|
GRAPH_DEBUG("diff_nodes.size() ", diff_nodes.size());
|
|
return diff_nodes;
|
|
}
|
|
} // namespace torch::jit
|