mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Mark values entering containers as wildcards (#20556)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20556 ghimport-source-id: d7c62e38a2f6928f6f8d988c26a38ea8f8cff8b6 Reviewed By: jamesr66a Differential Revision: D15447568 Pulled By: suo fbshipit-source-id: 77ebc11b571b8517d3bad3ee1b3ee5ac037542b2
This commit is contained in:
committed by
Facebook Github Bot
parent
28be521e39
commit
90910fc6cb
@ -22,11 +22,6 @@ class AliasInfo {
|
||||
static const Symbol wc = Symbol::fromQualString("alias::*");
|
||||
return wc;
|
||||
}
|
||||
static AliasInfo createWildcard() {
|
||||
AliasInfo ret;
|
||||
ret.addBeforeSet(wildcardSet());
|
||||
return ret;
|
||||
}
|
||||
|
||||
void setIsWrite(bool isWrite) {
|
||||
isWrite_ = isWrite;
|
||||
@ -57,10 +52,14 @@ class AliasInfo {
|
||||
return *beforeSets_.begin();
|
||||
}
|
||||
|
||||
bool isWildcard() const {
|
||||
bool isWildcardBefore() const {
|
||||
return beforeSets_.count(wildcardSet()) != 0;
|
||||
}
|
||||
|
||||
bool isWildcardAfter() const {
|
||||
return afterSets_.count(wildcardSet()) != 0;
|
||||
}
|
||||
|
||||
// the alias info for the contained types of the type
|
||||
// e.g. if this is an annotation on List[T], `sets` refers to
|
||||
// the alias sets that the list may be in
|
||||
|
@ -680,6 +680,49 @@ graph():
|
||||
AT_ASSERT(!aliasDb.mayContainAlias(first_st, second_st));
|
||||
AT_ASSERT(!aliasDb.mayContainAlias(second_st, tup_st));
|
||||
}
|
||||
{
|
||||
// Test list container aliasing
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%10 : bool? = prim::Constant()
|
||||
%8 : Device? = prim::Constant()
|
||||
%4 : int? = prim::Constant()
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
%1 : int = prim::Constant[value=3]()
|
||||
%2 : int[] = prim::ListConstruct(%0, %1)
|
||||
%x : Tensor = aten::rand(%2, %4, %4, %8, %10)
|
||||
%12 : int[] = prim::ListConstruct(%0, %1)
|
||||
%y : Tensor = aten::rand(%12, %4, %4, %8, %10)
|
||||
%22 : int[] = prim::ListConstruct(%0, %1)
|
||||
%z : Tensor = aten::rand(%22, %4, %4, %8, %10)
|
||||
%32 : int[] = prim::ListConstruct(%0, %1)
|
||||
%fresh : Tensor = aten::rand(%32, %4, %4, %8, %10)
|
||||
%foo : Tensor[] = prim::ListConstruct(%x, %y)
|
||||
%43 : Tensor[] = aten::append(%foo, %z)
|
||||
return ()
|
||||
)IR",
|
||||
graph.get(),
|
||||
vmap);
|
||||
AliasDb aliasDb(graph);
|
||||
auto x = vmap["x"];
|
||||
auto y = vmap["y"];
|
||||
auto z = vmap["z"];
|
||||
// Tensors x, y, and z went into a list, so they all may alias each other.
|
||||
ASSERT_TRUE(aliasDb.mayAlias(x, y));
|
||||
ASSERT_TRUE(aliasDb.mayAlias(y, z));
|
||||
ASSERT_TRUE(aliasDb.mayAlias(x, z));
|
||||
|
||||
// But we know `fresh` didn't go into a list, so x, y, and z should not
|
||||
// alias it.
|
||||
|
||||
// auto fresh = vmap["fresh"];
|
||||
// ASSERT_FALSE(aliasDb.mayAlias(x, fresh));
|
||||
// ASSERT_FALSE(aliasDb.mayAlias(y, fresh));
|
||||
// ASSERT_FALSE(aliasDb.mayAlias(z, fresh));
|
||||
}
|
||||
}
|
||||
|
||||
void testWildcards() {
|
||||
|
@ -405,8 +405,6 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
case prim::GradOf:
|
||||
return analyzeGradOf(node);
|
||||
case prim::Constant:
|
||||
case prim::DictConstruct:
|
||||
case prim::ListConstruct:
|
||||
case prim::AutogradZero:
|
||||
case prim::AutogradAdd:
|
||||
case prim::FusedConcat:
|
||||
@ -417,6 +415,9 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
case prim::Function:
|
||||
case prim::CreateObject:
|
||||
return analyzeCreator(node);
|
||||
case prim::DictConstruct:
|
||||
case prim::ListConstruct:
|
||||
return analyzeContainerConstruct(node);
|
||||
case prim::TupleUnpack:
|
||||
case prim::TupleIndex:
|
||||
case prim::DictIndex:
|
||||
@ -483,7 +484,8 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
return analyzeCustomOp(node);
|
||||
}
|
||||
|
||||
// Bind formal alias annotation to actual alias sets
|
||||
// Bind the schema's "formal" alias annotation to the actual values those
|
||||
// schema arguments represent
|
||||
std::unordered_map<Symbol, Value*> formalToActual;
|
||||
for (size_t i = 0; i < schema.arguments().size(); i++) {
|
||||
const auto& formal = schema.arguments()[i].alias_info();
|
||||
@ -498,11 +500,12 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We don't support composite types for alias analysis yet.
|
||||
// Do sanity checks on the alias annotation
|
||||
// - We don't support composite types for alias analysis yet.
|
||||
AT_ASSERT(formal->containedTypes().size() == 0);
|
||||
// TODO neither unions nor wildcards make sense on an input. We should
|
||||
// disallow them in function schema
|
||||
AT_ASSERT(!formal->isWildcard())
|
||||
// - Doesn't make sense for a value to start annotated as a wildcard.
|
||||
AT_ASSERT(!formal->isWildcardBefore());
|
||||
|
||||
const auto& formalAlias = formal->beforeSet();
|
||||
|
||||
// skip if we've already bound this alias
|
||||
@ -517,6 +520,15 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
if (formal->isWrite()) {
|
||||
registerWrite(actualValue, node);
|
||||
}
|
||||
|
||||
// Now deal with sets after the '->'
|
||||
if (formal->isWildcardAfter()) {
|
||||
setWildcard(actualValue);
|
||||
} else {
|
||||
// We don't understand anything else in the after yet, so assert there's
|
||||
// been no change.
|
||||
AT_ASSERT(formal->beforeSets() == formal->afterSets());
|
||||
}
|
||||
}
|
||||
|
||||
// Use the formal-actual mapping to give aliases to the outputs
|
||||
@ -537,7 +549,7 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
// We don't support composite types for alias analysis yet.
|
||||
AT_ASSERT(formal->containedTypes().size() == 0);
|
||||
|
||||
if (formal->isWildcard()) {
|
||||
if (formal->isWildcardBefore() || formal->isWildcardAfter()) {
|
||||
setWildcard(actual);
|
||||
continue;
|
||||
}
|
||||
@ -774,6 +786,23 @@ void AliasDb::analyzeCustomOp(Node* node) {
|
||||
}
|
||||
}
|
||||
|
||||
// List or dict construct: create an aliasing element for the actual container,
|
||||
// then mark all inputs as wildcards, since they've gone inside the container.
|
||||
// TODO: tuples are treated differently since we actually compare the contained
|
||||
// values for aliasing, so we don't need wildcards.
|
||||
void AliasDb::analyzeContainerConstruct(Node* node) {
|
||||
AT_ASSERT(
|
||||
node->kind() == prim::ListConstruct ||
|
||||
node->kind() == prim::DictConstruct);
|
||||
|
||||
for (auto input : node->inputs()) {
|
||||
setWildcard(input);
|
||||
}
|
||||
for (auto output : node->outputs()) {
|
||||
giveFreshAlias(output);
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
|
||||
// This is an intermediate node used only in the graph fuser.
|
||||
void AliasDb::analyzeBroadcastingChunk(Node* node) {
|
||||
|
@ -42,8 +42,10 @@ class AliasDb {
|
||||
|
||||
// Does `n` write to an alias of one of the values in `vs`?
|
||||
// if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
|
||||
TORCH_API bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
|
||||
const;
|
||||
TORCH_API bool writesToAlias(
|
||||
Node* n,
|
||||
const ValueSet& vs,
|
||||
bool recurseBlocks = false) const;
|
||||
|
||||
// Does `a` and `b` potentially share a memory location or do either
|
||||
// hold in memory any element that exists in the other
|
||||
@ -65,9 +67,11 @@ class AliasDb {
|
||||
// `Elements`.
|
||||
template <
|
||||
typename... Other1,
|
||||
template <typename, typename...> class T,
|
||||
template <typename, typename...>
|
||||
class T,
|
||||
typename... Other2,
|
||||
template <typename, typename...> class U>
|
||||
template <typename, typename...>
|
||||
class U>
|
||||
bool mayAlias(
|
||||
const T<const Value*, Other1...>& a,
|
||||
const U<const Value*, Other2...>& b) const {
|
||||
@ -199,6 +203,7 @@ class AliasDb {
|
||||
void analyzeSetAttr(Node* node);
|
||||
void analyzeTupleConstruct(Node* node);
|
||||
void analyzeCustomOp(Node* node);
|
||||
void analyzeContainerConstruct(Node* node);
|
||||
bool tryRegisteredAnalysis(Node* node);
|
||||
|
||||
/**
|
||||
@ -206,7 +211,9 @@ class AliasDb {
|
||||
*/
|
||||
void makeAllAlias(const std::vector<Value*>& values);
|
||||
void makePointerTo(const Value* value, const Value* to);
|
||||
TORCH_API void addToContainedElements(const Value* element, const Value* container);
|
||||
TORCH_API void addToContainedElements(
|
||||
const Value* element,
|
||||
const Value* container);
|
||||
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
|
||||
void giveFreshAlias(const Value* value);
|
||||
Element* getOrCreateElement(const Value* value);
|
||||
|
@ -1752,7 +1752,7 @@ RegisterOperators reg2({
|
||||
listSelect<Shared<c_type>>), \
|
||||
Operator( \
|
||||
"aten::append( " decl_type "[](a!) self, " decl_type \
|
||||
"(c) el) -> " decl_type "[](a!)", \
|
||||
"(c -> *) el) -> " decl_type "[](a!)", \
|
||||
listAppend<Shared<c_type>, c_type::ElemType>), \
|
||||
Operator( \
|
||||
"aten::reverse( " decl_type "[](a!) self) -> ()", \
|
||||
@ -1768,7 +1768,7 @@ RegisterOperators reg2({
|
||||
listCopy<Shared<c_type>>), \
|
||||
Operator( \
|
||||
"aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
|
||||
" el) -> " decl_type "[](a!)", \
|
||||
"(b -> *) el) -> " decl_type "[](a!)", \
|
||||
listSetItem<Shared<c_type>, c_type::ElemType>), \
|
||||
Operator( \
|
||||
"aten::clear( " decl_type "[](a!) self) -> ()", \
|
||||
@ -1776,7 +1776,7 @@ RegisterOperators reg2({
|
||||
Operator( \
|
||||
"aten::insert( " decl_type \
|
||||
"[](a!) self, int idx, \
|
||||
" decl_type " el) -> ()", \
|
||||
" decl_type "(b -> *) el) -> ()", \
|
||||
listInsert<Shared<c_type>, c_type::ElemType>), \
|
||||
Operator( \
|
||||
"aten::pop(" decl_type \
|
||||
@ -2263,7 +2263,7 @@ RegisterOperators reg2({
|
||||
dictGetDefault), \
|
||||
Operator( \
|
||||
"aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \
|
||||
" idx, t v) -> ()", \
|
||||
" idx, t(b -> *) v) -> ()", \
|
||||
dictSetItem)
|
||||
|
||||
CREATE_DICT_OPS("str"),
|
||||
|
@ -1,31 +1,31 @@
|
||||
#include <torch/csrc/jit/script/schema_type_parser.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/alias_info.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
#include <torch/csrc/jit/script/lexer.h>
|
||||
#include <torch/csrc/jit/script/parse_string_literal.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
#include <torch/csrc/jit/script/schema_type_parser.h>
|
||||
#include <string>
|
||||
|
||||
using c10::Symbol;
|
||||
using c10::AliasInfo;
|
||||
using c10::BoolType;
|
||||
using c10::CompleteTensorType;
|
||||
using c10::DeviceObjType;
|
||||
using c10::DictType;
|
||||
using c10::DimensionedTensorType;
|
||||
using c10::FloatType;
|
||||
using c10::FutureType;
|
||||
using c10::GeneratorType;
|
||||
using c10::IntType;
|
||||
using c10::DeviceObjType;
|
||||
using c10::NumberType;
|
||||
using c10::StringType;
|
||||
using c10::BoolType;
|
||||
using c10::NoneType;
|
||||
using c10::FloatType;
|
||||
using c10::OptionalType;
|
||||
using c10::TupleType;
|
||||
using c10::TensorType;
|
||||
using c10::DimensionedTensorType;
|
||||
using c10::CompleteTensorType;
|
||||
using c10::FutureType;
|
||||
using c10::DictType;
|
||||
using c10::ListType;
|
||||
using c10::NoneType;
|
||||
using c10::NumberType;
|
||||
using c10::OptionalType;
|
||||
using c10::StringType;
|
||||
using c10::Symbol;
|
||||
using c10::TensorType;
|
||||
using c10::TupleType;
|
||||
using c10::VarType;
|
||||
using c10::AliasInfo;
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -76,10 +76,10 @@ c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
|
||||
// optional 'alias set annotation'
|
||||
parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
|
||||
if (L.nextIf('*')) {
|
||||
alias_info = AliasInfo::createWildcard();
|
||||
alias_info.addBeforeSet(AliasInfo::wildcardSet());
|
||||
|
||||
// If we found a wildcard, ignore all subsequent annotations
|
||||
} else if (!alias_info.isWildcard()) {
|
||||
} else if (!alias_info.isWildcardBefore()) {
|
||||
alias_info.addBeforeSet(
|
||||
Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
|
||||
}
|
||||
@ -90,11 +90,14 @@ c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
|
||||
if (L.nextIf(TK_ARROW)) {
|
||||
// optional 'alias set annotation'
|
||||
parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
|
||||
if (L.cur().kind == '*') {
|
||||
L.reportError("Wildcard not allowed as part of the output set");
|
||||
if (L.nextIf('*')) {
|
||||
alias_info.addAfterSet(AliasInfo::wildcardSet());
|
||||
|
||||
// If we found a wildcard, ignore all subsequent annotations
|
||||
} else if (!alias_info.isWildcardAfter()) {
|
||||
alias_info.addAfterSet(
|
||||
Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
|
||||
}
|
||||
alias_info.addAfterSet(
|
||||
Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
|
||||
});
|
||||
} else {
|
||||
// We didn't encounter an ->, so assume the "after set" is identical
|
||||
|
Reference in New Issue
Block a user