mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Separating CUDA fuser from CPU fuser. 1. New node in IR - prim::CudaFusionGroup: This enables the cuda fuser to co-exist along side the old fuser. Allows us to incrementally build and expand cuda fuser. 2. copied FuseGraph optimization passes to CudaFuserGraph: We will re-factor & reuse Chunk/Concat in the old fuser logic, which is handled in the optimization pass at this moment. Unfortunately many code in the pass is tightly binded with the legacy fuser, which makes code sharing difficult. The CudaFusionGraph will support only a subset of operations comparing to legacy fuser (CUDA only). It is registered as a custom pass post fusion via ```torch._C._jit_register_cuda_fuser()``` To have it in effect, you should also turn off fusion on GPU via ```torch._C._jit_override_can_fuse_on_gpu(False)``` 3. We don't have codegen in this PR yet (WIP). Currently we just fall back to the old fuser. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33527 Differential Revision: D20171598 Pulled By: ZolotukhinM fbshipit-source-id: 9a3c0f06f46da7eaa80ae7551c04869f5b03ef71
1357 lines
42 KiB
C++
1357 lines
42 KiB
C++
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/csrc/utils/memory.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
// For any mutable type, map it to a type such that all other types which it can
|
|
// alias will be mapped to the same type. This function follows a similar logic
|
|
// to `unifyTypes` because any two mutable types which can be unified
|
|
// can alias each other.
|
|
// getMutableTypePtr(Optional[List[int]]) == getMutableTypePtr([List[int]])
|
|
// If a type is not mutable, return nullopt
|
|
c10::optional<TypePtr> getMutableTypePtr(const TypePtr& type) {
|
|
switch (type->kind()) {
|
|
case TypeKind::ListType:
|
|
case TypeKind::DictType:
|
|
case TypeKind::ClassType:
|
|
case TypeKind::TensorType:
|
|
return unshapedType(type);
|
|
case TypeKind::OptionalType:
|
|
return getMutableTypePtr(type->cast<OptionalType>()->getElementType());
|
|
case TypeKind::FutureType: {
|
|
if (auto elem = getMutableTypePtr(type->cast<FutureType>()->getElementType())) {
|
|
return FutureType::create(*elem);
|
|
}
|
|
return c10::nullopt;
|
|
}
|
|
case TypeKind::TupleType: {
|
|
std::vector<TypePtr> mutable_types;
|
|
for (const auto& elem : type->expect<TupleType>()->elements()) {
|
|
if (auto mut_elem = getMutableTypePtr(elem)) {
|
|
mutable_types.push_back(*mut_elem);
|
|
}
|
|
}
|
|
if (mutable_types.size() == 0) {
|
|
return c10::nullopt;
|
|
} else {
|
|
return TupleType::create(mutable_types);
|
|
}
|
|
}
|
|
default:
|
|
return c10::nullopt;
|
|
}
|
|
}
|
|
|
|
bool AliasDb::mutableType(const TypePtr& type) {
|
|
return getMutableTypePtr(type) != c10::nullopt;
|
|
}
|
|
|
|
// We only need to annotate values that either are mutable or could contain
|
|
// mutable types.
|
|
bool AliasDb::mutableType(const Value* v) {
|
|
return mutableType(v->type());
|
|
}
|
|
|
|
bool AliasDb::isContainerType(const TypePtr& type) {
|
|
auto mut_type = getMutableTypePtr(type);
|
|
return mut_type && (*mut_type)->containedTypes().size() > 0;
|
|
}
|
|
|
|
AliasDb::~AliasDb() = default;
|
|
|
|
AliasDb::AliasDb(std::shared_ptr<Graph> graph, bool isFrozen)
|
|
: graph_(std::move(graph)), isFrozen_(isFrozen) {
|
|
memoryDAG_ = torch::make_unique<MemoryDAG>();
|
|
analyze(graph_);
|
|
GRAPH_DEBUG(toString());
|
|
}
|
|
|
|
bool AliasDb::isMutable(Node* n) const {
|
|
ValueSet vs;
|
|
for (const auto input : n->inputs()) {
|
|
vs.insert(input);
|
|
}
|
|
return writesToAlias(n, vs);
|
|
}
|
|
|
|
bool AliasDb::hasInputWriters(const Node* n) const {
|
|
for (const auto input : n->inputs()) {
|
|
if (hasWriters(input)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool AliasDb::hasOutputWriters(const Node* n) const {
|
|
for (const auto output : n->outputs()) {
|
|
if (hasWriters(output)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool AliasDb::hasWriters(const Node* n) const {
|
|
return hasInputWriters(n) || hasOutputWriters(n);
|
|
}
|
|
|
|
bool AliasDb::hasWriters(const Value* v) const {
|
|
if (v->mustBeNone()) {
|
|
return false;
|
|
}
|
|
|
|
auto it = elementMap_.find(v);
|
|
if (it == elementMap_.end()) {
|
|
return false;
|
|
}
|
|
|
|
if (isWriteCacheStale_) {
|
|
rebuildWriteCache();
|
|
}
|
|
const auto& el = it->second;
|
|
return writeCache_.intersects(el->getMemoryLocations());
|
|
}
|
|
|
|
void AliasDb::getWritesImpl(Node* n, MemoryLocations& ret) const {
|
|
if (writeIndex_.count(n)) {
|
|
const auto& writes = writeIndex_.at(n);
|
|
ret |= writes;
|
|
}
|
|
|
|
for (auto block : n->blocks()) {
|
|
for (auto node : block->nodes()) {
|
|
getWritesImpl(node, ret);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Does `n` write to an alias of one of the values in `vs`?
|
|
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
|
|
const auto writtenTo = getWrites(n);
|
|
if (writtenTo.empty()) {
|
|
return false;
|
|
}
|
|
|
|
MemoryLocations locs;
|
|
for (const auto v : vs) {
|
|
auto it = elementMap_.find(v);
|
|
if (it != elementMap_.end()) {
|
|
const auto& vlocs = it->second->getMemoryLocations();
|
|
if (writtenTo.intersects(vlocs)) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
MemoryLocations AliasDb::getWrites(Node* n) const {
|
|
MemoryLocations writes;
|
|
getWritesImpl(n, writes);
|
|
return writes;
|
|
}
|
|
|
|
void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const {
|
|
for (const auto input : n->inputs()) {
|
|
auto it = elementMap_.find(input);
|
|
if (it != elementMap_.end()) {
|
|
auto el = it->second;
|
|
// Add all memory locations this element may alias.
|
|
ret |= el->getMemoryLocations();
|
|
|
|
// We also consider memory locations of contained values to be "read".
|
|
for (const auto& type : input->type()->containedTypes()) {
|
|
if (auto wildcard = getWildcard(type)) {
|
|
ret |= wildcard->getMemoryLocations();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto block : n->blocks()) {
|
|
for (auto node : block->nodes()) {
|
|
getReadsImpl(node, ret);
|
|
}
|
|
}
|
|
}
|
|
|
|
MemoryLocations AliasDb::getReads(Node* n) const {
|
|
MemoryLocations reads;
|
|
getReadsImpl(n, reads);
|
|
return reads;
|
|
}
|
|
|
|
std::string AliasDb::getElementName(const Element* e) const {
|
|
if (e->value == nullptr) {
|
|
// not the most efficient way, but given the fact there are
|
|
// not too many types and even fewer of them will end up in
|
|
// wildcardIndex_, we should be fine with a linear search
|
|
// each time we hit a wildcard leaf
|
|
for (const auto& ent : wildcardIndex_) {
|
|
if (ent.second == e) {
|
|
return std::string("WILDCARD for type ") + ent.first->str();
|
|
}
|
|
}
|
|
return "WILDCARD";
|
|
} else {
|
|
return e->value->debugName();
|
|
}
|
|
}
|
|
|
|
void AliasDb::dump() const {
|
|
std::cout << toString();
|
|
}
|
|
|
|
std::string AliasDb::toString() const {
|
|
std::stringstream ss{};
|
|
|
|
ss << "\n===1. GRAPH===\n";
|
|
ss << graph_->toString();
|
|
|
|
ss << "\n===2. ALIAS DB===\n";
|
|
for (const auto& ptrPair : elementMap_) {
|
|
const auto element = ptrPair.second;
|
|
if (!element->pointsTo.empty()) {
|
|
ss << getElementName(element) << " points to: ";
|
|
for (const auto pointedTo : element->pointsTo) {
|
|
ss << getElementName(memoryDAG_->fromIndex(pointedTo)) << ", ";
|
|
}
|
|
ss << "\n";
|
|
}
|
|
if (!element->containedElements.empty()) {
|
|
ss << getElementName(element) << " contains: ";
|
|
for (const auto contained : element->containedElements) {
|
|
ss << getElementName(memoryDAG_->fromIndex(contained)) << ", ";
|
|
}
|
|
ss << "\n";
|
|
}
|
|
}
|
|
|
|
ss << "\n===3. Writes===\n";
|
|
for (const auto& pr : writeIndex_) {
|
|
const auto node = pr.first;
|
|
const auto& values = pr.second;
|
|
ss << *node;
|
|
ss << " ";
|
|
for (const auto value : values) {
|
|
ss << getElementName(memoryDAG_->fromIndex(value)) << ", ";
|
|
}
|
|
ss << "\n";
|
|
}
|
|
ss << "\n";
|
|
return ss.str();
|
|
}
|
|
|
|
void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
|
|
for (auto input : graph->inputs()) {
|
|
setWildcard(input);
|
|
}
|
|
analyze(graph->block());
|
|
}
|
|
|
|
void AliasDb::analyze(Block* block) {
|
|
for (auto node : block->nodes()) {
|
|
analyze(node);
|
|
}
|
|
}
|
|
|
|
void AliasDb::analyze(Node* node) {
|
|
analyzeImpl(node);
|
|
}
|
|
|
|
// Returns true if analysis was run using
|
|
// the registered analyzer.
|
|
bool AliasDb::tryRegisteredAnalysis(Node* node) {
|
|
const Operator& op = node->getOperator();
|
|
auto analysis = op.aliasAnalysisKind();
|
|
if (AliasAnalysisKind::PURE_FUNCTION == analysis) {
|
|
analyzeCreator(node);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// The basic strategy is:
|
|
// 1. Retrieve alias information for every input.
|
|
// 2. Use the node's schema's alias annotations to propgagate alias/write
|
|
// information to the outputs. For unschematized nodes, a special analyzer
|
|
// will have to be handwritten.
|
|
void AliasDb::analyzeImpl(Node* node) {
|
|
auto op = node->maybeOperator();
|
|
const bool hasSpecialCase = aliasAnalysisHasSpecialCaseFor(node->kind());
|
|
if (op) {
|
|
const auto analysis = op->aliasAnalysisKind();
|
|
|
|
const bool registeredAsSpecialCase =
|
|
analysis == AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
|
|
if (C10_UNLIKELY(registeredAsSpecialCase && !hasSpecialCase)) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"Op ",
|
|
node->kind().toDisplayString(),
|
|
" is registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but doesn't have a special case.");
|
|
} else if (C10_UNLIKELY(!registeredAsSpecialCase && hasSpecialCase)) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"Op ",
|
|
node->kind().toDisplayString(),
|
|
" has a special case and should be registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but is registered with ",
|
|
c10::toString(analysis));
|
|
}
|
|
} else {
|
|
if (!hasSpecialCase) {
|
|
std::ostringstream oss;
|
|
for (const auto input : node->inputs()) {
|
|
oss << input->type()->str() << ", ";
|
|
}
|
|
TORCH_INTERNAL_ASSERT(
|
|
0,
|
|
"We don't have an op for ",
|
|
node->kind().toDisplayString(),
|
|
" but it isn't a special case. ",
|
|
"Argument types: ", oss.str());
|
|
}
|
|
}
|
|
|
|
// These nodes are not schematized, so we need to handle them specially
|
|
switch (node->kind()) {
|
|
case prim::If:
|
|
return analyzeIf(node);
|
|
case prim::Loop:
|
|
return analyzeLoop(node);
|
|
case prim::FusionGroup:
|
|
case prim::CudaFusionGroup:
|
|
case prim::DifferentiableGraph:
|
|
return analyzeSubgraph(node);
|
|
case prim::fork:
|
|
return analyzeFork(node);
|
|
case aten::wait:
|
|
return analyzeWait(node);
|
|
case prim::rpc_async:
|
|
return analyzeRpcAsync(node);
|
|
case prim::GradOf:
|
|
return analyzeGradOf(node);
|
|
case prim::Constant:
|
|
case prim::AutogradZero:
|
|
case prim::AutogradAdd:
|
|
case prim::FusedConcat:
|
|
case prim::MMTreeReduce:
|
|
case prim::MMBatchSide:
|
|
case prim::BroadcastSizes:
|
|
case prim::ChunkSizes:
|
|
case prim::Function:
|
|
case prim::CreateObject:
|
|
case prim::tolist:
|
|
return analyzeCreator(node);
|
|
case prim::TupleConstruct:
|
|
case prim::DictConstruct:
|
|
case prim::ListConstruct:
|
|
return analyzeContainerConstruct(node);
|
|
case prim::TupleUnpack:
|
|
case prim::TupleIndex:
|
|
case prim::TupleSlice:
|
|
case prim::ListUnpack:
|
|
case prim::PythonOp:
|
|
case prim::GetAttr:
|
|
if (isFrozen_ && node->kind() == prim::GetAttr)
|
|
return analyzeCreator(node);
|
|
return analyzeExtractor(node);
|
|
case prim::unchecked_cast:
|
|
return makePointerTo(node->output(), node->input());
|
|
case prim::ConstantChunk:
|
|
return analyzeChunk(node);
|
|
case prim::BroadcastingChunk:
|
|
return analyzeBroadcastingChunk(node);
|
|
case prim::SetAttr:
|
|
return analyzeSetAttr(node);
|
|
case prim::profile:
|
|
if (node->inputs().size() > 0) {
|
|
makePointerTo(node->output(), node->inputs().at(0));
|
|
}
|
|
return;
|
|
case prim::BailOut:
|
|
TORCH_INTERNAL_ASSERT(node->inputs().at(0)->node()->kind() ==
|
|
prim::BailoutTemplate);
|
|
makePointerTo(node->output(), node->inputs().at(1));
|
|
return;
|
|
case prim::Guard:
|
|
makePointerTo(node->output(), node->inputs().at(0));
|
|
return;
|
|
case prim::CallFunction:
|
|
case prim::CallMethod:
|
|
// TODO: this can be improved with summarizes of what the function does
|
|
// for now we assume the worst
|
|
// NB: update safeToChangeAliasingRelationship if changed
|
|
return analyzeConservative(node);
|
|
case prim::Uninitialized:
|
|
giveFreshAlias(node->output());
|
|
return;
|
|
case prim::Print:
|
|
case prim::isinstance:
|
|
// These ops do nothing
|
|
return;
|
|
default:
|
|
if (tryRegisteredAnalysis(node)) {
|
|
return;
|
|
}
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(op, "We should have an op schema if we get to here");
|
|
const AliasAnalysisKind analysis = op->aliasAnalysisKind();
|
|
TORCH_INTERNAL_ASSERT(
|
|
analysis != AliasAnalysisKind::INTERNAL_SPECIAL_CASE &&
|
|
!aliasAnalysisHasSpecialCaseFor(node->kind()),
|
|
"Special cases should be handled already if we're here.");
|
|
|
|
if (node->kind().is_aten() || node->kind().is_prim()) {
|
|
// TODO There is nothing in the system that relies on aten:: and prim::
|
|
// ops using AliasAnalysisKind::FROM_SCHEMA or AliasAnalysisKind::INTERNAL_SPECIAL_CASE,
|
|
// but this is the intended behavior for all current ops and a good error check.
|
|
// We can consider lifting this constraint later if we have a use case for it.
|
|
TORCH_INTERNAL_ASSERT(
|
|
analysis == AliasAnalysisKind::FROM_SCHEMA ||
|
|
analysis == AliasAnalysisKind::CONSERVATIVE,
|
|
"aten:: and prim:: operators should use AliasAnalysisKind::FROM_SCHEMA or "
|
|
"AliasAnalysisKind::CONSERVATIVE(if really necessary), but ",
|
|
node->kind().toDisplayString(),
|
|
" doesn't. Note: Ideally, prim:: operators actually shouldn't have a schema ",
|
|
"and then use AliasAnalysisKind::INTERNAL_SPECIAL_CASE instead.");
|
|
}
|
|
|
|
if (analysis == AliasAnalysisKind::CONSERVATIVE) {
|
|
// TODO A previous implementation of alias analysis always accessed
|
|
// node->schema , which cause the schema caches in the Node class to be
|
|
// filled for the full graph. Unfortunately, our JIT passes started relying
|
|
// on that, so we need to keep doing this. Details: in
|
|
// caffe2/torch/onnx/utils.py, _jit_pass_onnx is called on an invalid JIT
|
|
// graph because we called _jit_pass_erase_number_types right before and
|
|
// ints are now Tensors instead. So if _jit_pass_onnx tries to look up
|
|
// operator schemas, it will crash. However, _jit_pass_constant_propagation,
|
|
// which is called before it, runs alias analysis and prefills the schema
|
|
// cache in the all Node instances so that _jit_pass_onnx doesn't look up
|
|
// operators to get the schemas anymore. We should fix this.
|
|
node->schema(); // fill the schema cache in the Node class
|
|
|
|
return analyzeConservative(node);
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
analysis == AliasAnalysisKind::FROM_SCHEMA,
|
|
"AliasAnalysisKind::CONSERVATIVE/PURE_FUNCTION/INTERNAL_SPECIAL_CASE should already have been handled above");
|
|
const auto& schema = node->schema();
|
|
|
|
// Bind the schema's "formal" alias annotation to the actual values those
|
|
// schema arguments represent
|
|
std::unordered_map<Symbol, Value*> formalToActual;
|
|
for (size_t i = 0; i < schema.arguments().size(); i++) {
|
|
const auto& formal = schema.arguments()[i].alias_info();
|
|
const auto& actualValue = node->inputs().at(i);
|
|
// Skip if there's no alias annotation
|
|
if (!formal) {
|
|
continue;
|
|
}
|
|
|
|
// If this type cannot alias, continue. Can occur with a VarType schema
|
|
if (!mutableType(actualValue)) {
|
|
continue;
|
|
}
|
|
|
|
// Do sanity checks on the alias annotation
|
|
TORCH_INTERNAL_ASSERT(
|
|
formal->containedTypes().size() == 0,
|
|
"Composite types for alias analysis not yet supported");
|
|
TORCH_INTERNAL_ASSERT(
|
|
!formal->isWildcardBefore(),
|
|
"Doesn't make sense for a input value to begin as a wildcard");
|
|
|
|
const auto& formalAlias = formal->beforeSet();
|
|
|
|
// skip if we've already bound this alias
|
|
if (formalToActual.count(formalAlias) != 0) {
|
|
continue;
|
|
}
|
|
|
|
// Bind the formal to the actual
|
|
formalToActual[formalAlias] = actualValue;
|
|
|
|
// Record writes
|
|
if (formal->isWrite()) {
|
|
registerWrite(actualValue, node);
|
|
}
|
|
|
|
// Now deal with sets after the '->'
|
|
if (formal->isWildcardAfter()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
formal->afterSets().size() == 1,
|
|
"If the after set contains a wildcard, "
|
|
"there should be no other alias sets specified.");
|
|
setWildcard(actualValue);
|
|
} else {
|
|
// We don't understand anything else in the after yet, so assert there's
|
|
// been no change.
|
|
TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
|
|
}
|
|
}
|
|
|
|
// Use the formal-actual mapping to give aliases to the outputs
|
|
for (size_t i = 0; i < schema.returns().size(); i++) {
|
|
const auto actual = node->outputs().at(i);
|
|
const auto& formal = schema.returns()[i].alias_info();
|
|
if (!formal) {
|
|
// This is a fresh tensor
|
|
giveFreshAlias(actual);
|
|
continue;
|
|
}
|
|
|
|
// If this type cannot alias, continue. Can occur with a VarType schema
|
|
if (!mutableType(actual)) {
|
|
continue;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
formal->containedTypes().size() == 0,
|
|
"Composite types for alias analysis not yet supported");
|
|
TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
|
|
|
|
if (formal->isWildcardBefore()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
formal->beforeSets().size() == 1,
|
|
"If an output is a wildcard, "
|
|
"there should be no other alias sets specified.");
|
|
setWildcard(actual);
|
|
continue;
|
|
}
|
|
|
|
for (const auto& formalAlias : formal->beforeSets()) {
|
|
// If we encounter an alias annotation that wasn't in the inputs:
|
|
if (!formalToActual.count(formalAlias)) {
|
|
// If this alias is not seen elsewhere and is the only annotation on
|
|
// the output, it's equivalent to being fresh:
|
|
// e.g. foo(Tensor(a) self) -> Tensor(b)
|
|
if (formal->beforeSets().size() == 1) {
|
|
giveFreshAlias(actual);
|
|
}
|
|
// Or it is the form of a|fresh, which we can ignore, taking the
|
|
// conservative assumption that the output must alias `a`, e.g
|
|
// aten::cuda(Tensor(a) self) -> Tensor(a|fresh)
|
|
|
|
// Don't assign an alias set in that case.
|
|
continue;
|
|
}
|
|
|
|
auto toAlias = formalToActual.at(formalAlias);
|
|
makePointerTo(actual, toAlias);
|
|
}
|
|
|
|
// Record writes
|
|
if (formal->isWrite()) {
|
|
registerWrite(actual, node);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Register the fact that `n` writes to `v`.
|
|
void AliasDb::registerWrite(const Value* v, Node* n) {
|
|
if (!mutableType(v)) {
|
|
// don't need to register a write if the value isn't mutable
|
|
return;
|
|
}
|
|
auto it = elementMap_.find(v);
|
|
TORCH_INTERNAL_ASSERT(
|
|
it != elementMap_.end(), "Tried to write to value not in MemoryDAG");
|
|
const auto& writtenMemoryLocations = it->second->getMemoryLocations();
|
|
writeIndex_[n] |= writtenMemoryLocations;
|
|
}
|
|
|
|
void AliasDb::registerWrite(const Element* e, Node* n) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
e->pointsTo.empty(),
|
|
"Can only register writes to memory location elements");
|
|
writeIndex_[n].set(e->index);
|
|
}
|
|
|
|
void AliasDb::analyzeIf(Node* node) {
|
|
// For if statements, the alias set of an output is the union of the
|
|
// alias sets generated by the if and else block
|
|
const auto trueBlock = node->blocks().at(0);
|
|
const auto falseBlock = node->blocks().at(1);
|
|
analyze(trueBlock);
|
|
analyze(falseBlock);
|
|
|
|
for (size_t i = 0; i < node->outputs().size(); i++) {
|
|
const auto nodeOutput = node->outputs()[i];
|
|
|
|
const auto trueOutput = trueBlock->outputs().at(i);
|
|
const auto falseOutput = falseBlock->outputs().at(i);
|
|
|
|
makePointerTo(nodeOutput, trueOutput);
|
|
makePointerTo(nodeOutput, falseOutput);
|
|
}
|
|
}
|
|
|
|
void AliasDb::analyzeLoop(Node* node) {
|
|
const auto bodyBlock = node->blocks().at(0);
|
|
const auto loopCarriedInputs = node->inputs().slice(2); // skip max, cond
|
|
const auto blockInputs = bodyBlock->inputs().slice(1); // skip trip
|
|
const auto blockOutputs = bodyBlock->outputs().slice(1); // skip trip
|
|
TORCH_INTERNAL_ASSERT(loopCarriedInputs.size() == blockInputs.size());
|
|
TORCH_INTERNAL_ASSERT(blockOutputs.size() == node->outputs().size());
|
|
|
|
// Run alias analysis on the loop body, iterating until the block output
|
|
// alias info converges.
|
|
// Copy node input aliases to block input
|
|
mapAliases(blockInputs, loopCarriedInputs);
|
|
|
|
// Populate block output alias info by analyzing the body
|
|
analyze(bodyBlock);
|
|
|
|
// Copy the alias info from the block output to the node output
|
|
mapAliases(node->outputs(), blockOutputs);
|
|
}
|
|
|
|
void AliasDb::analyzeGradOf(Node* node) {
|
|
const auto grad_of_block = node->blocks().at(0);
|
|
analyze(grad_of_block);
|
|
mapAliases(node->outputs(), grad_of_block->outputs());
|
|
}
|
|
|
|
void AliasDb::analyzeSubgraph(Node* node) {
|
|
const auto subgraph = node->g(attr::Subgraph).get();
|
|
const auto subgraphBlock = subgraph->block();
|
|
mapAliases(subgraphBlock->inputs(), node->inputs());
|
|
|
|
analyze(subgraphBlock);
|
|
|
|
// Note: the subgraph outputs and node outputs are NOT NECESSARILY the
|
|
// same length. Autodifferentiation maybe capture additional outputs in the
|
|
// subgraph block.
|
|
TORCH_INTERNAL_ASSERT(
|
|
subgraphBlock->outputs().size() >= node->outputs().size());
|
|
for (size_t i = 0; i < node->outputs().size(); i++) {
|
|
makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]);
|
|
}
|
|
}
|
|
|
|
// For nodes that generate a fresh value from nothing
|
|
void AliasDb::analyzeCreator(Node* node) {
|
|
for (Value* output : node->outputs()) {
|
|
giveFreshAlias(output);
|
|
}
|
|
}
|
|
|
|
// For nodes that extract values from a composite type. Right now, this just
|
|
// gives up and creates wildcards for everything.
|
|
void AliasDb::analyzeExtractor(Node* node) {
|
|
for (const auto output : node->outputs()) {
|
|
setWildcard(output);
|
|
}
|
|
}
|
|
|
|
// For torch.chunk(), all returned tensors may alias the input tensor
|
|
void AliasDb::analyzeChunk(Node* node) {
|
|
for (auto output : node->outputs()) {
|
|
makePointerTo(output, node->input());
|
|
}
|
|
}
|
|
|
|
void AliasDb::analyzeFork(Node* node) {
|
|
for (const auto input : node->inputs()) {
|
|
setWildcard(input);
|
|
}
|
|
|
|
// Give the future that the fork emits a fresh value
|
|
for (const auto output : node->outputs()) {
|
|
giveFreshAlias(output);
|
|
}
|
|
}
|
|
|
|
void AliasDb::analyzeWait(Node* node) {
|
|
TORCH_INTERNAL_ASSERT(node->kind() == aten::wait);
|
|
for (const auto output : node->outputs()) {
|
|
setWildcard(output);
|
|
}
|
|
// the forked subgraph that `wait` is waiting on may write to any of its
|
|
// inputs. We don't have a reliable way of recovering the fork inputs, so
|
|
// for safety we just register a write to every wildcard.
|
|
for (const auto& pr : wildcardIndex_) {
|
|
registerWrite(pr.second, node);
|
|
}
|
|
}
|
|
|
|
void AliasDb::analyzeRpcAsync(Node* node) {
|
|
for (const auto input : node->inputs()) {
|
|
setWildcard(input);
|
|
}
|
|
|
|
// Give the future that the rpc_async emits a fresh value
|
|
for (const auto output : node->outputs()) {
|
|
giveFreshAlias(output);
|
|
}
|
|
}
|
|
|
|
// SetAttr: writes to the `self` field
|
|
void AliasDb::analyzeSetAttr(Node* node) {
|
|
const auto self = node->inputs().at(0);
|
|
TORCH_INTERNAL_ASSERT(self->type()->kind() == TypeKind::ClassType);
|
|
registerWrite(self, node);
|
|
// Also the value being set must become a wildcard.
|
|
const auto newValue = node->inputs().at(1);
|
|
setWildcard(newValue);
|
|
}
|
|
|
|
// Used for anything where we do not have accurate alias summaries
|
|
// may write to any input and produce wildcards
|
|
void AliasDb::analyzeConservative(Node* node) {
|
|
for (const auto input : node->inputs()) {
|
|
if (!mutableType(input)) {
|
|
continue;
|
|
}
|
|
auto elem = elementMap_.at(input);
|
|
registerWrite(input, node);
|
|
MemoryLocations mem_locations;
|
|
memoryDAG_->collectAllContainedMemoryLocations(elem, mem_locations);
|
|
for (unsigned loc : mem_locations) {
|
|
auto contained_elem = memoryDAG_->fromIndex(loc);
|
|
// we only register writes on memory locations
|
|
if (contained_elem->pointsTo.empty()) {
|
|
registerWrite(contained_elem, node);
|
|
}
|
|
}
|
|
setWildcard(input);
|
|
}
|
|
|
|
for (const auto output : node->outputs()) {
|
|
setWildcard(output);
|
|
}
|
|
}
|
|
|
|
// List or dict or tuple: construct: create an aliasing element for the actual
|
|
// container, then mark all inputs as wildcards, since they've gone inside the
|
|
// container. Then, add the wildcard sets of appropriate type to the contained
|
|
// elements of the container.
|
|
void AliasDb::analyzeContainerConstruct(Node* node) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
node->kind() == prim::ListConstruct ||
|
|
node->kind() == prim::DictConstruct ||
|
|
node->kind() == prim::TupleConstruct);
|
|
|
|
// tuples which contain immutable types are immutable
|
|
if (!mutableType(node->output())) {
|
|
return;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
|
|
auto container = node->output();
|
|
giveFreshAlias(container);
|
|
auto container_elem = elementMap_.at(container);
|
|
for (auto input : node->inputs()) {
|
|
auto maybe_wildcard_elem = setWildcard(input);
|
|
if (maybe_wildcard_elem) {
|
|
memoryDAG_->addToContainedElements(*maybe_wildcard_elem, container_elem);
|
|
}
|
|
}
|
|
}
|
|
|
|
// BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
|
|
// This is an intermediate node used only in the graph fuser.
|
|
void AliasDb::analyzeBroadcastingChunk(Node* node) {
|
|
auto inputs = node->inputs();
|
|
auto outputs = node->outputs();
|
|
auto nchunks = node->i(attr::chunks);
|
|
for (size_t index = 0; index < inputs.size(); ++index) {
|
|
// Each inputs[i] is aliased by exactly `nchunks` distinct output tensors:
|
|
// inputs[i] produces chunks outputs[i * nchunks + k] for k in [0..nchunks)
|
|
auto output_begin = outputs.begin() + index * nchunks;
|
|
for (auto it = output_begin; it != output_begin + nchunks; ++it) {
|
|
makePointerTo(*it, inputs.at(index));
|
|
}
|
|
}
|
|
}
|
|
|
|
bool AliasDb::nonAliasingValue(const Value* elem) const {
|
|
// these are values which can point to aliasing types in the graph,
|
|
// as with a None value pointing to an optional if node output,
|
|
// but will never alias themselves
|
|
return elem->mustBeNone() || elem->node()->kind() == prim::Uninitialized;
|
|
}
|
|
|
|
// Register the fact that `from` is a pointer to `to`
|
|
void AliasDb::makePointerTo(const Value* from, const Value* to) {
|
|
if (nonAliasingValue(from) || nonAliasingValue(to)) {
|
|
// if either value is guaranteed to be non-aliasing, we do not need to
|
|
// connect the two elements. however, it is invariant that aliasing types
|
|
// that are not wildcards have a memory dag element, so we create one if
|
|
// needed
|
|
giveFreshAlias(from);
|
|
giveFreshAlias(to);
|
|
return;
|
|
}
|
|
|
|
// covariant type containers can be point to types which are not
|
|
// also mutable/immutable because we unify the contained types
|
|
if (mutableType(from) != mutableType(to)) {
|
|
auto from_kind = from->type()->kind();
|
|
TORCH_INTERNAL_ASSERT(
|
|
from_kind == TypeKind::OptionalType ||
|
|
from_kind == TypeKind::FutureType || from_kind == TypeKind::TupleType);
|
|
return;
|
|
}
|
|
|
|
// both immutable
|
|
if (!mutableType(from)) {
|
|
return;
|
|
}
|
|
|
|
if (from == to) {
|
|
return;
|
|
}
|
|
|
|
// At this point, we are dealing with two mutable types.
|
|
auto fromEl = getOrCreateElement(from);
|
|
auto toEl = getOrCreateElement(to);
|
|
|
|
memoryDAG_->makePointerTo(fromEl, toEl);
|
|
}
|
|
|
|
void AliasDb::addToContainedElements(
|
|
const Value* elem,
|
|
const Value* container) {
|
|
if (!mutableType(elem)) {
|
|
return;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(isContainerType(container->type()));
|
|
|
|
auto elemEl = getOrCreateElement(elem);
|
|
auto contEl = getOrCreateElement(container);
|
|
|
|
memoryDAG_->addToContainedElements(elemEl, contEl);
|
|
}
|
|
|
|
bool AliasDb::mayAlias(const Value* a, const Value* b) const {
|
|
if (!mutableType(a) || !mutableType(b)) {
|
|
return false;
|
|
}
|
|
|
|
return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b));
|
|
}
|
|
|
|
bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const {
|
|
if (a.empty() || b.empty()) {
|
|
return false;
|
|
}
|
|
|
|
// Record all memory locations from group `a`
|
|
MemoryLocations aMemLocs;
|
|
for (const auto value : a) {
|
|
auto it = elementMap_.find(value);
|
|
if (it != elementMap_.end()) {
|
|
aMemLocs |= it->second->getMemoryLocations();
|
|
}
|
|
}
|
|
|
|
// If any of group `b`s memory locations overlap, return true.
|
|
for (const auto value : b) {
|
|
auto it = elementMap_.find(value);
|
|
if (it != elementMap_.end()) {
|
|
if (aMemLocs.intersects(it->second->getMemoryLocations())) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
// No overlap, so group `a` and `b` do not share a memory location
|
|
return false;
|
|
}
|
|
|
|
bool AliasDb::mayContainAlias(Value* a, Value* b) const {
|
|
const std::vector<Value*> a_vec = {a};
|
|
const std::vector<Value*> b_vec = {b};
|
|
|
|
return mayContainAlias(a_vec, b_vec);
|
|
}
|
|
|
|
std::vector<Element*> AliasDb::getElements(at::ArrayRef<Value*> vs) const {
|
|
std::vector<Element*> elements;
|
|
for (const auto& val : vs) {
|
|
if (mutableType(val)) {
|
|
elements.push_back(elementMap_.at(val));
|
|
}
|
|
}
|
|
return elements;
|
|
}
|
|
|
|
bool AliasDb::mayContainAlias(
|
|
const at::ArrayRef<Value*> a,
|
|
const at::ArrayRef<Value*> b) const {
|
|
auto a_elems = getElements(a);
|
|
return a_elems.size() == 0 ? false : memoryDAG_->mayContainAlias(a_elems, getElements(b));
|
|
}
|
|
|
|
// Make each value in the `from` list point to its partner in the `to` list
|
|
void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
|
|
TORCH_INTERNAL_ASSERT(to.size() == from.size());
|
|
for (size_t i = 0; i < to.size(); i++) {
|
|
makePointerTo(from[i], to[i]);
|
|
}
|
|
}
|
|
|
|
void AliasDb::giveFreshAlias(const Value* value) {
|
|
auto maybe_mut_type = getMutableTypePtr(value->type());
|
|
if (!maybe_mut_type) {
|
|
return;
|
|
}
|
|
|
|
if (elementMap_.count(value)) {
|
|
// Inside a loop, we may have given a fresh alias to this value already, so
|
|
// skip
|
|
return;
|
|
}
|
|
|
|
auto new_elem = memoryDAG_->makeFreshValue(value);
|
|
elementMap_[value] = new_elem;
|
|
addContainedTypesToFreshElement(new_elem, *maybe_mut_type);
|
|
}
|
|
|
|
Element* AliasDb::getOrCreateElement(const Value* value) {
|
|
if (!elementMap_.count(value)) {
|
|
giveFreshAlias(value);
|
|
}
|
|
return elementMap_.at(value);
|
|
}
|
|
|
|
bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
|
|
return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false);
|
|
}
|
|
|
|
bool AliasDb::couldMoveAfterTopologically(Node* n, Node* movePoint) {
|
|
return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/true);
|
|
}
|
|
|
|
bool AliasDb::moveBeforeTopologicallyValid(Node* n, Node* movePoint) {
|
|
// We have to distinguish the move side (instead of just moving after
|
|
// n->prev()). Consider the following example:
|
|
// If the dependency graph looks like
|
|
// n -> movePoint -> o
|
|
// then moveBefore(o) will end up with
|
|
// n, o, movePoint
|
|
// but moveAfter(n) will return false.
|
|
return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/false);
|
|
}
|
|
|
|
bool AliasDb::couldMoveBeforeTopologically(Node* n, Node* movePoint) {
|
|
return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/true);
|
|
}
|
|
|
|
bool AliasDb::hasWriters(const at::ArrayRef<Value*>& values) const {
|
|
return std::any_of(values.begin(), values.end(), [&](Value* value) {
|
|
return hasWriters(value);
|
|
});
|
|
}
|
|
|
|
bool AliasDb::escapesScope(const at::ArrayRef<Value*>& vs) const {
|
|
return mayContainAlias(graph_->inputs(), vs) ||
|
|
mayContainAlias(graph_->outputs(), vs) || mayAliasWildcard(vs);
|
|
}
|
|
|
|
// Correctness conditions:
|
|
// no values in either set can have writers, and values in both sets
|
|
// cannot escape the current graph scope. Values can escape the current scope
|
|
// by aliasing a graph output or input, or by aliasing the wildcard set.
|
|
bool AliasDb::safeToChangeAliasingRelationship(
|
|
const at::ArrayRef<Value*>& a,
|
|
const at::ArrayRef<Value*>& b) const {
|
|
if (hasWriters(a) || hasWriters(b)) {
|
|
return false;
|
|
}
|
|
|
|
return !(escapesScope(a) && escapesScope(b));
|
|
}
|
|
|
|
// Helper for topologically-safe node moves. See `tryMove()` for details.
|
|
class AliasDb::WorkingSet {
|
|
public:
|
|
explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) {
|
|
mover_ = mover;
|
|
for (const auto user : getUsersSameBlock(mover_)) {
|
|
moverUsers_.insert(user);
|
|
}
|
|
moverWrites_ |= aliasDb_.getWrites(mover_);
|
|
moverReads_ |= aliasDb_.getReads(mover_);
|
|
}
|
|
|
|
// Add `n` to the working set
|
|
void add(Node* n) {
|
|
nodes_.push_back(n);
|
|
for (const auto user : getUsersSameBlock(n)) {
|
|
users_.insert(user);
|
|
}
|
|
|
|
writes_ |= aliasDb_.getWrites(n);
|
|
reads_ |= aliasDb_.getReads(n);
|
|
}
|
|
|
|
void eraseMover() {
|
|
mover_ = nullptr;
|
|
moverWrites_.clear();
|
|
moverReads_.clear();
|
|
moverUsers_.clear();
|
|
}
|
|
|
|
const std::vector<Node*>& dependentNodes() {
|
|
return nodes_;
|
|
}
|
|
|
|
// Does the working set depend on `n`?
|
|
bool dependsOn(Node* n) const {
|
|
if (!mover_ && nodes_.empty()) {
|
|
return false;
|
|
}
|
|
|
|
return hasDataDependency(n) || hasMutabilityDependency(n);
|
|
}
|
|
|
|
private:
|
|
bool hasDataDependency(Node* n) const {
|
|
if (!mover_ && nodes_.empty()) {
|
|
return false;
|
|
}
|
|
const Node* pivot = mover_ ? mover_ : nodes_.front();
|
|
if (n->isAfter(pivot)) {
|
|
return producesFor(n);
|
|
} else {
|
|
return consumesFrom(n);
|
|
}
|
|
}
|
|
|
|
bool hasMutabilityDependency(Node* n) const {
|
|
// Check that `n` does not write to anything used by the working set
|
|
const auto& nWrites = aliasDb_.getWrites(n);
|
|
if (reads_.intersects(nWrites)) {
|
|
return true;
|
|
}
|
|
if (mover_ && moverReads_.intersects(nWrites)) {
|
|
return true;
|
|
}
|
|
|
|
// Check that the working set doesn't write to anything that `n` uses.
|
|
const auto& nReads = aliasDb_.getReads(n);
|
|
if (writes_.intersects(nReads)) {
|
|
return true;
|
|
}
|
|
if (mover_ && moverWrites_.intersects(nReads)) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Does the working set produce any values consumed by `n`?
|
|
bool producesFor(Node* n) const {
|
|
// This equivalent to asking: does the total use-set of all the nodes in the
|
|
// working set include `n`?
|
|
if (mover_ && moverUsers_.count(n)) {
|
|
return true;
|
|
}
|
|
return users_.count(n) != 0;
|
|
}
|
|
|
|
// Does the working set consume any values produced by `n`?
|
|
bool consumesFrom(Node* n) const {
|
|
const auto users = getUsersSameBlock(n);
|
|
|
|
if (mover_ && users.count(mover_)) {
|
|
return true;
|
|
}
|
|
return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
|
|
return users.count(node) != 0;
|
|
});
|
|
}
|
|
|
|
// Get all users of outputs of `n`, in the same block as `n`.
|
|
// This means if there is an `if` node that uses an output of `n` in some
|
|
// inner sub-block, we will consider the whole `if` node a user of `n`.
|
|
std::unordered_set<Node*> getUsersSameBlock(Node* n) const {
|
|
std::unordered_set<Node*> users;
|
|
for (const auto output : n->outputs()) {
|
|
for (const auto& use : output->uses()) {
|
|
if (auto sameBlock = findSameBlock(use.user, n)) {
|
|
users.insert(sameBlock);
|
|
}
|
|
}
|
|
}
|
|
return users;
|
|
}
|
|
|
|
// Traverse `target`'s blockchain upward until we find a node that shares a
|
|
// block with `n`.
|
|
//
|
|
// If one can't be found (say, because `n` is an inner block and target is
|
|
// outside), then return nullptr. Since we can only reorder nodes within a
|
|
// block, `target` would be irrelevant.
|
|
static Node* findSameBlock(Node* target, Node* n) {
|
|
TORCH_INTERNAL_ASSERT(target->owningGraph() == n->owningGraph());
|
|
if (target->owningBlock() == n->owningBlock()) {
|
|
return target;
|
|
} else {
|
|
// This user is in a sub-block. Traverse the blockchain upward until
|
|
// we arrive at a node that shares a block with `this`
|
|
auto curNode = target;
|
|
while (curNode->owningBlock() != n->owningBlock()) {
|
|
curNode = curNode->owningBlock()->owningNode();
|
|
if (curNode == nullptr) {
|
|
return curNode;
|
|
}
|
|
}
|
|
return curNode;
|
|
}
|
|
}
|
|
|
|
const AliasDb& aliasDb_;
|
|
std::vector<Node*> nodes_;
|
|
|
|
// Mover dependencies. We track these separately since we may erase the mover
|
|
// from the working set.
|
|
Node* mover_;
|
|
MemoryLocations moverWrites_;
|
|
MemoryLocations moverReads_;
|
|
std::unordered_set<Node*> moverUsers_;
|
|
|
|
// users => # of working set nodes it uses
|
|
std::unordered_set<Node*> users_;
|
|
// Values written to by the working set => number of nodes writing to value
|
|
MemoryLocations writes_;
|
|
MemoryLocations reads_;
|
|
};
|
|
|
|
// Try to move `toMove` before/after `movePoint` while preserving value
|
|
// dependencies. Returns false iff such a move could not be made.
|
|
//
|
|
// If `dryRun` is set, don't actually execute the move, just check if the move
|
|
// is possible
|
|
//
|
|
// The basic approach is: have a "working set" that we are moving forward, one
|
|
// node at a time. When we can't move past a node (because it depends on the
|
|
// working set), then add it to the working set and keep moving until we hit
|
|
// `moveAfter`.
|
|
bool AliasDb::tryMove(
|
|
Node* toMove,
|
|
Node* movePoint,
|
|
MoveSide moveSide,
|
|
bool dryRun) {
|
|
TORCH_INTERNAL_ASSERT(toMove->owningBlock() == movePoint->owningBlock());
|
|
if (toMove == movePoint) {
|
|
return true;
|
|
}
|
|
|
|
// 1. Move from `this` toward movePoint, building up the working set of
|
|
// dependencies
|
|
WorkingSet workingSet(toMove, *this);
|
|
|
|
int direction;
|
|
if (toMove->isAfter(movePoint)) {
|
|
direction = kPrevDirection;
|
|
} else {
|
|
direction = kNextDirection;
|
|
}
|
|
|
|
auto curNode = toMove->next_in_graph[direction];
|
|
// Move forward one node at a time
|
|
while (curNode != movePoint) {
|
|
if (workingSet.dependsOn(curNode)) {
|
|
// If we can't move past this node, add it to the working set
|
|
workingSet.add(curNode);
|
|
}
|
|
curNode = curNode->next_in_graph[direction];
|
|
}
|
|
|
|
// 2. Decide whether we can move it all to `movePoint`.
|
|
|
|
// Say we are moving directly before movePoint and `toMove` starts before
|
|
// movePoint in the graph. The move looks like
|
|
//
|
|
// `toMove` `toMove` |
|
|
// <dependencies> -> `movePoint` | `toMove` and deps are split
|
|
// `movePoint` <dependencies> |
|
|
//
|
|
// Contrast with the case where `toMove` starts AFTER movePoint:
|
|
//
|
|
// `movePoint` <dependencies> |
|
|
// <dependencies> -> `toMove` | `toMove` and deps are together
|
|
// `toMove` `movePoint` |
|
|
//
|
|
// In the first case, we need to split `this` off from its dependencies, so we
|
|
// can move the dependencies below `movePoint` and keep `toMove` above.
|
|
const bool splitToMoveAndDeps =
|
|
(moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) ||
|
|
(moveSide == MoveSide::AFTER && toMove->isAfter(movePoint));
|
|
|
|
if (splitToMoveAndDeps) {
|
|
// remove `this` from dependencies to be moved past `movePoint`
|
|
workingSet.eraseMover();
|
|
}
|
|
|
|
// Check if we can move the working set past the move point
|
|
if (workingSet.dependsOn(movePoint)) {
|
|
// if we can't, then there are intermediate dependencies between the
|
|
// `this` and `movePoint`, so we can't do the move
|
|
return false;
|
|
}
|
|
|
|
if (dryRun) {
|
|
return true;
|
|
}
|
|
|
|
// 3. Execute the move
|
|
TORCH_INTERNAL_ASSERT(curNode == movePoint);
|
|
if (splitToMoveAndDeps) {
|
|
// Move `toMove`
|
|
move(toMove, movePoint, moveSide);
|
|
|
|
// Then move all of its dependencies on the other side of `movePoint`
|
|
const auto reversed =
|
|
moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE;
|
|
for (auto n : workingSet.dependentNodes()) {
|
|
move(n, curNode, reversed);
|
|
curNode = n;
|
|
}
|
|
} else {
|
|
// Just append/prepend everything to `movePoint`
|
|
move(toMove, curNode, moveSide);
|
|
curNode = toMove;
|
|
for (auto n : workingSet.dependentNodes()) {
|
|
move(n, curNode, moveSide);
|
|
curNode = n;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Helper function so we can generalize `tryMove`
|
|
void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
|
|
switch (moveSide) {
|
|
case MoveSide::BEFORE:
|
|
toMove->moveBefore(movePoint);
|
|
break;
|
|
case MoveSide::AFTER:
|
|
toMove->moveAfter(movePoint);
|
|
break;
|
|
}
|
|
}
|
|
|
|
bool AliasDb::writesToWildcard(Node* n) const {
|
|
if (!writeIndex_.count(n)) {
|
|
return false;
|
|
}
|
|
const auto& writes = writeIndex_.at(n);
|
|
|
|
// Are any of these memoryLocs a wildcard element?
|
|
for (const auto& pr : wildcardIndex_) {
|
|
const auto wildcardElement = pr.second;
|
|
if (writes.test(wildcardElement->index)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool AliasDb::mayAliasWildcard(const Value* v) const {
|
|
if (auto e = getWildcard(v->type())) {
|
|
return memoryDAG_->mayAlias(elementMap_.at(v), e);
|
|
}
|
|
// There were no wildcards of this type, so return false.
|
|
return false;
|
|
}
|
|
|
|
bool AliasDb::mayAliasWildcard(const at::ArrayRef<Value*> vs) const {
|
|
return std::any_of(
|
|
vs.begin(), vs.end(), [&](Value* v) { return mayAliasWildcard(v); });
|
|
}
|
|
|
|
c10::optional<Element*> AliasDb::tryGetOrCreateWildcard(const TypePtr& type) {
|
|
auto updated_type = getMutableTypePtr(type);
|
|
if (!updated_type) {
|
|
return c10::nullopt;
|
|
}
|
|
auto mapped_type = *updated_type;
|
|
auto existing_wildcard = wildcardIndex_.find(mapped_type);
|
|
if (existing_wildcard != wildcardIndex_.end()) {
|
|
return existing_wildcard->second;
|
|
}
|
|
|
|
auto wildcard_elem = memoryDAG_->makeFreshValue(nullptr);
|
|
wildcardIndex_.emplace(mapped_type, wildcard_elem);
|
|
addContainedTypesToFreshElement(wildcard_elem, mapped_type);
|
|
return wildcard_elem;
|
|
}
|
|
|
|
void AliasDb::addContainedTypesToFreshElement(
|
|
Element* container_elem,
|
|
const TypePtr& mut_type) {
|
|
for (const auto& contained : mut_type->containedTypes()) {
|
|
auto maybe_elem = tryGetOrCreateWildcard(contained);
|
|
if (maybe_elem) {
|
|
memoryDAG_->addToContainedElements(*maybe_elem, container_elem);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Search the wildcard index for an element that corresponds to the given type.
|
|
// Const version returns nullptr
|
|
Element* AliasDb::getWildcard(const TypePtr& type) const {
|
|
auto maybe_mut_type = getMutableTypePtr(type);
|
|
if (!maybe_mut_type) {
|
|
return nullptr;
|
|
}
|
|
TypePtr mut_type = *maybe_mut_type;
|
|
auto wildcard = wildcardIndex_.find(mut_type);
|
|
if (wildcard != wildcardIndex_.end()) {
|
|
return wildcard->second;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
// Register `v` as a wildcard value.
|
|
c10::optional<Element*> AliasDb::setWildcard(const Value* v) {
|
|
auto maybe_wildcardElement = tryGetOrCreateWildcard(v->type());
|
|
if (!maybe_wildcardElement) {
|
|
return c10::nullopt;
|
|
}
|
|
|
|
auto wildcardElement = *maybe_wildcardElement;
|
|
// Making a value a wildcard means that all its potential memory locations
|
|
// may alias the wildcard set.
|
|
const MemoryLocations pointeeSet = getOrCreateElement(v)->getMemoryLocations();
|
|
for (const auto& pointee : pointeeSet) {
|
|
auto from = memoryDAG_->fromIndex(pointee);
|
|
// avoid cycles where the wildcard points to itself
|
|
if (from != wildcardElement) {
|
|
memoryDAG_->makePointerTo(from, wildcardElement);
|
|
}
|
|
}
|
|
|
|
// We also need to update the write index with new writes to the wildcard set.
|
|
for (auto& pr : writeIndex_) {
|
|
auto& writtenTo = pr.second;
|
|
if (writtenTo.intersects(pointeeSet)) {
|
|
writtenTo.set(wildcardElement->index);
|
|
}
|
|
}
|
|
isWriteCacheStale_ = true;
|
|
return wildcardElement;
|
|
}
|
|
|
|
void AliasDb::rebuildWriteCache() const {
|
|
for (const auto& pr : writeIndex_) {
|
|
const auto& writtenLocs = pr.second;
|
|
writeCache_ |= writtenLocs;
|
|
}
|
|
isWriteCacheStale_ = false;
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|