make wildcards alias only each other (#20670)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20670
ghimport-source-id: f5704c49fcb829e4668441f31fcf9305da22335c

Reviewed By: jamesr66a

Differential Revision: D15447567

Pulled By: suo

fbshipit-source-id: 391236806838de2524410e26946456441e562470
This commit is contained in:
Michael Suo
2019-05-22 16:42:33 -07:00
committed by Facebook Github Bot
parent 90910fc6cb
commit 7aa3887f43
4 changed files with 140 additions and 431 deletions

View File

@ -717,11 +717,10 @@ graph():
// But we know `fresh` didn't go into a list, so x, y, and z should not
// alias it.
// auto fresh = vmap["fresh"];
// ASSERT_FALSE(aliasDb.mayAlias(x, fresh));
// ASSERT_FALSE(aliasDb.mayAlias(y, fresh));
// ASSERT_FALSE(aliasDb.mayAlias(z, fresh));
auto fresh = vmap["fresh"];
ASSERT_FALSE(aliasDb.mayAlias(x, fresh));
ASSERT_FALSE(aliasDb.mayAlias(y, fresh));
ASSERT_FALSE(aliasDb.mayAlias(z, fresh));
}
}
@ -750,7 +749,7 @@ void testWildcards() {
AliasDb aliasDb(graph);
ASSERT_FALSE(aliasDb.mayAlias(a, fresh));
ASSERT_TRUE(aliasDb.mayAlias(wildcard, fresh));
ASSERT_FALSE(aliasDb.mayAlias(wildcard, fresh));
ASSERT_TRUE(aliasDb.mayAlias(wildcard, a));
ASSERT_FALSE(aliasDb.mayAlias(
std::unordered_set<const Value*>({wildcard}),
@ -762,8 +761,7 @@ void testWildcards() {
{
graph->lint();
AliasDb aliasDb(graph);
// Any write should be considered a write to the wildcard
ASSERT_TRUE(aliasDb.hasWriters(wildcard->node()));
ASSERT_FALSE(aliasDb.hasWriters(wildcard->node()));
}
const auto wildcardWrite = graph->insert(writes, {wildcard})->node();
@ -771,9 +769,9 @@ void testWildcards() {
graph->lint();
AliasDb aliasDb(graph);
// Test writes to wildcards
ASSERT_TRUE(aliasDb.writesToAlias(
ASSERT_FALSE(aliasDb.writesToAlias(
wildcardWrite, std::unordered_set<const Value*>{fresh}));
ASSERT_TRUE(aliasDb.writesToAlias(
ASSERT_FALSE(aliasDb.writesToAlias(
wildcardWrite, std::unordered_set<const Value*>{fresh2}));
ASSERT_TRUE(aliasDb.writesToAlias(
wildcardWrite, std::unordered_set<const Value*>{a}));

View File

@ -9,15 +9,33 @@
namespace torch {
namespace jit {
// Get a typekind that can be used as a key to distinguish different kinds of
// mutable types. If the type is not mutable, return nullopt.
//
// TODO: We use these rules to divide wildcards into distinct "buckets", where
// every wildcard that resolves to the same kind will alias each other. We can
// introduce more granularity here (e.g. List[int] will never alias
// List[float]).
c10::optional<TypeKind> AliasDb::getMutableTypeKind(const TypePtr& type) {
if (type->isSubtypeOf(TensorType::get())) {
return TypeKind::TensorType;
}
switch (type->kind()) {
case TypeKind::ListType:
case TypeKind::TupleType:
case TypeKind::DictType:
case TypeKind::ClassType:
return type->kind();
case TypeKind::OptionalType:
return getMutableTypeKind(type->cast<OptionalType>()->getElementType());
default:
return c10::nullopt;
}
}
bool AliasDb::shouldAnnotate(const TypePtr& type) {
return type->isSubtypeOf(TensorType::get()) ||
type->kind() == TypeKind::ListType ||
type->kind() == TypeKind::TupleType ||
type->kind() == TypeKind::DictType || type->kind() == TypeKind::VarType ||
type->kind() == TypeKind::FutureType ||
type->kind() == TypeKind::ClassType ||
(type->kind() == TypeKind::OptionalType &&
shouldAnnotate(type->cast<OptionalType>()->getElementType()));
return getMutableTypeKind(type) != c10::nullopt;
}
// We only need to annotate values that either are mutable or could contain
@ -43,48 +61,6 @@ AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
analyze(graph_);
}
// Does `n` use or write to any wildcard aliases?
bool AliasDb::hasWildcard(const Node* n) const {
for (const auto input : n->inputs()) {
if (isWildcard(input)) {
return true;
}
}
for (const auto output : n->outputs()) {
if (isWildcard(output)) {
return true;
}
}
return false;
}
bool AliasDb::isWildcard(const Value* v) const {
return wildcards_.count(v);
}
bool AliasDb::writesTo(Node* n, const Value* v) const {
if (!shouldAnnotate(v) || v->mustBeNone()) {
// This is a non-aliasing value
return false;
}
if (isWildcard(v)) {
return wildcardWriters_.count(n);
}
if (!elementMap_.count(v) || !writeIndex_.count(n)) {
return false;
}
// Can short-circuit if we know this node writes directly to `v`
if (writeIndex_.at(n).count(v)) {
return true;
}
// Otherwise, check if `v` may alias any of written-to values in `n`
const auto vSet = ValueSet{v};
return mayAlias(vSet, writeIndex_.at(n));
}
bool AliasDb::hasWriters(const Node* n) const {
for (const auto input : n->inputs()) {
if (hasWriters(input)) {
@ -100,22 +76,10 @@ bool AliasDb::hasWriters(const Node* n) const {
}
bool AliasDb::hasWriters(const Value* v) const {
if (isWildcard(v)) {
// If `n` has a wildcard, any write in the graph may write to it.
// So the only way we know there are no writers is if there are no writes
// at all.
return numWrites_ != 0;
}
if (!elementMap_.count(v) || v->mustBeNone()) {
return false;
}
if (wildcardWriters_.size() > 0) {
// A write to the wildcard may be a write to any value.
return true;
}
if (isWriteCacheStale_) {
rebuildWriteCache();
}
@ -129,44 +93,6 @@ bool AliasDb::hasWriters(const Value* v) const {
return false;
}
bool AliasDb::hasWrites(Node* n) const {
for (const auto input : n->inputs()) {
if (writesTo(n, input)) {
return true;
}
}
for (const auto output : n->outputs()) {
if (writesTo(n, output)) {
return true;
}
}
return false;
}
bool AliasDb::writesToInputAlias(Node* n) const {
std::vector<const Value*> writes;
for (const auto input : n->inputs()) {
if (writesTo(n, input)) {
writes.push_back(input);
}
}
for (const auto output : n->outputs()) {
if (writesTo(n, output)) {
writes.push_back(output);
}
}
// For all writes, check if the written value may alias a graph input
return std::any_of(writes.cbegin(), writes.cend(), [&](const Value* v) {
return std::any_of(
graph_->inputs().cbegin(),
graph_->inputs().cend(),
[&](const Value* graphInput) {
return shouldAnnotate(graphInput) && mayAlias(graphInput, v);
});
});
}
void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const {
for (auto node : b->nodes()) {
getWritesImpl(node, ret, recurseBlocks);
@ -174,14 +100,10 @@ void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const {
}
void AliasDb::getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks) const {
for (const auto input : n->inputs()) {
if (writesTo(n, input)) {
ret.insert(input);
}
}
for (const auto output : n->outputs()) {
if (writesTo(n, output)) {
ret.insert(output);
if (writeIndex_.count(n)) {
const auto& writes = writeIndex_.at(n);
for (const auto write : writes) {
ret.insert(write);
}
}
@ -192,13 +114,6 @@ void AliasDb::getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks) const {
}
}
// Get all writes by all nodes in a block, recursively exploring sub-blocks
ValueSet AliasDb::getWrites(Block* b) const {
ValueSet writes;
getWritesImpl(b, writes, /*recurseBlocks=*/true);
return writes;
}
// Does `n` write to an alias of one of the values in `vs`?
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks)
const {
@ -236,6 +151,14 @@ ValueSet AliasDb::getReads(Node* n, bool recurseBlocks) const {
return reads;
}
static std::string getElementName(const Element* e) {
if (e->value == nullptr) {
return "WILDCARD";
} else {
return e->value->uniqueName();
}
}
void AliasDb::dump() const {
std::cout << "\n===1. GRAPH===\n";
graph_->dump();
@ -244,28 +167,22 @@ void AliasDb::dump() const {
for (const auto& ptrPair : elementMap_) {
const auto element = ptrPair.second;
if (element->pointsTo.size() > 0) {
std::cout << element->value->uniqueName() << " points to: ";
std::cout << getElementName(element) << " points to: ";
for (const auto pointedTo : element->pointsTo) {
std::cout << pointedTo->value->uniqueName() << ", ";
std::cout << getElementName(pointedTo) << ", ";
}
std::cout << "\n";
}
if (element->contained_elements.size() > 0) {
std::cout << element->value->uniqueName() << " contains: ";
std::cout << getElementName(element) << " contains: ";
for (const auto contained : element->contained_elements) {
std::cout << contained->value->uniqueName() << ", ";
std::cout << getElementName(contained) << ", ";
}
std::cout << "\n";
}
}
std::cout << "\n===3. WILDCARDS===\n";
for (const auto wildcard : wildcards_) {
std::cout << wildcard->uniqueName() << ", ";
}
std::cout << "\n";
std::cout << "\n===4. Writes===\n";
std::cout << "\n===3. Writes===\n";
for (const auto& pr : writeIndex_) {
const auto node = pr.first;
const auto& values = pr.second;
@ -279,75 +196,10 @@ void AliasDb::dump() const {
std::cout << "\n";
}
// TODO: need to create a dummy "graph input alias" value in MemoryDAG for all
// inputs of the same type to point to. Currently they all point to the first
// element, which is technically wrong.
void AliasDb::makeAllAlias(const std::vector<Value*>& values) {
if (values.size() > 0) {
giveFreshAlias(values[0]);
}
for (const auto value : values) {
makePointerTo(value, values[0]);
}
}
void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
// Assign aliases to the graph's inputs, assuming that all inputs of a given
// type may alias to each other.
// 1. Partition inputs by their type
std::map<TypeKind, std::vector<Value*>> listTypes;
std::unordered_map<TupleTypePtr, std::vector<Value*>> tupleTypes;
std::unordered_map<DictTypePtr, std::vector<Value*>> dictTypes;
std::unordered_map<ClassTypePtr, std::vector<Value*>> classTypes;
std::vector<Value*> tensors;
for (auto input : graph->inputs()) {
auto inputType = input->type();
// unwrap optional types
if (inputType->kind() == TypeKind::OptionalType) {
inputType = inputType->cast<OptionalType>()->getElementType();
}
if (inputType->isSubtypeOf(TensorType::get())) {
tensors.push_back(input);
} else if (inputType->kind() == TypeKind::ListType) {
auto containedType = inputType->containedTypes().at(0);
// All tensor subtypes may alias to each other, so we should consider all
// lists of them to alias to each other.
if (containedType->isSubtypeOf(TensorType::get())) {
containedType = TensorType::get();
}
listTypes[containedType->kind()].push_back(input);
} else if (inputType->kind() == TypeKind::TupleType) {
auto tupleType = inputType->cast<TupleType>();
tupleTypes[tupleType].push_back(input);
} else if (inputType->kind() == TypeKind::DictType) {
auto dictType = inputType->cast<DictType>();
dictTypes[dictType].push_back(input);
} else if (inputType->kind() == TypeKind::ClassType) {
auto classType = inputType->cast<ClassType>();
classTypes[classType].push_back(input);
} else {
AT_ASSERT(!shouldAnnotate(input));
}
setWildcard(input);
}
// 2. Make all partitions alias each other
for (const auto& pr : listTypes) {
makeAllAlias(pr.second);
}
for (const auto& pr : tupleTypes) {
makeAllAlias(pr.second);
}
for (const auto& pr : dictTypes) {
makeAllAlias(pr.second);
}
for (const auto& pr : classTypes) {
makeAllAlias(pr.second);
}
makeAllAlias(tensors);
analyze(graph->block());
}
@ -359,11 +211,6 @@ void AliasDb::analyze(Block* block) {
void AliasDb::analyze(Node* node) {
analyzeImpl(node);
// After analyzing, update the wildcard index
if (hasWildcard(node)) {
wildcardNodes_.insert(node);
}
}
// Returns true if analysis was run using
@ -587,15 +434,7 @@ void AliasDb::registerWrite(const Value* v, Node* n) {
// don't need to register a write if the value isn't mutable
return;
}
numWrites_++;
if (isWildcard(v)) {
wildcardWriters_.insert(n);
return;
}
AT_ASSERT(elementMap_.count(v));
TORCH_INTERNAL_ASSERT(elementMap_.count(v));
writeIndex_[n].insert(v);
}
@ -646,9 +485,6 @@ void AliasDb::analyzeGradOf(Node* node) {
void AliasDb::analyzeSubgraph(Node* node) {
const auto subgraph = node->g(attr::Subgraph).get();
subgraphToOwner_.insert({subgraph, node});
const auto subgraphBlock = subgraph->block();
mapAliases(subgraphBlock->inputs(), node->inputs());
@ -674,9 +510,7 @@ void AliasDb::analyzeCreator(Node* node) {
// gives up and creates wildcards for everything.
void AliasDb::analyzeExtractor(Node* node) {
for (const auto output : node->outputs()) {
if (shouldAnnotate(output)) {
setWildcard(output);
}
setWildcard(output);
}
}
@ -687,16 +521,10 @@ void AliasDb::analyzeChunk(Node* node) {
}
}
// Propagate aliasing and write information from the subgraph outputs to the
// outputs of the corresponding aten::wait() calls, since that's where the
// values will eventually emerge.
void AliasDb::analyzeFork(Node* node) {
const auto subgraph = node->g(attr::Subgraph).get();
subgraphToOwner_.insert({subgraph, node});
const auto subgraphBlock = subgraph->block();
mapAliases(subgraphBlock->inputs(), node->inputs());
analyze(subgraphBlock);
for (const auto input : node->inputs()) {
setWildcard(input);
}
// Give the future that the fork emits a fresh value
for (const auto output : node->outputs()) {
@ -705,51 +533,23 @@ void AliasDb::analyzeFork(Node* node) {
}
void AliasDb::analyzeWait(Node* node) {
const auto fut = node->input();
AT_ASSERT(fut->type()->kind() == TypeKind::FutureType);
if (isWildcard(fut)) {
for (const auto output : node->outputs()) {
setWildcard(output);
}
return;
TORCH_INTERNAL_ASSERT(node->kind() == aten::wait);
for (const auto output : node->outputs()) {
setWildcard(output);
}
const auto originFuts = getMemoryLocations(fut);
for (const auto originFut : originFuts) {
const auto subgraphNode = originFut->node();
const auto subgraph = subgraphNode->g(attr::Subgraph).get();
const auto subgraphWrites = getWrites(subgraph->block());
// Retrieve aliasing info from the subgraph
mapAliases(node->outputs(), subgraph->outputs());
// Propagate write information to the `wait` node.
//
// We need to do this for all writes in the entire subgraph, so that we
// disallow reorders past a call to "aten::wait".
//
// Consider the following Fork where the subgraph writes to %a:
//
// %c : Future[Tensor] = prim::Fork(%a, %b) <-- writes to %a
// ...
// aten::wait(%c)
// aten::use(%a) <-- we can't move this node before the `wait` safely!
//
// Say we define the "live interval" of a fork the interval between the
// `fork` and its first corresponding `wait` (inclusive).
//
// Any writes in the subgraph can happen at any point in the live interval,
// so it's not safe to re-order any reads to those memory locations from
// outside the live interval to inside.
//
// In reality, any reads *inside* the live interval are undefined behavior,
// since the writes may or may not have been executed yet. But we'll let
// users do that and shoot themselves in the foot for now.
for (const auto write : subgraphWrites) {
registerWrite(write, node);
}
// the forked subgraph that `wait` is waiting on may write to any of its
// inputs. We don't have a reliable way of recovering the fork inputs, so
// for safety we just register a write to every wildcard.
for (const auto& pr : wildcardIndex_) {
// TODO: Given the way the write query API is written, we can't regiser a
// write directly against the wildcard element. So find a wildcard value in
// the graph to write to.
const auto el = pr.second;
const auto& pointedFrom = el->pointedFrom;
TORCH_INTERNAL_ASSERT(!pointedFrom.empty());
const auto wildcardValue = (*pointedFrom.begin())->value;
TORCH_INTERNAL_ASSERT(wildcardValue);
registerWrite(wildcardValue, node);
}
}
@ -771,6 +571,9 @@ void AliasDb::analyzeSetAttr(Node* node) {
const auto self = node->inputs().at(0);
AT_ASSERT(self->type()->kind() == TypeKind::ClassType);
registerWrite(self, node);
// Also the value being set must become a wildcard.
const auto newValue = node->inputs().at(1);
setWildcard(newValue);
}
// Custom ops may write to any input and produce wildcards
@ -819,7 +622,7 @@ void AliasDb::analyzeBroadcastingChunk(Node* node) {
}
}
// Register the fact that `value` is a pointer to `to`
// Register the fact that `from` is a pointer to `to`
void AliasDb::makePointerTo(const Value* from, const Value* to) {
if (!shouldAnnotate(from)) {
AT_ASSERT(!shouldAnnotate(to));
@ -840,13 +643,6 @@ void AliasDb::makePointerTo(const Value* from, const Value* to) {
// At this point, we should be dealing with two mutable types.
AT_ASSERT(shouldAnnotate(from) && shouldAnnotate(to));
// If either value is a wildcard, don't insert anything into the graph;
// wildcards are tracked separately since they have different aliasing rules.
if (isWildcard(to) || isWildcard(from)) {
setWildcard(from);
return;
}
auto fromEl = getOrCreateElement(from);
auto toEl = getOrCreateElement(to);
@ -860,11 +656,6 @@ void AliasDb::addToContainedElements(
return;
}
// wildcards tracked separately
if (isWildcard(elem)) {
return;
}
AT_ASSERT(isContainerType(container->type()));
auto elemEl = getOrCreateElement(elem);
@ -878,18 +669,10 @@ bool AliasDb::mayAlias(const Value* a, const Value* b) const {
return false;
}
if (isWildcard(a) || isWildcard(b)) {
return true;
}
return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b));
}
bool AliasDb::cannotCheckAliasContainment(const Value* elem) const {
if (isWildcard(elem)) {
return true;
}
if (isContainerType(elem->type())) {
if (elem->node()->kind() != prim::TupleConstruct) {
return true;
@ -952,7 +735,7 @@ void AliasDb::giveFreshAlias(const Value* value) {
return;
}
if (isTracked(value)) {
if (elementMap_.count(value)) {
// Inside a loop, we may have given a fresh alias to this value already, so
// skip
return;
@ -962,16 +745,12 @@ void AliasDb::giveFreshAlias(const Value* value) {
}
Element* AliasDb::getOrCreateElement(const Value* value) {
if (!isTracked(value)) {
if (!elementMap_.count(value)) {
giveFreshAlias(value);
}
return elementMap_.at(value);
}
bool AliasDb::isTracked(const Value* v) const {
return isWildcard(v) || elementMap_.count(v);
}
bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false);
}
@ -1015,9 +794,6 @@ class AliasDb::WorkingSet {
for (const auto& read : aliasDb_.getReads(n, /*recurseBlocks=*/true)) {
reads_.insert(read);
}
if (aliasDb_.hasWildcard(n)) {
numWildcards_++;
}
}
void eraseMover() {
@ -1042,9 +818,6 @@ class AliasDb::WorkingSet {
reads_.erase(it);
}
}
if (aliasDb_.hasWildcard(mover)) {
numWildcards_--;
}
nodes_.pop_front();
}
@ -1071,18 +844,6 @@ class AliasDb::WorkingSet {
}
bool hasMutabilityDependency(Node* n) const {
// 1. Handle wildcard dependencies:
// If the working set has a wildcard, `n` can't write to anything.
if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) {
return true;
}
// If `n` has a wildcard, the working set can't write to anything.
if (aliasDb_.hasWildcard(n) && writes_.size() > 0) {
return true;
}
// 2. Handle regular mutable dependencies
// Check that `n` does not write to anything used by the working set
const auto nWrites = aliasDb_.getWrites(n, /*recurseBlocks=*/true);
if (aliasDb_.mayAlias(nWrites, reads_)) {
@ -1158,7 +919,6 @@ class AliasDb::WorkingSet {
// Values written to by the working set => number of nodes writing to value
std::unordered_multiset<const Value*> writes_;
std::unordered_multiset<const Value*> reads_;
size_t numWildcards_ = 0;
};
// Try to move `toMove` before/after `movePoint` while preserving value
@ -1274,53 +1034,16 @@ void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
}
}
bool AliasDb::hasUntrackedEffects(Node* node) const {
bool touchesWildcard = false;
if (const auto lastWildcard = getLastWildcard()) {
touchesWildcard = hasWrites(node) &&
(isBeforeSameGraph(node, *lastWildcard) || node == *lastWildcard);
bool AliasDb::writesToWildcard(Node* n) const {
if (!writeIndex_.count(n)) {
return false;
}
const auto& writes = writeIndex_.at(n);
return writesToInputAlias(node) || touchesWildcard;
}
// Nodes must be in the same graph in order to do `isBefore` or `isAfter`. This
// traverses the subgraph "chain" upward until we find two nodes that share an
// owning graph.
//
// NOTE: this is n^2 in subgraph depth. Right now the maximum depth is like 2,
// but if we ever do huge nested subgraphs we'll need to reconsider this.
bool AliasDb::isBeforeSameGraph(const Node* a, const Node* b) const {
auto lhs = a;
while (true) {
auto rhs = b;
while (true) {
if (lhs->owningGraph() == rhs->owningGraph()) {
return lhs->isBefore(rhs);
}
if (!subgraphToOwner_.count(rhs->owningGraph())) {
break;
}
rhs = subgraphToOwner_.at(rhs->owningGraph());
}
if (!subgraphToOwner_.count(lhs->owningGraph())) {
break;
}
lhs = subgraphToOwner_.at(lhs->owningGraph());
}
AT_ASSERT(false);
}
c10::optional<const Node*> AliasDb::getLastWildcard() const {
auto it = std::max_element(
wildcardNodes_.cbegin(),
wildcardNodes_.cend(),
[this](const Node* a, const Node* b) { return isBeforeSameGraph(a, b); });
if (it != wildcardNodes_.end()) {
return *it;
} else {
return c10::nullopt;
}
// For all writes, check if the written value is a wildcard
return std::any_of(writes.cbegin(), writes.cend(), [&](const Value* v) {
return mayAliasWildcard(v);
});
}
bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
@ -1379,12 +1102,51 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
return handled.count(symbol) || purposefully_not_handled.count(symbol);
}
bool AliasDb::mayAliasWildcard(const Value* v) const {
if (!shouldAnnotate(v)) {
return false;
}
if (auto e = getWildcard(v->type())) {
return memoryDAG_->mayAlias(elementMap_.at(v), e);
}
// There were no wildcards of this type, so return false.
return false;
}
// Search the wildcard index for an element that corresponds to the given type.
Element* AliasDb::getOrCreateWildcard(const TypePtr& type) {
TORCH_INTERNAL_ASSERT(shouldAnnotate(type));
const auto kind = getMutableTypeKind(type);
TORCH_INTERNAL_ASSERT(kind);
if (!wildcardIndex_.count(*kind)) {
// create a new empty Element to stand in for the wildcard set.
wildcardIndex_.emplace(*kind, memoryDAG_->makeFreshValue(nullptr));
}
return wildcardIndex_.at(*kind);
}
// Search the wildcard index for an element that corresponds to the given type.
// Const version returns nullptr
Element* AliasDb::getWildcard(const TypePtr& type) const {
TORCH_INTERNAL_ASSERT(shouldAnnotate(type));
const auto kind = getMutableTypeKind(type);
TORCH_INTERNAL_ASSERT(kind);
if (!wildcardIndex_.count(*kind)) {
return nullptr;
}
return wildcardIndex_.at(*kind);
}
// Register `v` as a wildcard value.
void AliasDb::setWildcard(const Value* v) {
if (!shouldAnnotate(v)) {
return;
}
wildcards_.insert(v);
auto e = getOrCreateWildcard(v->type());
TORCH_INTERNAL_ASSERT(e != nullptr);
memoryDAG_->makePointerTo(getOrCreateElement(v), e);
}
void AliasDb::rebuildWriteCache() const {
@ -1399,17 +1161,5 @@ void AliasDb::rebuildWriteCache() const {
}
isWriteCacheStale_ = false;
}
ValueSet AliasDb::getMemoryLocations(const Value* v) const {
ValueSet ret;
if (!elementMap_.count(v)) {
return ret;
}
for (const auto el : elementMap_.at(v)->getMemoryLocations()) {
ret.insert(el->value);
}
return ret;
}
} // namespace jit
} // namespace torch

View File

@ -38,7 +38,7 @@ class AliasDb {
//
// These nodes are considered not safe to eliminate or mutate under any
// circumstances.
bool hasUntrackedEffects(Node* n) const;
bool writesToWildcard(Node* n) const;
// Does `n` write to an alias of one of the values in `vs`?
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
@ -78,17 +78,6 @@ class AliasDb {
if (a.empty() || b.empty()) {
return false;
}
// Short-circuit for special case: if any value is a wildcard, the two sets
// may alias
if (std::any_of(
a.cbegin(),
a.cend(),
[this](const Value* v) { return isWildcard(v); }) ||
std::any_of(b.cbegin(), b.cend(), [this](const Value* v) {
return isWildcard(v);
})) {
return true;
}
T<Element*> aElements;
for (const Value* v : a) {
@ -139,14 +128,11 @@ class AliasDb {
/**
* Write and read internal API
*/
// Does `n` write to any alias sets?
bool hasWrites(Node* n) const;
// Get all the values that `n` writes to.
// NOTE: this only returns values directly written to, not aliases thereof
//
// if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks
ValueSet getWrites(Node* n, bool recurseBlocks = false) const;
ValueSet getWrites(Block* b) const;
void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
// Do any nodes write to `v`s memory location?
@ -158,26 +144,11 @@ class AliasDb {
ValueSet getReads(Node* n, bool recurseBlocks = false) const;
void getReadsImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
// Does `n` write to a value that may alias one of the graph inputs?
bool writesToInputAlias(Node* n) const;
// Does `n` write to `v` or any aliases of `v`?
bool writesTo(Node* n, const Value* v) const;
/**
* Wildcard methods
*/
// is `v` a wildcard?
TORCH_API bool isWildcard(const Value* v) const;
// Register `v` as a wildcard value.
void setWildcard(const Value* v);
// Get all nodes that write to a wildcard value.
const std::unordered_set<Node*>& getWildcardWriters() const {
return wildcardWriters_;
}
// Does `n` use or write to any wildcard aliases?
bool hasWildcard(const Node* n) const;
// Returns nullopt if there are no wildcard nodes
c10::optional<const Node*> getLastWildcard() const;
// Is the element a wildcard or an unhandled container type,
// or does the element contain an element for which that's true
@ -220,38 +191,28 @@ class AliasDb {
static bool shouldAnnotate(const Value* v);
static bool shouldAnnotate(const TypePtr& type);
static c10::optional<TypeKind> getMutableTypeKind(const TypePtr& type);
static bool isContainerType(const TypePtr& type);
bool hasUsesAfter(Symbol alias, const Node* n) const;
bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const;
// Returns true iff `v` is part of the alias tracker/is a wildcard
bool isTracked(const Value* v) const;
// Get the values that represent the memory locations that `v` may point to.
// Return values are guaranteed to be "fresh" tensors--they do not point to
// anything else.
ValueSet getMemoryLocations(const Value* v) const;
std::shared_ptr<Graph> graph_;
std::unordered_map<const Graph*, const Node*> subgraphToOwner_;
// The points-to graph that stores aliasing relationships
std::unique_ptr<MemoryDAG> memoryDAG_;
// Mapping of values to MemoryDAG elements
std::unordered_map<const Value*, Element*> elementMap_;
// All wildcard elements (one for each unique mutable type).
std::map<TypeKind, Element*> wildcardIndex_;
Element* getWildcard(const TypePtr& type) const;
Element* getOrCreateWildcard(const TypePtr& type);
bool mayAliasWildcard(const Value* v) const;
// All values that may point to a wildcard value.
ValueSet wildcards_;
// All nodes that write to a wildcard
std::unordered_set<Node*> wildcardWriters_;
// All nodes that contain a wildcard
std::unordered_set<const Node*> wildcardNodes_;
// State for tracking write info
size_t numWrites_ = 0;
/**
* State for tracking write info.
*/
// Map of nodes to the values that they write to
std::unordered_map<Node*, ValueSet> writeIndex_;
// Set of all memory locations that may have been written to.
mutable std::unordered_set<const Element*> writeCache_;
mutable bool isWriteCacheStale_ = true;
void rebuildWriteCache() const;

View File

@ -202,7 +202,7 @@ class DeadCodeEliminator {
auto schema = node->maybeSchema();
return schema && schema->is_mutable();
} else {
return aliasDb_->hasUntrackedEffects(node);
return aliasDb_->writesToWildcard(node);
}
}