mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make wildcards alias only each other (#20670)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20670 ghimport-source-id: f5704c49fcb829e4668441f31fcf9305da22335c Reviewed By: jamesr66a Differential Revision: D15447567 Pulled By: suo fbshipit-source-id: 391236806838de2524410e26946456441e562470
This commit is contained in:
committed by
Facebook Github Bot
parent
90910fc6cb
commit
7aa3887f43
@ -9,15 +9,33 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Get a typekind that can be used as a key to distinguish different kinds of
|
||||
// mutable types. If the type is not mutable, return nullopt.
|
||||
//
|
||||
// TODO: We use these rules to divide wildcards into distinct "buckets", where
|
||||
// every wildcard that resolves to the same kind will alias each other. We can
|
||||
// introduce more granularity here (e.g. List[int] will never alias
|
||||
// List[float]).
|
||||
c10::optional<TypeKind> AliasDb::getMutableTypeKind(const TypePtr& type) {
|
||||
if (type->isSubtypeOf(TensorType::get())) {
|
||||
return TypeKind::TensorType;
|
||||
}
|
||||
|
||||
switch (type->kind()) {
|
||||
case TypeKind::ListType:
|
||||
case TypeKind::TupleType:
|
||||
case TypeKind::DictType:
|
||||
case TypeKind::ClassType:
|
||||
return type->kind();
|
||||
case TypeKind::OptionalType:
|
||||
return getMutableTypeKind(type->cast<OptionalType>()->getElementType());
|
||||
default:
|
||||
return c10::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
bool AliasDb::shouldAnnotate(const TypePtr& type) {
|
||||
return type->isSubtypeOf(TensorType::get()) ||
|
||||
type->kind() == TypeKind::ListType ||
|
||||
type->kind() == TypeKind::TupleType ||
|
||||
type->kind() == TypeKind::DictType || type->kind() == TypeKind::VarType ||
|
||||
type->kind() == TypeKind::FutureType ||
|
||||
type->kind() == TypeKind::ClassType ||
|
||||
(type->kind() == TypeKind::OptionalType &&
|
||||
shouldAnnotate(type->cast<OptionalType>()->getElementType()));
|
||||
return getMutableTypeKind(type) != c10::nullopt;
|
||||
}
|
||||
|
||||
// We only need to annotate values that either are mutable or could contain
|
||||
@ -43,48 +61,6 @@ AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
|
||||
analyze(graph_);
|
||||
}
|
||||
|
||||
// Does `n` use or write to any wildcard aliases?
|
||||
bool AliasDb::hasWildcard(const Node* n) const {
|
||||
for (const auto input : n->inputs()) {
|
||||
if (isWildcard(input)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto output : n->outputs()) {
|
||||
if (isWildcard(output)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AliasDb::isWildcard(const Value* v) const {
|
||||
return wildcards_.count(v);
|
||||
}
|
||||
|
||||
bool AliasDb::writesTo(Node* n, const Value* v) const {
|
||||
if (!shouldAnnotate(v) || v->mustBeNone()) {
|
||||
// This is a non-aliasing value
|
||||
return false;
|
||||
}
|
||||
if (isWildcard(v)) {
|
||||
return wildcardWriters_.count(n);
|
||||
}
|
||||
|
||||
if (!elementMap_.count(v) || !writeIndex_.count(n)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Can short-circuit if we know this node writes directly to `v`
|
||||
if (writeIndex_.at(n).count(v)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Otherwise, check if `v` may alias any of written-to values in `n`
|
||||
const auto vSet = ValueSet{v};
|
||||
return mayAlias(vSet, writeIndex_.at(n));
|
||||
}
|
||||
|
||||
bool AliasDb::hasWriters(const Node* n) const {
|
||||
for (const auto input : n->inputs()) {
|
||||
if (hasWriters(input)) {
|
||||
@ -100,22 +76,10 @@ bool AliasDb::hasWriters(const Node* n) const {
|
||||
}
|
||||
|
||||
bool AliasDb::hasWriters(const Value* v) const {
|
||||
if (isWildcard(v)) {
|
||||
// If `n` has a wildcard, any write in the graph may write to it.
|
||||
// So the only way we know there are no writers is if there are no writes
|
||||
// at all.
|
||||
return numWrites_ != 0;
|
||||
}
|
||||
|
||||
if (!elementMap_.count(v) || v->mustBeNone()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (wildcardWriters_.size() > 0) {
|
||||
// A write to the wildcard may be a write to any value.
|
||||
return true;
|
||||
}
|
||||
|
||||
if (isWriteCacheStale_) {
|
||||
rebuildWriteCache();
|
||||
}
|
||||
@ -129,44 +93,6 @@ bool AliasDb::hasWriters(const Value* v) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AliasDb::hasWrites(Node* n) const {
|
||||
for (const auto input : n->inputs()) {
|
||||
if (writesTo(n, input)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto output : n->outputs()) {
|
||||
if (writesTo(n, output)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AliasDb::writesToInputAlias(Node* n) const {
|
||||
std::vector<const Value*> writes;
|
||||
for (const auto input : n->inputs()) {
|
||||
if (writesTo(n, input)) {
|
||||
writes.push_back(input);
|
||||
}
|
||||
}
|
||||
for (const auto output : n->outputs()) {
|
||||
if (writesTo(n, output)) {
|
||||
writes.push_back(output);
|
||||
}
|
||||
}
|
||||
|
||||
// For all writes, check if the written value may alias a graph input
|
||||
return std::any_of(writes.cbegin(), writes.cend(), [&](const Value* v) {
|
||||
return std::any_of(
|
||||
graph_->inputs().cbegin(),
|
||||
graph_->inputs().cend(),
|
||||
[&](const Value* graphInput) {
|
||||
return shouldAnnotate(graphInput) && mayAlias(graphInput, v);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const {
|
||||
for (auto node : b->nodes()) {
|
||||
getWritesImpl(node, ret, recurseBlocks);
|
||||
@ -174,14 +100,10 @@ void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const {
|
||||
}
|
||||
|
||||
void AliasDb::getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks) const {
|
||||
for (const auto input : n->inputs()) {
|
||||
if (writesTo(n, input)) {
|
||||
ret.insert(input);
|
||||
}
|
||||
}
|
||||
for (const auto output : n->outputs()) {
|
||||
if (writesTo(n, output)) {
|
||||
ret.insert(output);
|
||||
if (writeIndex_.count(n)) {
|
||||
const auto& writes = writeIndex_.at(n);
|
||||
for (const auto write : writes) {
|
||||
ret.insert(write);
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,13 +114,6 @@ void AliasDb::getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks) const {
|
||||
}
|
||||
}
|
||||
|
||||
// Get all writes by all nodes in a block, recursively exploring sub-blocks
|
||||
ValueSet AliasDb::getWrites(Block* b) const {
|
||||
ValueSet writes;
|
||||
getWritesImpl(b, writes, /*recurseBlocks=*/true);
|
||||
return writes;
|
||||
}
|
||||
|
||||
// Does `n` write to an alias of one of the values in `vs`?
|
||||
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks)
|
||||
const {
|
||||
@ -236,6 +151,14 @@ ValueSet AliasDb::getReads(Node* n, bool recurseBlocks) const {
|
||||
return reads;
|
||||
}
|
||||
|
||||
static std::string getElementName(const Element* e) {
|
||||
if (e->value == nullptr) {
|
||||
return "WILDCARD";
|
||||
} else {
|
||||
return e->value->uniqueName();
|
||||
}
|
||||
}
|
||||
|
||||
void AliasDb::dump() const {
|
||||
std::cout << "\n===1. GRAPH===\n";
|
||||
graph_->dump();
|
||||
@ -244,28 +167,22 @@ void AliasDb::dump() const {
|
||||
for (const auto& ptrPair : elementMap_) {
|
||||
const auto element = ptrPair.second;
|
||||
if (element->pointsTo.size() > 0) {
|
||||
std::cout << element->value->uniqueName() << " points to: ";
|
||||
std::cout << getElementName(element) << " points to: ";
|
||||
for (const auto pointedTo : element->pointsTo) {
|
||||
std::cout << pointedTo->value->uniqueName() << ", ";
|
||||
std::cout << getElementName(pointedTo) << ", ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
if (element->contained_elements.size() > 0) {
|
||||
std::cout << element->value->uniqueName() << " contains: ";
|
||||
std::cout << getElementName(element) << " contains: ";
|
||||
for (const auto contained : element->contained_elements) {
|
||||
std::cout << contained->value->uniqueName() << ", ";
|
||||
std::cout << getElementName(contained) << ", ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n===3. WILDCARDS===\n";
|
||||
for (const auto wildcard : wildcards_) {
|
||||
std::cout << wildcard->uniqueName() << ", ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "\n===4. Writes===\n";
|
||||
std::cout << "\n===3. Writes===\n";
|
||||
for (const auto& pr : writeIndex_) {
|
||||
const auto node = pr.first;
|
||||
const auto& values = pr.second;
|
||||
@ -279,75 +196,10 @@ void AliasDb::dump() const {
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
// TODO: need to create a dummy "graph input alias" value in MemoryDAG for all
|
||||
// inputs of the same type to point to. Currently they all point to the first
|
||||
// element, which is technically wrong.
|
||||
void AliasDb::makeAllAlias(const std::vector<Value*>& values) {
|
||||
if (values.size() > 0) {
|
||||
giveFreshAlias(values[0]);
|
||||
}
|
||||
for (const auto value : values) {
|
||||
makePointerTo(value, values[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
|
||||
// Assign aliases to the graph's inputs, assuming that all inputs of a given
|
||||
// type may alias to each other.
|
||||
|
||||
// 1. Partition inputs by their type
|
||||
std::map<TypeKind, std::vector<Value*>> listTypes;
|
||||
std::unordered_map<TupleTypePtr, std::vector<Value*>> tupleTypes;
|
||||
std::unordered_map<DictTypePtr, std::vector<Value*>> dictTypes;
|
||||
std::unordered_map<ClassTypePtr, std::vector<Value*>> classTypes;
|
||||
std::vector<Value*> tensors;
|
||||
|
||||
for (auto input : graph->inputs()) {
|
||||
auto inputType = input->type();
|
||||
// unwrap optional types
|
||||
if (inputType->kind() == TypeKind::OptionalType) {
|
||||
inputType = inputType->cast<OptionalType>()->getElementType();
|
||||
}
|
||||
|
||||
if (inputType->isSubtypeOf(TensorType::get())) {
|
||||
tensors.push_back(input);
|
||||
} else if (inputType->kind() == TypeKind::ListType) {
|
||||
auto containedType = inputType->containedTypes().at(0);
|
||||
// All tensor subtypes may alias to each other, so we should consider all
|
||||
// lists of them to alias to each other.
|
||||
if (containedType->isSubtypeOf(TensorType::get())) {
|
||||
containedType = TensorType::get();
|
||||
}
|
||||
listTypes[containedType->kind()].push_back(input);
|
||||
} else if (inputType->kind() == TypeKind::TupleType) {
|
||||
auto tupleType = inputType->cast<TupleType>();
|
||||
tupleTypes[tupleType].push_back(input);
|
||||
} else if (inputType->kind() == TypeKind::DictType) {
|
||||
auto dictType = inputType->cast<DictType>();
|
||||
dictTypes[dictType].push_back(input);
|
||||
} else if (inputType->kind() == TypeKind::ClassType) {
|
||||
auto classType = inputType->cast<ClassType>();
|
||||
classTypes[classType].push_back(input);
|
||||
} else {
|
||||
AT_ASSERT(!shouldAnnotate(input));
|
||||
}
|
||||
setWildcard(input);
|
||||
}
|
||||
|
||||
// 2. Make all partitions alias each other
|
||||
for (const auto& pr : listTypes) {
|
||||
makeAllAlias(pr.second);
|
||||
}
|
||||
for (const auto& pr : tupleTypes) {
|
||||
makeAllAlias(pr.second);
|
||||
}
|
||||
for (const auto& pr : dictTypes) {
|
||||
makeAllAlias(pr.second);
|
||||
}
|
||||
for (const auto& pr : classTypes) {
|
||||
makeAllAlias(pr.second);
|
||||
}
|
||||
makeAllAlias(tensors);
|
||||
|
||||
analyze(graph->block());
|
||||
}
|
||||
|
||||
@ -359,11 +211,6 @@ void AliasDb::analyze(Block* block) {
|
||||
|
||||
void AliasDb::analyze(Node* node) {
|
||||
analyzeImpl(node);
|
||||
|
||||
// After analyzing, update the wildcard index
|
||||
if (hasWildcard(node)) {
|
||||
wildcardNodes_.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if analysis was run using
|
||||
@ -587,15 +434,7 @@ void AliasDb::registerWrite(const Value* v, Node* n) {
|
||||
// don't need to register a write if the value isn't mutable
|
||||
return;
|
||||
}
|
||||
|
||||
numWrites_++;
|
||||
|
||||
if (isWildcard(v)) {
|
||||
wildcardWriters_.insert(n);
|
||||
return;
|
||||
}
|
||||
|
||||
AT_ASSERT(elementMap_.count(v));
|
||||
TORCH_INTERNAL_ASSERT(elementMap_.count(v));
|
||||
writeIndex_[n].insert(v);
|
||||
}
|
||||
|
||||
@ -646,9 +485,6 @@ void AliasDb::analyzeGradOf(Node* node) {
|
||||
|
||||
void AliasDb::analyzeSubgraph(Node* node) {
|
||||
const auto subgraph = node->g(attr::Subgraph).get();
|
||||
|
||||
subgraphToOwner_.insert({subgraph, node});
|
||||
|
||||
const auto subgraphBlock = subgraph->block();
|
||||
mapAliases(subgraphBlock->inputs(), node->inputs());
|
||||
|
||||
@ -674,9 +510,7 @@ void AliasDb::analyzeCreator(Node* node) {
|
||||
// gives up and creates wildcards for everything.
|
||||
void AliasDb::analyzeExtractor(Node* node) {
|
||||
for (const auto output : node->outputs()) {
|
||||
if (shouldAnnotate(output)) {
|
||||
setWildcard(output);
|
||||
}
|
||||
setWildcard(output);
|
||||
}
|
||||
}
|
||||
|
||||
@ -687,16 +521,10 @@ void AliasDb::analyzeChunk(Node* node) {
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate aliasing and write information from the subgraph outputs to the
|
||||
// outputs of the corresponding aten::wait() calls, since that's where the
|
||||
// values will eventually emerge.
|
||||
void AliasDb::analyzeFork(Node* node) {
|
||||
const auto subgraph = node->g(attr::Subgraph).get();
|
||||
subgraphToOwner_.insert({subgraph, node});
|
||||
|
||||
const auto subgraphBlock = subgraph->block();
|
||||
mapAliases(subgraphBlock->inputs(), node->inputs());
|
||||
analyze(subgraphBlock);
|
||||
for (const auto input : node->inputs()) {
|
||||
setWildcard(input);
|
||||
}
|
||||
|
||||
// Give the future that the fork emits a fresh value
|
||||
for (const auto output : node->outputs()) {
|
||||
@ -705,51 +533,23 @@ void AliasDb::analyzeFork(Node* node) {
|
||||
}
|
||||
|
||||
void AliasDb::analyzeWait(Node* node) {
|
||||
const auto fut = node->input();
|
||||
AT_ASSERT(fut->type()->kind() == TypeKind::FutureType);
|
||||
|
||||
if (isWildcard(fut)) {
|
||||
for (const auto output : node->outputs()) {
|
||||
setWildcard(output);
|
||||
}
|
||||
return;
|
||||
TORCH_INTERNAL_ASSERT(node->kind() == aten::wait);
|
||||
for (const auto output : node->outputs()) {
|
||||
setWildcard(output);
|
||||
}
|
||||
|
||||
const auto originFuts = getMemoryLocations(fut);
|
||||
for (const auto originFut : originFuts) {
|
||||
const auto subgraphNode = originFut->node();
|
||||
|
||||
const auto subgraph = subgraphNode->g(attr::Subgraph).get();
|
||||
const auto subgraphWrites = getWrites(subgraph->block());
|
||||
|
||||
// Retrieve aliasing info from the subgraph
|
||||
mapAliases(node->outputs(), subgraph->outputs());
|
||||
|
||||
// Propagate write information to the `wait` node.
|
||||
//
|
||||
// We need to do this for all writes in the entire subgraph, so that we
|
||||
// disallow reorders past a call to "aten::wait".
|
||||
//
|
||||
// Consider the following Fork where the subgraph writes to %a:
|
||||
//
|
||||
// %c : Future[Tensor] = prim::Fork(%a, %b) <-- writes to %a
|
||||
// ...
|
||||
// aten::wait(%c)
|
||||
// aten::use(%a) <-- we can't move this node before the `wait` safely!
|
||||
//
|
||||
// Say we define the "live interval" of a fork the interval between the
|
||||
// `fork` and its first corresponding `wait` (inclusive).
|
||||
//
|
||||
// Any writes in the subgraph can happen at any point in the live interval,
|
||||
// so it's not safe to re-order any reads to those memory locations from
|
||||
// outside the live interval to inside.
|
||||
//
|
||||
// In reality, any reads *inside* the live interval are undefined behavior,
|
||||
// since the writes may or may not have been executed yet. But we'll let
|
||||
// users do that and shoot themselves in the foot for now.
|
||||
for (const auto write : subgraphWrites) {
|
||||
registerWrite(write, node);
|
||||
}
|
||||
// the forked subgraph that `wait` is waiting on may write to any of its
|
||||
// inputs. We don't have a reliable way of recovering the fork inputs, so
|
||||
// for safety we just register a write to every wildcard.
|
||||
for (const auto& pr : wildcardIndex_) {
|
||||
// TODO: Given the way the write query API is written, we can't regiser a
|
||||
// write directly against the wildcard element. So find a wildcard value in
|
||||
// the graph to write to.
|
||||
const auto el = pr.second;
|
||||
const auto& pointedFrom = el->pointedFrom;
|
||||
TORCH_INTERNAL_ASSERT(!pointedFrom.empty());
|
||||
const auto wildcardValue = (*pointedFrom.begin())->value;
|
||||
TORCH_INTERNAL_ASSERT(wildcardValue);
|
||||
registerWrite(wildcardValue, node);
|
||||
}
|
||||
}
|
||||
|
||||
@ -771,6 +571,9 @@ void AliasDb::analyzeSetAttr(Node* node) {
|
||||
const auto self = node->inputs().at(0);
|
||||
AT_ASSERT(self->type()->kind() == TypeKind::ClassType);
|
||||
registerWrite(self, node);
|
||||
// Also the value being set must become a wildcard.
|
||||
const auto newValue = node->inputs().at(1);
|
||||
setWildcard(newValue);
|
||||
}
|
||||
|
||||
// Custom ops may write to any input and produce wildcards
|
||||
@ -819,7 +622,7 @@ void AliasDb::analyzeBroadcastingChunk(Node* node) {
|
||||
}
|
||||
}
|
||||
|
||||
// Register the fact that `value` is a pointer to `to`
|
||||
// Register the fact that `from` is a pointer to `to`
|
||||
void AliasDb::makePointerTo(const Value* from, const Value* to) {
|
||||
if (!shouldAnnotate(from)) {
|
||||
AT_ASSERT(!shouldAnnotate(to));
|
||||
@ -840,13 +643,6 @@ void AliasDb::makePointerTo(const Value* from, const Value* to) {
|
||||
// At this point, we should be dealing with two mutable types.
|
||||
AT_ASSERT(shouldAnnotate(from) && shouldAnnotate(to));
|
||||
|
||||
// If either value is a wildcard, don't insert anything into the graph;
|
||||
// wildcards are tracked separately since they have different aliasing rules.
|
||||
if (isWildcard(to) || isWildcard(from)) {
|
||||
setWildcard(from);
|
||||
return;
|
||||
}
|
||||
|
||||
auto fromEl = getOrCreateElement(from);
|
||||
auto toEl = getOrCreateElement(to);
|
||||
|
||||
@ -860,11 +656,6 @@ void AliasDb::addToContainedElements(
|
||||
return;
|
||||
}
|
||||
|
||||
// wildcards tracked separately
|
||||
if (isWildcard(elem)) {
|
||||
return;
|
||||
}
|
||||
|
||||
AT_ASSERT(isContainerType(container->type()));
|
||||
|
||||
auto elemEl = getOrCreateElement(elem);
|
||||
@ -878,18 +669,10 @@ bool AliasDb::mayAlias(const Value* a, const Value* b) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isWildcard(a) || isWildcard(b)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b));
|
||||
}
|
||||
|
||||
bool AliasDb::cannotCheckAliasContainment(const Value* elem) const {
|
||||
if (isWildcard(elem)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (isContainerType(elem->type())) {
|
||||
if (elem->node()->kind() != prim::TupleConstruct) {
|
||||
return true;
|
||||
@ -952,7 +735,7 @@ void AliasDb::giveFreshAlias(const Value* value) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isTracked(value)) {
|
||||
if (elementMap_.count(value)) {
|
||||
// Inside a loop, we may have given a fresh alias to this value already, so
|
||||
// skip
|
||||
return;
|
||||
@ -962,16 +745,12 @@ void AliasDb::giveFreshAlias(const Value* value) {
|
||||
}
|
||||
|
||||
Element* AliasDb::getOrCreateElement(const Value* value) {
|
||||
if (!isTracked(value)) {
|
||||
if (!elementMap_.count(value)) {
|
||||
giveFreshAlias(value);
|
||||
}
|
||||
return elementMap_.at(value);
|
||||
}
|
||||
|
||||
bool AliasDb::isTracked(const Value* v) const {
|
||||
return isWildcard(v) || elementMap_.count(v);
|
||||
}
|
||||
|
||||
bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
|
||||
return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false);
|
||||
}
|
||||
@ -1015,9 +794,6 @@ class AliasDb::WorkingSet {
|
||||
for (const auto& read : aliasDb_.getReads(n, /*recurseBlocks=*/true)) {
|
||||
reads_.insert(read);
|
||||
}
|
||||
if (aliasDb_.hasWildcard(n)) {
|
||||
numWildcards_++;
|
||||
}
|
||||
}
|
||||
|
||||
void eraseMover() {
|
||||
@ -1042,9 +818,6 @@ class AliasDb::WorkingSet {
|
||||
reads_.erase(it);
|
||||
}
|
||||
}
|
||||
if (aliasDb_.hasWildcard(mover)) {
|
||||
numWildcards_--;
|
||||
}
|
||||
nodes_.pop_front();
|
||||
}
|
||||
|
||||
@ -1071,18 +844,6 @@ class AliasDb::WorkingSet {
|
||||
}
|
||||
|
||||
bool hasMutabilityDependency(Node* n) const {
|
||||
// 1. Handle wildcard dependencies:
|
||||
// If the working set has a wildcard, `n` can't write to anything.
|
||||
if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If `n` has a wildcard, the working set can't write to anything.
|
||||
if (aliasDb_.hasWildcard(n) && writes_.size() > 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 2. Handle regular mutable dependencies
|
||||
// Check that `n` does not write to anything used by the working set
|
||||
const auto nWrites = aliasDb_.getWrites(n, /*recurseBlocks=*/true);
|
||||
if (aliasDb_.mayAlias(nWrites, reads_)) {
|
||||
@ -1158,7 +919,6 @@ class AliasDb::WorkingSet {
|
||||
// Values written to by the working set => number of nodes writing to value
|
||||
std::unordered_multiset<const Value*> writes_;
|
||||
std::unordered_multiset<const Value*> reads_;
|
||||
size_t numWildcards_ = 0;
|
||||
};
|
||||
|
||||
// Try to move `toMove` before/after `movePoint` while preserving value
|
||||
@ -1274,53 +1034,16 @@ void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
|
||||
}
|
||||
}
|
||||
|
||||
bool AliasDb::hasUntrackedEffects(Node* node) const {
|
||||
bool touchesWildcard = false;
|
||||
if (const auto lastWildcard = getLastWildcard()) {
|
||||
touchesWildcard = hasWrites(node) &&
|
||||
(isBeforeSameGraph(node, *lastWildcard) || node == *lastWildcard);
|
||||
bool AliasDb::writesToWildcard(Node* n) const {
|
||||
if (!writeIndex_.count(n)) {
|
||||
return false;
|
||||
}
|
||||
const auto& writes = writeIndex_.at(n);
|
||||
|
||||
return writesToInputAlias(node) || touchesWildcard;
|
||||
}
|
||||
|
||||
// Nodes must be in the same graph in order to do `isBefore` or `isAfter`. This
|
||||
// traverses the subgraph "chain" upward until we find two nodes that share an
|
||||
// owning graph.
|
||||
//
|
||||
// NOTE: this is n^2 in subgraph depth. Right now the maximum depth is like 2,
|
||||
// but if we ever do huge nested subgraphs we'll need to reconsider this.
|
||||
bool AliasDb::isBeforeSameGraph(const Node* a, const Node* b) const {
|
||||
auto lhs = a;
|
||||
while (true) {
|
||||
auto rhs = b;
|
||||
while (true) {
|
||||
if (lhs->owningGraph() == rhs->owningGraph()) {
|
||||
return lhs->isBefore(rhs);
|
||||
}
|
||||
if (!subgraphToOwner_.count(rhs->owningGraph())) {
|
||||
break;
|
||||
}
|
||||
rhs = subgraphToOwner_.at(rhs->owningGraph());
|
||||
}
|
||||
if (!subgraphToOwner_.count(lhs->owningGraph())) {
|
||||
break;
|
||||
}
|
||||
lhs = subgraphToOwner_.at(lhs->owningGraph());
|
||||
}
|
||||
AT_ASSERT(false);
|
||||
}
|
||||
|
||||
c10::optional<const Node*> AliasDb::getLastWildcard() const {
|
||||
auto it = std::max_element(
|
||||
wildcardNodes_.cbegin(),
|
||||
wildcardNodes_.cend(),
|
||||
[this](const Node* a, const Node* b) { return isBeforeSameGraph(a, b); });
|
||||
if (it != wildcardNodes_.end()) {
|
||||
return *it;
|
||||
} else {
|
||||
return c10::nullopt;
|
||||
}
|
||||
// For all writes, check if the written value is a wildcard
|
||||
return std::any_of(writes.cbegin(), writes.cend(), [&](const Value* v) {
|
||||
return mayAliasWildcard(v);
|
||||
});
|
||||
}
|
||||
|
||||
bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
||||
@ -1379,12 +1102,51 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
||||
return handled.count(symbol) || purposefully_not_handled.count(symbol);
|
||||
}
|
||||
|
||||
bool AliasDb::mayAliasWildcard(const Value* v) const {
|
||||
if (!shouldAnnotate(v)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto e = getWildcard(v->type())) {
|
||||
return memoryDAG_->mayAlias(elementMap_.at(v), e);
|
||||
}
|
||||
// There were no wildcards of this type, so return false.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Search the wildcard index for an element that corresponds to the given type.
|
||||
Element* AliasDb::getOrCreateWildcard(const TypePtr& type) {
|
||||
TORCH_INTERNAL_ASSERT(shouldAnnotate(type));
|
||||
const auto kind = getMutableTypeKind(type);
|
||||
TORCH_INTERNAL_ASSERT(kind);
|
||||
|
||||
if (!wildcardIndex_.count(*kind)) {
|
||||
// create a new empty Element to stand in for the wildcard set.
|
||||
wildcardIndex_.emplace(*kind, memoryDAG_->makeFreshValue(nullptr));
|
||||
}
|
||||
return wildcardIndex_.at(*kind);
|
||||
}
|
||||
|
||||
// Search the wildcard index for an element that corresponds to the given type.
|
||||
// Const version returns nullptr
|
||||
Element* AliasDb::getWildcard(const TypePtr& type) const {
|
||||
TORCH_INTERNAL_ASSERT(shouldAnnotate(type));
|
||||
const auto kind = getMutableTypeKind(type);
|
||||
TORCH_INTERNAL_ASSERT(kind);
|
||||
if (!wildcardIndex_.count(*kind)) {
|
||||
return nullptr;
|
||||
}
|
||||
return wildcardIndex_.at(*kind);
|
||||
}
|
||||
|
||||
// Register `v` as a wildcard value.
|
||||
void AliasDb::setWildcard(const Value* v) {
|
||||
if (!shouldAnnotate(v)) {
|
||||
return;
|
||||
}
|
||||
wildcards_.insert(v);
|
||||
auto e = getOrCreateWildcard(v->type());
|
||||
TORCH_INTERNAL_ASSERT(e != nullptr);
|
||||
memoryDAG_->makePointerTo(getOrCreateElement(v), e);
|
||||
}
|
||||
|
||||
void AliasDb::rebuildWriteCache() const {
|
||||
@ -1399,17 +1161,5 @@ void AliasDb::rebuildWriteCache() const {
|
||||
}
|
||||
isWriteCacheStale_ = false;
|
||||
}
|
||||
|
||||
ValueSet AliasDb::getMemoryLocations(const Value* v) const {
|
||||
ValueSet ret;
|
||||
if (!elementMap_.count(v)) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
for (const auto el : elementMap_.at(v)->getMemoryLocations()) {
|
||||
ret.insert(el->value);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -38,7 +38,7 @@ class AliasDb {
|
||||
//
|
||||
// These nodes are considered not safe to eliminate or mutate under any
|
||||
// circumstances.
|
||||
bool hasUntrackedEffects(Node* n) const;
|
||||
bool writesToWildcard(Node* n) const;
|
||||
|
||||
// Does `n` write to an alias of one of the values in `vs`?
|
||||
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
|
||||
@ -78,17 +78,6 @@ class AliasDb {
|
||||
if (a.empty() || b.empty()) {
|
||||
return false;
|
||||
}
|
||||
// Short-circuit for special case: if any value is a wildcard, the two sets
|
||||
// may alias
|
||||
if (std::any_of(
|
||||
a.cbegin(),
|
||||
a.cend(),
|
||||
[this](const Value* v) { return isWildcard(v); }) ||
|
||||
std::any_of(b.cbegin(), b.cend(), [this](const Value* v) {
|
||||
return isWildcard(v);
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
|
||||
T<Element*> aElements;
|
||||
for (const Value* v : a) {
|
||||
@ -139,14 +128,11 @@ class AliasDb {
|
||||
/**
|
||||
* Write and read internal API
|
||||
*/
|
||||
// Does `n` write to any alias sets?
|
||||
bool hasWrites(Node* n) const;
|
||||
// Get all the values that `n` writes to.
|
||||
// NOTE: this only returns values directly written to, not aliases thereof
|
||||
//
|
||||
// if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks
|
||||
ValueSet getWrites(Node* n, bool recurseBlocks = false) const;
|
||||
ValueSet getWrites(Block* b) const;
|
||||
void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
|
||||
void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
|
||||
// Do any nodes write to `v`s memory location?
|
||||
@ -158,26 +144,11 @@ class AliasDb {
|
||||
ValueSet getReads(Node* n, bool recurseBlocks = false) const;
|
||||
void getReadsImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
|
||||
|
||||
// Does `n` write to a value that may alias one of the graph inputs?
|
||||
bool writesToInputAlias(Node* n) const;
|
||||
// Does `n` write to `v` or any aliases of `v`?
|
||||
bool writesTo(Node* n, const Value* v) const;
|
||||
|
||||
/**
|
||||
* Wildcard methods
|
||||
*/
|
||||
// is `v` a wildcard?
|
||||
TORCH_API bool isWildcard(const Value* v) const;
|
||||
// Register `v` as a wildcard value.
|
||||
void setWildcard(const Value* v);
|
||||
// Get all nodes that write to a wildcard value.
|
||||
const std::unordered_set<Node*>& getWildcardWriters() const {
|
||||
return wildcardWriters_;
|
||||
}
|
||||
// Does `n` use or write to any wildcard aliases?
|
||||
bool hasWildcard(const Node* n) const;
|
||||
// Returns nullopt if there are no wildcard nodes
|
||||
c10::optional<const Node*> getLastWildcard() const;
|
||||
|
||||
// Is the element a wildcard or an unhandled container type,
|
||||
// or does the element contain an element for which that's true
|
||||
@ -220,38 +191,28 @@ class AliasDb {
|
||||
|
||||
static bool shouldAnnotate(const Value* v);
|
||||
static bool shouldAnnotate(const TypePtr& type);
|
||||
static c10::optional<TypeKind> getMutableTypeKind(const TypePtr& type);
|
||||
|
||||
static bool isContainerType(const TypePtr& type);
|
||||
|
||||
bool hasUsesAfter(Symbol alias, const Node* n) const;
|
||||
bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const;
|
||||
|
||||
// Returns true iff `v` is part of the alias tracker/is a wildcard
|
||||
bool isTracked(const Value* v) const;
|
||||
|
||||
// Get the values that represent the memory locations that `v` may point to.
|
||||
// Return values are guaranteed to be "fresh" tensors--they do not point to
|
||||
// anything else.
|
||||
ValueSet getMemoryLocations(const Value* v) const;
|
||||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
std::unordered_map<const Graph*, const Node*> subgraphToOwner_;
|
||||
|
||||
// The points-to graph that stores aliasing relationships
|
||||
std::unique_ptr<MemoryDAG> memoryDAG_;
|
||||
// Mapping of values to MemoryDAG elements
|
||||
std::unordered_map<const Value*, Element*> elementMap_;
|
||||
// All wildcard elements (one for each unique mutable type).
|
||||
std::map<TypeKind, Element*> wildcardIndex_;
|
||||
Element* getWildcard(const TypePtr& type) const;
|
||||
Element* getOrCreateWildcard(const TypePtr& type);
|
||||
bool mayAliasWildcard(const Value* v) const;
|
||||
|
||||
// All values that may point to a wildcard value.
|
||||
ValueSet wildcards_;
|
||||
// All nodes that write to a wildcard
|
||||
std::unordered_set<Node*> wildcardWriters_;
|
||||
// All nodes that contain a wildcard
|
||||
std::unordered_set<const Node*> wildcardNodes_;
|
||||
|
||||
// State for tracking write info
|
||||
size_t numWrites_ = 0;
|
||||
/**
|
||||
* State for tracking write info.
|
||||
*/
|
||||
// Map of nodes to the values that they write to
|
||||
std::unordered_map<Node*, ValueSet> writeIndex_;
|
||||
// Set of all memory locations that may have been written to.
|
||||
mutable std::unordered_set<const Element*> writeCache_;
|
||||
mutable bool isWriteCacheStale_ = true;
|
||||
void rebuildWriteCache() const;
|
||||
|
@ -202,7 +202,7 @@ class DeadCodeEliminator {
|
||||
auto schema = node->maybeSchema();
|
||||
return schema && schema->is_mutable();
|
||||
} else {
|
||||
return aliasDb_->hasUntrackedEffects(node);
|
||||
return aliasDb_->writesToWildcard(node);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user