Files
pytorch/torch/csrc/jit/passes/dead_code_elimination.cpp
David Berard a237831bc2 [JIT] Optimize DCE by storing a MemoryLocations for an entire set<Value*> (#153645)
Summary:
**TL;DR**: make DCE faster by replacing a Set<Value*> with a MemoryLocations sparse bitset (representing all the memory locations stored by the collection of all values in the set).

**Details**
The goal of this PR is to optimize this function from AliasDb:

```
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
  const auto writtenTo = getWrites(n);
  if (writtenTo.empty()) {
    return false;
  }

  MemoryLocations locs;
  for (const auto v : vs) {
    auto it = elementMap_.find(v);
    if (it != elementMap_.end()) {
      const auto& vlocs = memoryDAG_->getMemoryLocations(it->second);
      if (writtenTo.intersects(vlocs)) {
        return true;
      }
    }
  }

  return false;
}
```

In the DCE use case, we have a ValueSet of live values, into which we insert `Value*`s; and sometimes need to check whether a node mutates any of the live values using `writesToAlias`.

Looping through all the values in the ValueSet and indexing into the elementMap_ is slow; so if we can pre-compute the MemoryLocations set, this speeds up the function. In some large model examples, I see ~15-25x speedups from this change.

**Implementation**: To avoid exposing too many details of AliasDb, I introduce a friend class `ValueAndMemoryLocationSet`, which is an insert-only set of Values, which also maintains the corresponding MemoryLocations.

Then in AliasDb, I use `ValueAndMemoryLocationSet` if we're using AliasDb for analysis, and otherwise use a `Set<Value*>` if we don't have AliasDb.

Test Plan: Rely on unit tests.

Differential Revision: D74827086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153645
Approved by: https://github.com/eellison
2025-05-19 21:04:59 +00:00

509 lines
16 KiB
C++

#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/jit_log.h>
#include <unordered_map>
namespace torch::jit {
namespace prim {
using namespace ::c10::prim;
}
class DeadCodeEliminator {
public:
explicit DeadCodeEliminator(
std::shared_ptr<Graph> graph,
DCESideEffectPolicy sideEffectPolicy)
: sideEffectPolicy_(sideEffectPolicy),
graph_(std::move(graph)),
useAliasDb_(true) {}
DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy)
: sideEffectPolicy_(sideEffectPolicy) {}
// The algorithm is an inverse mark-and-sweep. Starting from the return node,
// we mark "live" nodes that are necessary for the output. Nodes that have
// side effects are also marked.
void run(Block* block, bool recurse) {
// clean up unused fork inputs before starting the main algorithm
eliminateDeadForkInputs(block, recurse);
// Initialize by marking the return node and all its consumed values as live
mark(block->return_node());
mark(block);
deleteCallback_(getLiveValues());
sweep(block, recurse);
}
void setDeleteCallback(
std::function<void(const std::unordered_set<const Value*>&)>
deleteCallback) {
deleteCallback_ = std::move(deleteCallback);
}
private:
void eliminateDeadForkInputs(Block* block, bool recurse) {
for (Node* node : block->nodes()) {
if (recurse) {
for (Block* sb : node->blocks()) {
eliminateDeadForkInputs(sb, recurse);
}
}
if (node->kind() != prim::fork) {
continue;
}
Graph& g = *node->g(attr::Subgraph);
// WARNING: Do not use a ranged loop. The loop bounds are changed by the
// loop body.
for (size_t i = 0; i < g.inputs().size(); ++i) {
if (!g.inputs().at(i)->hasUses()) {
GRAPH_UPDATE(
"Dead ",
i,
"-th input ",
node->inputs().at(i)->debugName(),
"(",
g.inputs().at(i)->debugName(),
" in a subgraph) will be removed");
g.eraseInput(i);
node->removeInput(i);
}
}
}
}
// Special handling for block return nodes. Unlike other nodes, the block
// return node doesn't really "use" its inputs. Consider:
//
// %a0 = aten::foo()
// %b = aten::foo()
// %a2, %b2 = prim::If(%cond) {
// block0() {
// %a1 = aten::foo(%.0)
// %b1 = aten::foo(%b)
// } -> (%a1, %b1)
// }
// return (%a2)
//
// We want to be able to DCE all the %b stuff. So when processing block
// returns, we only mark producers for values that "live" (i.e. used outside
// the block).
//
// Returns true iff this marked something we haven't marked before.
bool markReturnNode(Node* node) {
if (marked_.count(node)) {
return false;
}
AT_ASSERT(node->owningBlock()->return_node() == node);
auto outerNode = node->owningBlock()->owningNode();
if (outerNode == nullptr || outerNode->kind() == prim::Reverse) {
// If there's no outer node, we're looking at the graph's top-level
// return block. We consider all graph outputs to be "used", so just mark
// this node normally.
return mark(node);
}
// Collect all inputs that are actually live
if (outerNode->kind() == prim::Loop ||
outerNode->kind() == c10::onnx::Loop) {
// Special handling to deal with loop carried dependencies.
auto loop = LoopView(outerNode);
for (const auto i : c10::irange(loop.carriedOutputs().size())) {
if (outerNode->kind() == c10::onnx::Loop) {
// Special handling for onnx loop.
// The number of body carried inputs and outputs are different.
// They cannot be mapped to each other easily by the same index.
insertLiveValue(loop.bodyCarriedOutputs().at(i));
continue;
}
auto innerInput = loop.bodyCarriedInputs().at(i);
auto innerOutput = loop.bodyCarriedOutputs().at(i);
auto outerOutput = loop.carriedOutputs().at(i);
if (liveValuesContains(outerOutput) || innerInput->hasUses()) {
insertLiveValue(innerOutput);
}
}
// Also mark the loop next condition as live, since it will be used inside
// the loop body.
insertLiveValue(loop.nextCond());
} else {
AT_ASSERT(outerNode->outputs().size() == node->inputs().size());
for (const auto i : c10::irange(outerNode->outputs().size())) {
auto innerOutput = node->inputs()[i];
auto outerOutput = outerNode->outputs()[i];
if (liveValuesContains(outerOutput)) {
insertLiveValue(innerOutput);
}
}
}
marked_.insert(node);
return true;
}
// Loops are special, because we need to run them to convergence.
// Consider the following loop:
// for i in range(3):
// tot += a[0][0]
// b = a[0]
// b[0] += 1
// print(tot)
//
// If we only process the loop block once, we will conclude that `b[0]` and
// `b` are dead, even though `b[0] += 1` mutates a live memory location (since
// `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next
// iteration
//
// We need to mark the loop again with the information that `a` is live, and
// repeat until we're not marking new stuff anymore.
//
// Returns true iff this marked something we haven't marked before.
bool markLoop(Node* node) {
TORCH_INTERNAL_ASSERT(node->kind() == prim::Loop);
// Did a single iteration over the loop block mark anything new?
// If this is false, we've converged.
bool marked = false;
// Did we ever mark anything new?
bool anyMarked = false;
do {
marked = mark(node->blocks().at(0));
anyMarked |= marked;
} while (marked);
return anyMarked;
}
// Returns true iff this marked something we haven't marked before.
bool mark(Block* block) {
bool anyMarked = false;
// Mark all nodes with side effects.
for (auto node : block->nodes()) {
if (sideEffectPolicy_ ==
DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS &&
hasSideEffects(node)) {
anyMarked |= mark(node);
}
}
// Initialize by marking the return node
anyMarked |= markReturnNode(block->return_node());
for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) {
auto node = *it;
if (node->kind() == prim::Loop) {
// Special casing for loops, see comment in markLoop.
anyMarked |= markLoop(node);
} else {
// Other nodes with sub-blocks get marked normally.
for (auto subBlock : node->blocks()) {
anyMarked |= mark(subBlock);
}
}
anyMarked |= markIfLive(node);
}
return anyMarked;
}
// If we output or write to a live memory location, mark this node
// Returns true iff this marked something we haven't marked before.
bool markIfLive(Node* node) {
for (const auto output : node->outputs()) {
if (liveValuesContains(output)) {
return mark(node);
}
}
if (useAliasDb_) {
if (getOrCreateAliasDb()->writesToAlias(
node, getLiveValuesAndMemoryLocations())) {
return mark(node);
}
}
return false;
}
// Mark this node as live and add this node's inputs and aliases to the live
// value sets.
// Returns true iff this marked something we haven't marked before.
bool mark(Node* node) {
if (marked_.count(node)) {
return false;
}
marked_.insert(node);
// Mark all nodes in this node's blockchain (since owning nodes are
// considered live if they contain a live node)
auto curNode = node;
while (curNode) {
if (!curNode->owningBlock()) {
break;
}
mark(curNode);
curNode = curNode->owningBlock()->owningNode();
}
for (const auto input : node->inputs()) {
if (liveValuesContains(input)) {
continue;
}
insertLiveValue(input);
}
return true;
}
// Delete all unmarked nodes.
void sweep(Block* block, bool recurse) {
auto nodes = block->nodes().reverse();
for (auto it = nodes.begin(); it != nodes.end(); it++) {
auto node = *it;
// note these occur before the recursion because we want to uncover
// dead code in the blocks used to calculate the output
removeDeadBlockOutputs(node);
removeDeadLoopOutputs(node);
if (recurse) {
for (Block* block : node->blocks()) {
sweep(block, true);
}
}
// NB: Checking hasUses() is required. AD graphs are not perfectly
// valid, as a node in grad_desc.f might be used in reverse_block.
// Reverse_block is inlined in grad_desc.f before it's separated
// to grad_desc.df.
if (!(marked_.count(node) || node->hasUses())) {
GRAPH_UPDATE(
"Node ",
it->kind().toQualString(),
" which outputs ",
(!node->outputs().empty() ? node->outputs().at(0)->debugName()
: "n/a"),
" will be removed");
it.destroyCurrent();
}
}
}
bool hasUntrackedMutation(Node* node) {
if (!useAliasDb_) {
// If we don't have alias information, all mutable ops have unknown
// effects and can't be considered for elimination.
if (node->kind() == prim::SetAttr) {
// SetAttr is a special case: it doesn't have a schema, but does
// have untracked mutations
return true;
}
// onnx export calls EliminateDeadCode but sometimes passes invalid
// aten operators. So we call maybeSchema so we handle the cases when
// there is no valid schema for a node
auto schema = node->maybeSchema();
return schema && schema->is_mutable();
} else {
return getOrCreateAliasDb()->writesToWildcard(node);
}
}
bool hasSideEffects(Node* node) {
auto it = memo_.find(node);
if (it != memo_.end())
return it->second;
bool has_side_effects = node->hasSideEffects() ||
std::any_of(node->blocks().begin(),
node->blocks().end(),
[&](Block* b) {
return std::any_of(
b->nodes().begin(), b->nodes().end(), [&](Node* n) {
return hasSideEffects(n);
});
}) ||
hasUntrackedMutation(node);
memo_.emplace(node, has_side_effects);
return has_side_effects;
}
void removeDeadBlockOutputs(Node* node) {
if (node->kind() != prim::If && node->kind() != prim::GradOf) {
return;
}
for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
size_t i = i_1 - 1;
if (!node->outputs().at(i)->hasUses()) {
GRAPH_UPDATE(
"Dead ",
i,
"-th output ",
node->outputs().at(i)->debugName(),
" of node ",
node->kind().toQualString(),
" will be removed");
node->eraseOutput(i);
for (Block* b : node->blocks()) {
GRAPH_UPDATE(
"\tCorresponding block output ",
b->outputs().at(i)->debugName(),
" will be removed");
b->eraseOutput(i);
}
}
}
}
void removeDeadLoopOutputs(Node* node) {
if (node->kind() != prim::Loop)
return;
auto loop_body = node->blocks().at(0);
auto loop_input_offset = 2; // offset of loop carried deps in input list
auto loop_body_offset =
1; // offset to the loop carried dependencies in block inputs/outputs
for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
size_t i = i_1 - 1;
if (!node->outputs().at(i)->hasUses() &&
!loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
logDeadLoopOutputs(node, i, loop_input_offset, loop_body_offset);
node->eraseOutput(i);
node->removeInput(loop_input_offset + i);
loop_body->eraseInput(loop_body_offset + i);
loop_body->eraseOutput(loop_body_offset + i);
}
}
}
void logDeadLoopOutputs(
Node* node,
size_t i,
size_t loop_input_offset,
size_t loop_body_offset) {
auto loop_body = node->blocks().at(0);
GRAPH_UPDATE(
"Dead ",
loop_input_offset + i,
"-th input ",
node->inputs().at(i)->debugName(),
" will be removed");
GRAPH_UPDATE(
"Dead ",
i,
"-th output ",
node->outputs().at(i)->debugName(),
" will be removed");
GRAPH_UPDATE(
"\tDead block input ",
loop_body->inputs().at(loop_body_offset + i)->debugName(),
"at offset ",
loop_body_offset + i,
" will be removed");
GRAPH_UPDATE(
"\tDead block output ",
loop_body->outputs().at(loop_body_offset + i)->debugName(),
"at offset ",
loop_body_offset + i,
" will be removed");
}
AliasDb* getOrCreateAliasDb() {
if (!aliasDb_) {
aliasDb_ = std::make_unique<AliasDb>(graph_);
}
return aliasDb_.get();
}
ValueAndMemoryLocationSet& getLiveValuesAndMemoryLocations() {
if (!liveValuesAndMemoryLocations_) {
liveValuesAndMemoryLocations_ =
std::make_unique<ValueAndMemoryLocationSet>(
getOrCreateAliasDb()->getValueAndMemoryLocationSet());
}
return *liveValuesAndMemoryLocations_;
}
ValueSet& getLiveValuesSet() {
if (!liveValuesSet_) {
liveValuesSet_ = std::make_unique<ValueSet>();
}
return *liveValuesSet_;
}
ValueSet& getLiveValues() {
if (useAliasDb_) {
return getLiveValuesAndMemoryLocations().getValueSet();
} else {
return getLiveValuesSet();
}
}
void insertLiveValue(Value* v) {
if (useAliasDb_) {
getLiveValuesAndMemoryLocations().insert(v);
} else {
getLiveValuesSet().insert(v);
}
}
bool liveValuesContains(Value* v) {
if (useAliasDb_) {
return getLiveValuesAndMemoryLocations().getValueSet().count(v);
} else {
return getLiveValuesSet().count(v);
}
}
DCESideEffectPolicy sideEffectPolicy_;
std::shared_ptr<Graph> graph_;
bool useAliasDb_ = false;
// lazily initialized
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
std::unordered_map<Node*, bool> memo_;
std::unordered_set<Node*> marked_;
// we should have at most 1 of these as a non-nullptr; they are lazily
// initialized. liveValuesAndMemoryLocations_ is used if we are using AliasDb
// (in order to store aliasing info),
// otherwise liveValuesSet_ is used.
std::unique_ptr<ValueAndMemoryLocationSet> liveValuesAndMemoryLocations_ =
nullptr;
std::unique_ptr<ValueSet> liveValuesSet_ = nullptr;
std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
[](const std::unordered_set<const Value*>&) {};
};
void EliminateDeadCode(
const std::shared_ptr<Graph>& graph,
DCESideEffectPolicy sideEffectPolicy) {
DeadCodeEliminator(graph, sideEffectPolicy)
.run(graph->block(), /*recurse=*/true);
GRAPH_DUMP("After EliminateDeadCode: ", graph);
}
void EliminateDeadCode(
Block* block,
bool recurse,
DCESideEffectPolicy sideEffectPolicy) {
DeadCodeEliminator(sideEffectPolicy).run(block, recurse);
}
void EliminateDeadCode(
Block* block,
std::function<void(const std::unordered_set<const Value*>&)> cb,
DCESideEffectPolicy sideEffectPolicy) {
DeadCodeEliminator eliminator(sideEffectPolicy);
eliminator.setDeleteCallback(std::move(cb));
eliminator.run(block, /*recurse=*/true);
}
} // namespace torch::jit