Mark values entering containers as wildcards (#20556)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20556
ghimport-source-id: d7c62e38a2f6928f6f8d988c26a38ea8f8cff8b6

Reviewed By: jamesr66a

Differential Revision: D15447568

Pulled By: suo

fbshipit-source-id: 77ebc11b571b8517d3bad3ee1b3ee5ac037542b2
This commit is contained in:
Michael Suo
2019-05-22 16:42:33 -07:00
committed by Facebook Github Bot
parent 28be521e39
commit 90910fc6cb
6 changed files with 128 additions and 47 deletions

View File

@ -22,11 +22,6 @@ class AliasInfo {
static const Symbol wc = Symbol::fromQualString("alias::*");
return wc;
}
static AliasInfo createWildcard() {
AliasInfo ret;
ret.addBeforeSet(wildcardSet());
return ret;
}
void setIsWrite(bool isWrite) {
isWrite_ = isWrite;
@ -57,10 +52,14 @@ class AliasInfo {
return *beforeSets_.begin();
}
bool isWildcard() const {
bool isWildcardBefore() const {
return beforeSets_.count(wildcardSet()) != 0;
}
bool isWildcardAfter() const {
return afterSets_.count(wildcardSet()) != 0;
}
// the alias info for the contained types of the type
// e.g. if this is an annotation on List[T], `sets` refers to
// the alias sets that the list may be in

View File

@ -680,6 +680,49 @@ graph():
AT_ASSERT(!aliasDb.mayContainAlias(first_st, second_st));
AT_ASSERT(!aliasDb.mayContainAlias(second_st, tup_st));
}
{
// Test list container aliasing
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
script::parseIR(
R"IR(
graph():
%10 : bool? = prim::Constant()
%8 : Device? = prim::Constant()
%4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%2 : int[] = prim::ListConstruct(%0, %1)
%x : Tensor = aten::rand(%2, %4, %4, %8, %10)
%12 : int[] = prim::ListConstruct(%0, %1)
%y : Tensor = aten::rand(%12, %4, %4, %8, %10)
%22 : int[] = prim::ListConstruct(%0, %1)
%z : Tensor = aten::rand(%22, %4, %4, %8, %10)
%32 : int[] = prim::ListConstruct(%0, %1)
%fresh : Tensor = aten::rand(%32, %4, %4, %8, %10)
%foo : Tensor[] = prim::ListConstruct(%x, %y)
%43 : Tensor[] = aten::append(%foo, %z)
return ()
)IR",
graph.get(),
vmap);
AliasDb aliasDb(graph);
auto x = vmap["x"];
auto y = vmap["y"];
auto z = vmap["z"];
// Tensors x, y, and z went into a list, so they all may alias each other.
ASSERT_TRUE(aliasDb.mayAlias(x, y));
ASSERT_TRUE(aliasDb.mayAlias(y, z));
ASSERT_TRUE(aliasDb.mayAlias(x, z));
// But we know `fresh` didn't go into a list, so x, y, and z should not
// alias it.
// auto fresh = vmap["fresh"];
// ASSERT_FALSE(aliasDb.mayAlias(x, fresh));
// ASSERT_FALSE(aliasDb.mayAlias(y, fresh));
// ASSERT_FALSE(aliasDb.mayAlias(z, fresh));
}
}
void testWildcards() {

View File

@ -405,8 +405,6 @@ void AliasDb::analyzeImpl(Node* node) {
case prim::GradOf:
return analyzeGradOf(node);
case prim::Constant:
case prim::DictConstruct:
case prim::ListConstruct:
case prim::AutogradZero:
case prim::AutogradAdd:
case prim::FusedConcat:
@ -417,6 +415,9 @@ void AliasDb::analyzeImpl(Node* node) {
case prim::Function:
case prim::CreateObject:
return analyzeCreator(node);
case prim::DictConstruct:
case prim::ListConstruct:
return analyzeContainerConstruct(node);
case prim::TupleUnpack:
case prim::TupleIndex:
case prim::DictIndex:
@ -483,7 +484,8 @@ void AliasDb::analyzeImpl(Node* node) {
return analyzeCustomOp(node);
}
// Bind formal alias annotation to actual alias sets
// Bind the schema's "formal" alias annotation to the actual values those
// schema arguments represent
std::unordered_map<Symbol, Value*> formalToActual;
for (size_t i = 0; i < schema.arguments().size(); i++) {
const auto& formal = schema.arguments()[i].alias_info();
@ -498,11 +500,12 @@ void AliasDb::analyzeImpl(Node* node) {
continue;
}
// We don't support composite types for alias analysis yet.
// Do sanity checks on the alias annotation
// - We don't support composite types for alias analysis yet.
AT_ASSERT(formal->containedTypes().size() == 0);
// TODO neither unions nor wildcards make sense on an input. We should
// disallow them in function schema
AT_ASSERT(!formal->isWildcard())
// - Doesn't make sense for a value to start annotated as a wildcard.
AT_ASSERT(!formal->isWildcardBefore());
const auto& formalAlias = formal->beforeSet();
// skip if we've already bound this alias
@ -517,6 +520,15 @@ void AliasDb::analyzeImpl(Node* node) {
if (formal->isWrite()) {
registerWrite(actualValue, node);
}
// Now deal with sets after the '->'
if (formal->isWildcardAfter()) {
setWildcard(actualValue);
} else {
// We don't understand anything else in the after yet, so assert there's
// been no change.
AT_ASSERT(formal->beforeSets() == formal->afterSets());
}
}
// Use the formal-actual mapping to give aliases to the outputs
@ -537,7 +549,7 @@ void AliasDb::analyzeImpl(Node* node) {
// We don't support composite types for alias analysis yet.
AT_ASSERT(formal->containedTypes().size() == 0);
if (formal->isWildcard()) {
if (formal->isWildcardBefore() || formal->isWildcardAfter()) {
setWildcard(actual);
continue;
}
@ -774,6 +786,23 @@ void AliasDb::analyzeCustomOp(Node* node) {
}
}
// List or dict construct: create an aliasing element for the actual container,
// then mark all inputs as wildcards, since they've gone inside the container.
// TODO: tuples are treated differently since we actually compare the contained
// values for aliasing, so we don't need wildcards.
void AliasDb::analyzeContainerConstruct(Node* node) {
AT_ASSERT(
node->kind() == prim::ListConstruct ||
node->kind() == prim::DictConstruct);
for (auto input : node->inputs()) {
setWildcard(input);
}
for (auto output : node->outputs()) {
giveFreshAlias(output);
}
}
// BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
// This is an intermediate node used only in the graph fuser.
void AliasDb::analyzeBroadcastingChunk(Node* node) {

View File

@ -42,8 +42,10 @@ class AliasDb {
// Does `n` write to an alias of one of the values in `vs`?
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
TORCH_API bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
const;
TORCH_API bool writesToAlias(
Node* n,
const ValueSet& vs,
bool recurseBlocks = false) const;
// Does `a` and `b` potentially share a memory location or do either
// hold in memory any element that exists in the other
@ -65,9 +67,11 @@ class AliasDb {
// `Elements`.
template <
typename... Other1,
template <typename, typename...> class T,
template <typename, typename...>
class T,
typename... Other2,
template <typename, typename...> class U>
template <typename, typename...>
class U>
bool mayAlias(
const T<const Value*, Other1...>& a,
const U<const Value*, Other2...>& b) const {
@ -199,6 +203,7 @@ class AliasDb {
void analyzeSetAttr(Node* node);
void analyzeTupleConstruct(Node* node);
void analyzeCustomOp(Node* node);
void analyzeContainerConstruct(Node* node);
bool tryRegisteredAnalysis(Node* node);
/**
@ -206,7 +211,9 @@ class AliasDb {
*/
void makeAllAlias(const std::vector<Value*>& values);
void makePointerTo(const Value* value, const Value* to);
TORCH_API void addToContainedElements(const Value* element, const Value* container);
TORCH_API void addToContainedElements(
const Value* element,
const Value* container);
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
void giveFreshAlias(const Value* value);
Element* getOrCreateElement(const Value* value);

View File

@ -1752,7 +1752,7 @@ RegisterOperators reg2({
listSelect<Shared<c_type>>), \
Operator( \
"aten::append( " decl_type "[](a!) self, " decl_type \
"(c) el) -> " decl_type "[](a!)", \
"(c -> *) el) -> " decl_type "[](a!)", \
listAppend<Shared<c_type>, c_type::ElemType>), \
Operator( \
"aten::reverse( " decl_type "[](a!) self) -> ()", \
@ -1768,7 +1768,7 @@ RegisterOperators reg2({
listCopy<Shared<c_type>>), \
Operator( \
"aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
" el) -> " decl_type "[](a!)", \
"(b -> *) el) -> " decl_type "[](a!)", \
listSetItem<Shared<c_type>, c_type::ElemType>), \
Operator( \
"aten::clear( " decl_type "[](a!) self) -> ()", \
@ -1776,7 +1776,7 @@ RegisterOperators reg2({
Operator( \
"aten::insert( " decl_type \
"[](a!) self, int idx, \
" decl_type " el) -> ()", \
" decl_type "(b -> *) el) -> ()", \
listInsert<Shared<c_type>, c_type::ElemType>), \
Operator( \
"aten::pop(" decl_type \
@ -2263,7 +2263,7 @@ RegisterOperators reg2({
dictGetDefault), \
Operator( \
"aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \
" idx, t v) -> ()", \
" idx, t(b -> *) v) -> ()", \
dictSetItem)
CREATE_DICT_OPS("str"),

View File

@ -1,31 +1,31 @@
#include <torch/csrc/jit/script/schema_type_parser.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/alias_info.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/jit_type.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/script/lexer.h>
#include <torch/csrc/jit/script/parse_string_literal.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/script/schema_type_parser.h>
#include <string>
using c10::Symbol;
using c10::AliasInfo;
using c10::BoolType;
using c10::CompleteTensorType;
using c10::DeviceObjType;
using c10::DictType;
using c10::DimensionedTensorType;
using c10::FloatType;
using c10::FutureType;
using c10::GeneratorType;
using c10::IntType;
using c10::DeviceObjType;
using c10::NumberType;
using c10::StringType;
using c10::BoolType;
using c10::NoneType;
using c10::FloatType;
using c10::OptionalType;
using c10::TupleType;
using c10::TensorType;
using c10::DimensionedTensorType;
using c10::CompleteTensorType;
using c10::FutureType;
using c10::DictType;
using c10::ListType;
using c10::NoneType;
using c10::NumberType;
using c10::OptionalType;
using c10::StringType;
using c10::Symbol;
using c10::TensorType;
using c10::TupleType;
using c10::VarType;
using c10::AliasInfo;
namespace torch {
namespace jit {
@ -76,10 +76,10 @@ c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
// optional 'alias set annotation'
parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
if (L.nextIf('*')) {
alias_info = AliasInfo::createWildcard();
alias_info.addBeforeSet(AliasInfo::wildcardSet());
// If we found a wildcard, ignore all subsequent annotations
} else if (!alias_info.isWildcard()) {
} else if (!alias_info.isWildcardBefore()) {
alias_info.addBeforeSet(
Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
}
@ -90,11 +90,14 @@ c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
if (L.nextIf(TK_ARROW)) {
// optional 'alias set annotation'
parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
if (L.cur().kind == '*') {
L.reportError("Wildcard not allowed as part of the output set");
if (L.nextIf('*')) {
alias_info.addAfterSet(AliasInfo::wildcardSet());
// If we found a wildcard, ignore all subsequent annotations
} else if (!alias_info.isWildcardAfter()) {
alias_info.addAfterSet(
Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
}
alias_info.addAfterSet(
Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
});
} else {
// We didn't encounter an ->, so assume the "after set" is identical