Support Union in TorchScript (#64234)

Summary:
This PR is created to replace https://github.com/pytorch/pytorch/pull/53180 PR stack, which has all the review discussions. Reason for needing a replacement is due to a messy Sandcastle issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64234

Reviewed By: gmagogsfm

Differential Revision: D30656444

Pulled By: ansley

fbshipit-source-id: 77536c8bcc88162e2c72636026ca3c16891d669a
This commit is contained in:
Ansley Ussery
2021-09-03 06:10:37 -07:00
committed by Facebook GitHub Bot
parent 91b926fab3
commit 6831d8e379
50 changed files with 2137 additions and 467 deletions

View File

@ -435,12 +435,12 @@ is `./build/bin/FILENAME --gtest_filter=TESTSUITE.TESTNAME`, where
`TESTNAME` is the name of the test you'd like to run and `TESTSUITE` is
the suite that test is defined in.
For example, if you wanted to run the test ` MayContainAlias`, which
For example, if you wanted to run the test `MayContainAlias`, which
is part of the test suite `ContainerAliasingTest` in the file
`test/cpp/jit/test_alias_analysis.cpp`, the command would be:
```bash
./build/bin/test_jit --gtest_filter=ContainerAliasingTest.UnionAliasing
./build/bin/test_jit --gtest_filter=ContainerAliasingTest.MayContainAlias
```

View File

@ -30,6 +30,9 @@ struct FunctionSchema;
struct NamedType;
using OptNameList = c10::optional<std::vector<std::string>>;
void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
struct AnyType;
using AnyTypePtr = std::shared_ptr<AnyType>;
// Any is the top of the type hierarchy, all other types are subtypes
@ -94,25 +97,84 @@ struct SingleElementType : public Type {
TypePtr elem;
};
struct UnionType;
using UnionTypePtr = std::shared_ptr<UnionType>;
struct TORCH_API UnionType : public Type {
friend struct Type;
static const TypeKind Kind = TypeKind::UnionType;
bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override;
std::string str() const override;
static UnionTypePtr create(std::vector<TypePtr> reference);
bool operator==(const Type& rhs) const override;
at::ArrayRef<TypePtr> containedTypes() const override {
return types_;
}
// For testing purposes only
at::ArrayRef<TypePtr> getTypes() const {
return types_;
}
TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
return create(contained_types);
}
bool canHoldType(TypePtr type) const;
bool hasFreeVariables() const override {
return has_free_variables_;
}
c10::optional<TypePtr> toOptional() const;
c10::optional<TypePtr> subtractTypeSet(std::vector<TypePtr>& to_subtract) const;
protected:
explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
std::string unionStr(TypePrinter printer = nullptr, bool is_annotation_str = false) const;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool has_free_variables_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<TypePtr> types_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool can_hold_none_;
};
struct OptionalType;
using OptionalTypePtr = std::shared_ptr<OptionalType>;
// This type represents an optional type, for each element type.
// Optional[T] can accept both T and None(nullopt in C++)
// This type represents an optional type. There is one `Optional` for
// each element type. `Optional[T]` can accept both `T` and
// `None`(`c10::nullopt` in C++)
// Subtype hierarchy for Optional:
// 1. Optional[T] <: Optional[R] iff T <: R
// 2. T <: Optional[R] if T <: R
// 3. None <: Optional[T] for all T
struct TORCH_API OptionalType
: public SingleElementType<TypeKind::OptionalType, OptionalType> {
static OptionalTypePtr create(TypePtr element) {
TORCH_INTERNAL_ASSERT(element, "OptionalType requires valid TypePtr");
// Optional is a union of [None, T], so Optional[[Optional[T]]] ->
// Optional[T]
if (auto opt_ptr = element->cast<OptionalType>()) {
return opt_ptr;
// - Optional[T] <: Optional[R] iff T <: R
// - T <: Optional[R] if T <: R
// - None <: Optional[T] for all T
// - Optional[T] == Union[T, None] for all T
struct TORCH_API OptionalType : public UnionType {
static OptionalTypePtr create(TypePtr contained) {
return OptionalTypePtr(new OptionalType(std::move(contained)));
}
return OptionalTypePtr(
new OptionalType(std::move(element))); // NOLINT(modernize-make-shared)
static const TypeKind Kind = TypeKind::OptionalType;
friend struct Type;
bool operator==(const Type& rhs) const override;
TypePtr getElementType() const {
return contained_;
}
at::ArrayRef<TypePtr> containedTypes() const override {
return contained_;
}
std::string str() const override {
@ -127,20 +189,15 @@ struct TORCH_API OptionalType
return create(contained_types[0]);
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
if (Type::isSubtypeOfExt(rhs, why_not)) {
return true;
}
if (auto rhs_ = rhs->cast<OptionalType>()) {
return getElementType()->isSubtypeOfExt(rhs_->getElementType(), why_not);
}
return false;
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
// common cast Optional[Tensor] for undefined tensor type
static OptionalTypePtr ofTensor();
private:
OptionalType(TypePtr elem) : SingleElementType(elem) {}
explicit OptionalType(TypePtr contained);
TypePtr contained_;
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
@ -908,7 +965,6 @@ struct TORCH_API RRefType
}
};
struct NamedType;
using NamedTypePtr = std::shared_ptr<NamedType>;
using ConstNamedTypePtr = std::shared_ptr<const NamedType>;
@ -1112,7 +1168,6 @@ struct TORCH_API EnumType : public NamedType {
std::weak_ptr<::torch::jit::CompilationUnit> cu_;
};
// the common supertype of all Enums, only used in operator registraion.
// EnumType <: AnyEnumType for all Enums
struct AnyEnumType;
@ -1132,7 +1187,6 @@ private:
: Type(TypeKind::AnyEnumType) {}
};
struct NumberType;
using NumberTypePtr = std::shared_ptr<NumberType>;
// This type represents a Python number
@ -1141,9 +1195,10 @@ using NumberTypePtr = std::shared_ptr<NumberType>;
// FloatType <: NumberType
// ComplexType <:NumberType
struct TORCH_API NumberType : public Type {
bool operator==(const Type& rhs) const override {
return rhs.kind() == kind();
}
bool operator==(const Type& rhs) const override;
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
std::string str() const override {
return "Scalar"; // match what PythonArgParser says for clarity
}
@ -1172,7 +1227,8 @@ struct TORCH_API FloatType : public NumberType {
return "float";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::FloatType;
// global singleton
@ -1196,7 +1252,8 @@ struct TORCH_API ComplexType : public NumberType {
return "complex";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::ComplexType;
// global singleton
@ -1220,7 +1277,8 @@ struct TORCH_API IntType : public NumberType {
return "int";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::IntType;
// global singleton
@ -1334,12 +1392,8 @@ struct TORCH_API NoneType : public Type {
std::string str() const override {
return "NoneType";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override {
if (rhs->kind() == OptionalType::Kind) {
return true;
}
return Type::isSubtypeOfExt(rhs, why_not);
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override;
static const TypeKind Kind = TypeKind::NoneType;
// global singleton
static NoneTypePtr get();
@ -1524,8 +1578,15 @@ TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
// what is the type, ignoring extra size/shape information?
// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
// xxx: be careful with calls because this can be very slow. If calling this on a graph
// use `EraseShapeInformation` in shape_analysis.h
// `unshapedType` is used to remove Tensor subtypes. We treat all Tensor
// subtypes as simply "Tensor"; we also create a new version of any
// container types in which internal Tensors have undergone the same
// operation. This is used for type comparisons between two Tensor types
// (`unshapedType` means that we don't falsely return `false` for e.g.
// Tensors of different dimensions). It's also used in the alias
// analysis pass.
// Be careful with calls because this can be very slow. If calling this
// on a graph, use `EraseShapeInformation` in shape_analysis.h
inline TypePtr unshapedType(const TypePtr& type) {
if (type->isSubtypeOf(TensorType::get())) {
return TensorType::get();
@ -1569,27 +1630,32 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
return *result;
}
// Attempt to find the correct supertype of t1 and t2. If none is found then
// nullopt will be returned if default_to_any is false, and Any will be returned
// if it is true. If t1 == t2, or t1 is a type refinement of t2,
// then t2 will be returned (and vice versa).
// Attempt to find the correct supertype of the two types `t1` and `t2`.
// If no supertype is found, then nullopt will be returned if
// `default_to_union` is false, and `Union[t1, t2]` will be returned
// if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`,
// then `t2` will be returned (and vice versa).
//
// Two different tensortypes will return dynamic.
// Currently we chose not to support returning a NumberType for a float & int
// input because of a lack of operator support for NumberType.
//
// Currently we chose not to support returning a NumberType for
// two types from the set of {FloatType, IntType, ComplexType}, because
// there is a lack of operator support for NumberType.
//
// If `type_hint` is an `InterfaceType`, then we can use that as a
// potential supertype for `ClassType`s in the list. Otherwise, we have
// no way to find and use some common interface type
TORCH_API c10::optional<TypePtr> unifyTypes(
const TypePtr& t1,
const TypePtr& t2,
bool default_to_any = false,
TypePtr type_hint=nullptr);
bool default_to_union = false,
TypePtr type_hint = nullptr);
TORCH_API c10::optional<TypePtr> unifyTypeList(
at::ArrayRef<TypePtr> elements,
std::ostream& why_not,
bool default_to_any=false,
TypePtr type_hint=nullptr);
bool default_to_union = false,
TypePtr type_hint = nullptr);
namespace detail {
template <typename T>

View File

@ -44,7 +44,8 @@ namespace c10 {
_(ScalarTypeType) \
_(AnyListType) \
_(AnyTupleType) \
_(AnyClassType)
_(AnyClassType) \
_(UnionType)
enum class TypeKind {
#define DEFINE_TYPE(T) T,
@ -203,7 +204,7 @@ struct TORCH_API Type : std::enable_shared_from_this<Type> {
// contained_types
TypePtr withContained(std::vector<TypePtr> contained_types) {
auto current_contained = containedTypes();
AT_ASSERT(current_contained.size() == contained_types.size());
TORCH_INTERNAL_ASSERT(current_contained.size() == contained_types.size());
if (current_contained.equals(contained_types)) {
return shared_from_this();
}

View File

@ -265,7 +265,7 @@ AnyEnumTypePtr AnyEnumType::get() {
return value;
}
c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_any=false, TypePtr type_hint=nullptr) {
c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_union=false, TypePtr type_hint=nullptr) {
// check direct subtyping relation
if (t1->isSubtypeOf(t2)) {
return t2;
@ -308,7 +308,7 @@ c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool
}
std::vector<TypePtr> elements;
for (size_t i = 0; i < tuple1->elements().size(); i++) {
if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_any)) {
if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_union)) {
elements.push_back(*elem);
} else {
return c10::nullopt;
@ -347,11 +347,11 @@ c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool
return c10::nullopt;
}
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any, TypePtr type_hint) {
auto unified = unifyTypesImpl(t1, t2, default_to_any, type_hint);
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_union, TypePtr type_hint) {
auto unified = unifyTypesImpl(t1, t2, default_to_union, type_hint);
if (default_to_any && !unified) {
return AnyType::get();
if (default_to_union && !unified) {
return UnionType::create({t1, t2});
}
return unified;
@ -360,7 +360,7 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool def
c10::optional<TypePtr> unifyTypeList(
at::ArrayRef<TypePtr> elements,
std::ostream& why_not,
bool default_to_any,
bool default_to_union,
TypePtr type_hint) {
if (elements.size() == 0) {
why_not << "Cannot get unified type from empty list";
@ -369,7 +369,7 @@ c10::optional<TypePtr> unifyTypeList(
TypePtr ret_type = elements.at(0);
for (size_t i = 1; i < elements.size() && ret_type; ++i) {
c10::optional<TypePtr> maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_any, type_hint);
c10::optional<TypePtr> maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_union, type_hint);
if (!maybe_unified) {
why_not << "Could not unify type list since element " << i << " of type "
<< elements.at(i)->repr_str()
@ -547,8 +547,9 @@ TORCH_API TypePtr tryEvalTypeVariables(TypePtr type, std::unordered_map<std::str
}
TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) {
if (elem_type->kind() == OptionalType::Kind ||
elem_type->kind() == NumberType::Kind) {
if (elem_type->kind() == UnionType::Kind
|| elem_type->kind() == OptionalType::Kind
|| elem_type->kind() == NumberType::Kind) {
// Builtin Union types
return false;
}
@ -577,8 +578,16 @@ bool Type::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
if (rhs->kind() == TypeKind::AnyType || *this == *rhs) {
return true;
}
if(auto rhs_ = rhs->cast<OptionalType>()) {
return this->isSubtypeOfExt(rhs_->getElementType(), why_not);
if (auto opt_rhs = rhs->cast<OptionalType>()) {
return this->isSubtypeOfExt(opt_rhs->getElementType(), why_not);
}
if (auto union_rhs = rhs->cast<UnionType>()) {
// Check if `this` is a subtype of any of the types within the Union
return std::any_of(union_rhs->containedTypes().begin(),
union_rhs->containedTypes().end(),
[&](TypePtr inner) {
return this->isSubtypeOfExt(inner, why_not);
});
}
return false;
}
@ -808,6 +817,453 @@ TupleTypePtr TupleType::createNamed(const c10::optional<c10::QualifiedName>& qua
field_types, qualName, schema)); // NOLINT(modernize-make-shared)
}
bool NoneType::isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const {
if (rhs->kind() == OptionalType::Kind) {
return true;
}
return Type::isSubtypeOfExt(rhs, why_not);
}
// Remove nested Optionals/Unions during the instantiation of a Union or
// an Optional. This populates `types` with all the types found during
// flattening. At the end of `flattenUnion`, `types` may have
// duplicates, but it will not have nested Optionals/Unions
void flattenUnion(TypePtr& type, std::vector<TypePtr>* to_fill) {
if (auto union_type = type->cast<UnionType>()) {
for (auto inner : union_type->containedTypes()) {
flattenUnion(inner, to_fill);
}
} else if (auto opt_type = type->cast<OptionalType>()) {
auto inner = opt_type->getElementType();
flattenUnion(inner, to_fill);
to_fill->emplace_back(NoneType::get());
} else if (type->kind() == NumberType::Kind) {
to_fill->emplace_back(IntType::get());
to_fill->emplace_back(FloatType::get());
to_fill->emplace_back(ComplexType::get());
} else {
to_fill->emplace_back(type);
}
}
// Helper function for `standardizeUnion`
//
// NB: If we have types `T1`, `T2`, `T3`, and `PARENT_T` such that `T1`,
// `T2`, and `T2` are children of `PARENT_T`, then `unifyTypes(T1, T2)`
// will return `PARENT_T`. This could be a problem if we didn't want our
// Union to also be able to take `T3 `. In our current type hierarchy,
// this isn't an issue--most types SHOULD be unified even if the parent
// type wasn't in the original vector. However, later additions to the
// type system might necessitate reworking `get_supertype`
void filterDuplicateSubtypes(std::vector<TypePtr>* types) {
if (types->empty()) {
return;
}
auto get_supertype = [](const TypePtr t1, const TypePtr t2) -> c10::optional<TypePtr> {
// We don't want nested Optionals. Also, prematurely unifying to
// `Optional` could prevent us from coalescing other types
if ((t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get()))
|| (!t1->isSubtypeOf(NoneType::get()) && t2->isSubtypeOf(NoneType::get()))) {
return c10::nullopt;
} else {
return unifyTypes(t1, t2, /*default_to_union=*/false);
}
};
// Coalesce types and delete all duplicates. Moving from right to left
// through the vector, we try to unify the current element (`i`) with
// each element (`j`) before the "new" end of the vector (`end`).
// If we're able to unify the types at `types[i]` and `types[j]`, we
// decrement `end`, swap `types[j]` with the unified type, and
// break. Otherwise, we keep `end` where it is to signify that the
// new end of the vector hasn't shifted
size_t end_idx = types->size()-1;
for (size_t i = types->size()-1; i > 0; --i) {
for (size_t j = std::min(i-1, end_idx); ; --j) {
c10::optional<TypePtr> unified;
unified = get_supertype((*types)[i], (*types)[j]);
if (unified) {
(*types)[j] = *unified;
(*types)[i] = (*types)[end_idx];
--end_idx;
break;
}
// Break condition here so we don't get `j = 0; j = j-1` and end
// up with MAX_INT
if (j == 0) {
break;
}
}
}
// Cut off the vector's tail so that `end` is the real last element
types->erase(types->begin() + end_idx + 1, types->end());
}
void sortUnion(std::vector<TypePtr>* types) {
// We want the elements to be sorted so we can easily compare two
// UnionType objects for equality in the future. Note that this order
// is guaranteed to be stable since we've already coalesced any
// possible types
std::sort(types->begin(), types->end(),
[](const TypePtr a, const TypePtr b) -> bool {
if (a->kind() != b->kind()) {
return a->kind() < b->kind();
}
return a->str() < b->str();
});
}
void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill) {
for (auto type : reference) {
flattenUnion(type, to_fill);
}
filterDuplicateSubtypes(to_fill);
sortUnion(to_fill);
}
void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten) {
TORCH_INTERNAL_ASSERT(to_flatten, "`standardizeVectorForUnion` was ",
"passed a `nullptr`");
std::vector<TypePtr> to_fill;
standardizeVectorForUnion(*to_flatten, &to_fill);
*to_flatten = to_fill;
}
UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : Type(kind) {
TORCH_INTERNAL_ASSERT(!reference.empty(), "Cannot create an empty Union");
standardizeVectorForUnion(reference, &types_);
// Gate the assert in a regular conditional so that we don't create
// this long error message unnecessarily
if (types_.size() == 1) {
std::stringstream msg;
msg << "After type unification was performed, the Union with the "
<< "original types {";
for (auto i = 0; i < reference.size(); ++i) {
msg << reference[i]->repr_str();
if (i > 0) {
msg << ",";
}
msg << " ";
}
msg << "} has the single type " << types_[0]->repr_str()
<< ". Use the common supertype instead of creating a Union"
<< "type";
TORCH_INTERNAL_ASSERT(false, msg.str());
}
can_hold_none_ = false;
has_free_variables_ = false;
for (const TypePtr& type : types_) {
if (type->kind() == NoneType::Kind) {
can_hold_none_ = true;
}
if (type->hasFreeVariables()) {
has_free_variables_ = true;
}
}
}
UnionTypePtr UnionType::create(std::vector<TypePtr> reference) {
auto union_type = new UnionType(std::move(reference));
// Some very special-cased logic for `Optional`. This will be deleted
// in a later PR
bool int_found = false;
bool float_found = false;
bool complex_found = false;
bool nonetype_found = false;
auto update_is_opt_flags = [&](TypePtr t) {
if (t == IntType::get()) {
int_found = true;
} else if (t == FloatType::get()) {
float_found = true;
} else if (t == ComplexType::get()) {
complex_found = true;
} else if (t == NoneType::get()) {
nonetype_found = true;
}
};
for (const auto& t : union_type->containedTypes()) {
update_is_opt_flags(t);
}
bool numbertype_found = int_found && float_found && complex_found;
if (nonetype_found) {
if (union_type->containedTypes().size() == 4 && numbertype_found) {
return OptionalType::create(NumberType::get());
}
if (union_type->containedTypes().size() == 2) {
auto not_none = union_type->containedTypes()[0] != NoneType::get()
? union_type->containedTypes()[0]
: union_type->containedTypes()[1];
return OptionalType::create(not_none);
}
}
return UnionTypePtr(union_type);
}
bool UnionType::operator==(const Type& rhs) const {
if (auto union_rhs = rhs.cast<UnionType>()) {
// We can't compare the type vectors for equality using `operator=`,
// because the vectors hold `TypePtr`s and we want to compare `Type`
// equality
if (union_rhs->containedTypes().size() != this->containedTypes().size()) {
return false;
}
// Check that all the types in `this->types_` are also in
// `union_rhs->types_`
return std::all_of(this->containedTypes().begin(), this->containedTypes().end(),
[&](TypePtr lhs_type) {
return std::any_of(union_rhs->containedTypes().begin(),
union_rhs->containedTypes().end(),
[&](TypePtr rhs_type) {
return *lhs_type == *rhs_type;
});
});
} else if (auto optional_rhs = rhs.cast<OptionalType>()) {
if (optional_rhs->getElementType() == NumberType::get()) {
return this->containedTypes().size() == 4
&& this->can_hold_none_
&& this->canHoldType(NumberType::get());
}
auto optional_lhs = this->toOptional();
return optional_lhs && *optional_rhs == *((optional_lhs.value())->expect<OptionalType>());
} else if (rhs.kind() == NumberType::Kind) {
return this->containedTypes().size() == 3 && canHoldType(NumberType::get());
} else {
return false;
}
}
bool UnionType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
std::vector<TypePtr> rhs_types;
if (const auto union_rhs = rhs->cast<UnionType>()) {
// Fast path
if (this->containedTypes() == rhs->containedTypes()) {
return true;
}
rhs_types = rhs->containedTypes().vec();
} else if (const auto optional_rhs = rhs->cast<OptionalType>()) {
rhs_types.push_back(NoneType::get());
if (optional_rhs->getElementType() == NumberType::get()) {
std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
} else {
rhs_types.push_back(optional_rhs->getElementType());
}
} else if (const auto number_rhs = rhs->cast<NumberType>()) {
std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
} else {
rhs_types.push_back(rhs);
}
return std::all_of(this->containedTypes().begin(), this->containedTypes().end(),
[&](TypePtr lhs_type) -> bool {
return std::any_of(rhs_types.begin(),
rhs_types.end(),
[&](TypePtr rhs_type) -> bool {
return lhs_type->isSubtypeOfExt(rhs_type, why_not);
});
});
}
std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) const {
std::stringstream ss;
bool can_hold_numbertype = this->canHoldType(NumberType::get());
std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
auto is_numbertype = [&](TypePtr lhs) {
for (const auto& rhs : number_types) {
if (*lhs == *rhs) {
return true;
}
}
return false;
};
ss << "Union[";
bool printed = false;
for (size_t i = 0; i < types_.size(); ++i) {
if (!can_hold_numbertype || !is_numbertype(types_[i])) {
if (i > 0) {
ss << ", ";
printed = true;
}
if (is_annotation_str) {
ss << this->containedTypes()[i]->annotation_str(printer);
} else {
ss << this->containedTypes()[i]->str();
}
}
}
if (can_hold_numbertype) {
if (printed) {
ss << ", ";
}
if (is_annotation_str) {
ss << NumberType::get()->annotation_str(printer);
} else {
ss << NumberType::get()->str();
}
}
ss << "]";
return ss.str();
}
std::string UnionType::str() const {
return this->unionStr(nullptr, /*is_annotation_str=*/false);
}
std::string UnionType::annotation_str_impl(TypePrinter printer) const {
return this->unionStr(printer, /*is_annotation_str=*/true);
}
bool UnionType::canHoldType(TypePtr type) const {
if (type == NumberType::get()) {
return canHoldType(IntType::get())
&& canHoldType(FloatType::get())
&& canHoldType(ComplexType::get());
} else {
return std::any_of(this->containedTypes().begin(), this->containedTypes().end(),
[&](TypePtr inner) {
return type->isSubtypeOf(inner);
});
}
}
c10::optional<TypePtr> UnionType::toOptional() const {
if (!canHoldType(NoneType::get())) {
return c10::nullopt;
}
std::vector<TypePtr> copied_types = this->containedTypes().vec();
auto maybe_opt = UnionType::create(std::move(copied_types));
if (maybe_opt->kind() == UnionType::Kind) {
return c10::nullopt;
} else {
return maybe_opt;
}
}
c10::optional<TypePtr> UnionType::subtractTypeSet(std::vector<TypePtr>& to_subtract) const {
std::vector<TypePtr> types;
// Given a TypePtr `lhs`, this function says whether or not `lhs` (or
// one of its parent types) is in the `to_subtract` vector
auto should_subtract = [&](TypePtr lhs) -> bool {
return std::any_of(to_subtract.begin(), to_subtract.end(),
[&](TypePtr rhs) {
return lhs->isSubtypeOf(rhs);
});
};
// Copy all the elements that should NOT be subtracted to the `types`
// vector
std::copy_if(this->containedTypes().begin(), this->containedTypes().end(),
std::back_inserter(types),
[&](const TypePtr t) {
return !should_subtract(t);
});
if (types.size() == 0) {
return c10::nullopt;
} else if (types.size() == 1) {
return types[0];
} else {
return UnionType::create(std::move(types));
}
}
OptionalType::OptionalType(TypePtr contained)
: UnionType({contained, NoneType::get()}, TypeKind::OptionalType) {
bool is_numbertype = false;
if (auto as_union = contained->cast<UnionType>()) {
is_numbertype = as_union->containedTypes().size() == 3 &&
as_union->canHoldType(NumberType::get());
}
if (UnionType::containedTypes().size() == 2) {
contained_ = UnionType::containedTypes()[0]->kind()!= NoneType::Kind
? UnionType::containedTypes()[0]
: UnionType::containedTypes()[1];
} else if (contained == NumberType::get() || is_numbertype) {
contained_ = NumberType::get();
types_.clear();
types_.push_back(NumberType::get());
types_.push_back(NoneType::get());
} else {
std::vector<TypePtr> to_subtract{NoneType::get()};
auto without_none = this->subtractTypeSet(to_subtract);
contained_ = UnionType::create({*without_none});
}
has_free_variables_ = contained_->hasFreeVariables();
}
bool OptionalType::operator==(const Type& rhs) const {
if (auto union_rhs = rhs.cast<UnionType>()) {
auto optional_rhs = union_rhs->toOptional();
// `**optional_rhs` = `*` to get value of `c10::optional<TypePtr>`,
// then `*` to dereference the pointer
return optional_rhs && *this == **optional_rhs;
} else if (auto optional_rhs = rhs.cast<OptionalType>()) {
return *this->getElementType() == *optional_rhs->getElementType();
} else {
return false;
}
}
bool OptionalType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
if (OptionalTypePtr optional_rhs = rhs->cast<OptionalType>()) {
return getElementType()->isSubtypeOfExt(optional_rhs->getElementType(), why_not);
} else if (UnionTypePtr union_rhs = rhs->cast<UnionType>()) {
if (!union_rhs->canHoldType(NoneType::get())) {
if (why_not) {
*why_not << rhs->repr_str() << " cannot hold None";
}
return false;
} else if (!union_rhs->canHoldType(this->getElementType())) {
if (why_not) {
*why_not << rhs->repr_str() << " cannot hold " << this->getElementType();
}
return false;
} else {
return true;
}
} else {
// NOLINTNEXTLINE(bugprone-argument-comment)
return Type::isSubtypeOfExt(rhs, why_not);
}
}
bool NumberType::operator==(const Type& rhs) const {
if (auto union_type = rhs.cast<UnionType>()) {
return union_type->containedTypes().size() == 3 && union_type->canHoldType(NumberType::get());
} else {
return rhs.kind() == this->kind();
}
}
bool NumberType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
if (auto union_type = rhs->cast<UnionType>()) {
return union_type->canHoldType(NumberType::get());
} else {
return Type::isSubtypeOfExt(rhs, why_not);
}
}
TupleType::TupleType(
std::vector<TypePtr> elements,
c10::optional<c10::QualifiedName> name,
@ -1732,8 +2188,10 @@ size_t ClassType::addAttribute(
TORCH_CHECK(
(type->kind() == TensorType::Kind) ||
(type->kind() == OptionalType::Kind &&
type->expectRef<OptionalType>().getElementType()->kind() ==
type->expect<OptionalType>()->getElementType()->kind() ==
TensorType::Kind) ||
(type->kind() == UnionType::Kind &&
TensorType::get()->isSubtypeOf(type->expect<UnionType>())) ||
(type->kind() == NoneType::Kind),
"Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ",
toString(type));
@ -1880,7 +2338,9 @@ void SymbolicShape::dump() const {
bool EnumType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
return rhs->kind() == TypeKind::AnyType ||
rhs->kind() == TypeKind::AnyEnumType || *this == *rhs;
rhs->kind() == TypeKind::AnyEnumType ||
*this == *rhs ||
Type::isSubtypeOfExt(rhs, why_not);
}
} // namespace c10

View File

@ -7,8 +7,8 @@ Like all ATen methods/functions, native functions are made available
from both ATen's C++ and Python APIs. In C++, they are made available
either as methods on `Tensor` (`t.mymeth()`) and functions in the ATen
namespace (`at::myfunc()`). In PyTorch, they are made available as
methods on `Variable` or as functions on `torch._C._FunctionBase`
(it is the user's responsibility to re-exporting these functions in
methods on `Variable` or as functions on `torch._C._FunctionBase`.
(It is the user's responsibility to re-export these functions in
a more user-facing module.)
The rest of this document describes how to implement an ATen function.

View File

@ -50,7 +50,7 @@ class C10_API AllocationPlanner {
private:
AllocationPlan* allocation_plan_{nullptr};
// Maps allocated ptr to its allocation id.
// This is used when freeing the memory to lookup the allocation id
// This is used when freeing the memory to look up the allocation id
// in order to establish the lifetime of a particular allocation.
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
uint64_t allocation_id_{0};

View File

@ -65,7 +65,7 @@ an RPC.
input tensors. The output gradients of this function are sent to the source
node to the appropriate ``send`` function during the backward pass.
- Each ``send-recv`` pair is assigned a globally unique ``autograd_message_id``
to uniquely identify the pair. This is useful to lookup the corresponding
to uniquely identify the pair. This is useful to look up the corresponding
function on a remote node during the backward pass.
- For :ref:`rref`, whenever we call :meth:`torch.distributed.rpc.RRef.to_here`
we attach an appropriate ``send-recv`` pair for the tensors involved.
@ -98,7 +98,7 @@ This context serves the following purpose:
2. During the forward pass we store the ``send`` and ``recv`` functions for
each autograd pass in this context. This ensures we hold references to the
appropriate nodes in the autograd graph to keep it alive. In addition to
this, it is easy to lookup the appropriate ``send`` and ``recv`` functions
this, it is easy to look up the appropriate ``send`` and ``recv`` functions
during the backward pass.
3. In general we also use this context to store some metadata for each
distributed autograd pass.

View File

@ -66,6 +66,7 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_subgraph_matcher.cpp
${JIT_TEST_ROOT}/test_subgraph_rewriter.cpp
${JIT_TEST_ROOT}/test_subgraph_utils.cpp
${JIT_TEST_ROOT}/test_union.cpp
${JIT_TEST_ROOT}/test_utils.cpp
${JIT_TEST_ROOT}/test_script_profile.cpp
${JIT_TEST_ROOT}/test_jit_logging_levels.cpp

View File

@ -660,6 +660,31 @@ TEST(ContainerAliasingTest, PrimitveValuesDontAliasContainers) {
}
}
TEST(ContainerAliasingTest, UnionAliasing) {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%a : Dict(str, Tensor),
%b : Tensor[],
%c : Union(Dict(str, Tensor), Tensor[])):
return (%a, %b, %c)
)IR",
&*graph);
AliasDb aliasDb(graph);
auto a = graph->outputs().at(0);
auto b = graph->outputs().at(1);
auto c = graph->outputs().at(2);
EXPECT_TRUE(aliasDb.mayAlias(a, c));
EXPECT_TRUE(aliasDb.mayAlias(b, c));
EXPECT_TRUE(aliasDb.mayAlias(c, c));
EXPECT_FALSE(aliasDb.mayAlias(a, b));
EXPECT_TRUE(aliasDb.mayContainAlias(a, b));
EXPECT_TRUE(aliasDb.mayContainAlias(a, c));
EXPECT_TRUE(aliasDb.mayContainAlias(b, c));
}
TEST(ContainerAliasingTest, InputsCanAliasOutputs) {
// Test input aliasing
auto graph = std::make_shared<Graph>();

149
test/cpp/jit/test_union.cpp Normal file
View File

@ -0,0 +1,149 @@
#include <gtest/gtest.h>
#include <ATen/core/jit_type.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
class UnionTypeTest : public ::testing::Test {
public:
// None
const TypePtr none = NoneType::get();
// List[str]
const TypePtr l1 = ListType::ofStrings();
// Optional[int]
const TypePtr opt1 = OptionalType::create(IntType::get());
// Optional[float]
const TypePtr opt2 = OptionalType::create(FloatType::get());
// Optional[List[str]]
const TypePtr opt3 = OptionalType::create(ListType::ofStrings());
// Tuple[Optional[int], int]
const TypePtr tup1 =
TupleType::create({OptionalType::create(IntType::get()), IntType::get()});
// Tuple[int, int]
const TypePtr tup2 = TupleType::create({IntType::get(), IntType::get()});
bool hasType(UnionTypePtr u, TypePtr t) {
auto res = std::find(u->getTypes().begin(), u->getTypes().end(), t);
return res != u->getTypes().end();
}
};
TEST_F(UnionTypeTest, UnionOperatorEquals) {
const UnionTypePtr u1 = UnionType::create({l1, tup2, StringType::get()});
// Same thing, but using different TypePtrs
const TypePtr l1_ = ListType::ofStrings();
const TypePtr tup2_ = TupleType::create({IntType::get(), IntType::get()});
const UnionTypePtr u2 = UnionType::create({l1_, tup2_, StringType::get()});
ASSERT_TRUE(*u1 == *u2);
}
TEST_F(UnionTypeTest, UnionCreate_OptionalT1AndOptionalT2) {
// Goal: Union[int, float, None]
const UnionTypePtr u = UnionType::create({opt1, opt2});
ASSERT_EQ(u->getTypes().size(), 3);
ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, FloatType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
}
TEST_F(UnionTypeTest, UnionCreate_OptionalTAndT) {
// Goal: Union[int, None]
const UnionTypePtr u = UnionType::create({opt1, IntType::get()});
ASSERT_EQ(u->getTypes().size(), 2);
ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
}
TEST_F(UnionTypeTest, UnionCreate_TupleWithSubtypingRelationship) {
// Goal: Union[Tuple[Optional[int], int], str]
const UnionTypePtr u = UnionType::create({StringType::get(), tup1, tup2});
ASSERT_EQ(u->getTypes().size(), 2);
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, tup1));
}
TEST_F(UnionTypeTest, UnionCreate_ContainerTAndT) {
// Goal: Union[List[str], str]
const UnionTypePtr u = UnionType::create({l1, StringType::get()});
ASSERT_EQ(u->getTypes().size(), 2);
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
}
TEST_F(UnionTypeTest, UnionCreate_OptionalContainerTAndContainerTAndT) {
// Goal: Union[List[str], None, str]
const UnionTypePtr u = UnionType::create({l1, opt3, StringType::get()});
ASSERT_EQ(u->getTypes().size(), 3);
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
}
TEST_F(UnionTypeTest, Subtyping_NumberType) {
// Union[int, float, Complex]
const UnionTypePtr union1 =
UnionType::create({IntType::get(), FloatType::get(), ComplexType::get()});
// Union[int, float, Complex, None]
const UnionTypePtr union2 = UnionType::create(
{IntType::get(), FloatType::get(), ComplexType::get(), NoneType::get()});
const NumberTypePtr num = NumberType::get();
ASSERT_TRUE(num->isSubtypeOf(union1));
ASSERT_TRUE(union1->isSubtypeOf(num));
ASSERT_TRUE(*num == *union1);
ASSERT_TRUE(num->isSubtypeOf(union2));
ASSERT_FALSE(union2->isSubtypeOf(num));
ASSERT_FALSE(*num == *union2);
}
TEST_F(UnionTypeTest, Subtyping_OptionalType) {
// Union[int, None]
const UnionTypePtr union1 =
UnionType::create({IntType::get(), NoneType::get()});
// Union[int, str, None]
const UnionTypePtr union2 =
UnionType::create({IntType::get(), StringType::get(), NoneType::get()});
// Union[int, str, List[str]]
const UnionTypePtr union3 = UnionType::create(
{IntType::get(), StringType::get(), ListType::ofStrings()});
ASSERT_TRUE(none->isSubtypeOf(opt1));
ASSERT_TRUE(none->isSubtypeOf(union1));
ASSERT_TRUE(none->isSubtypeOf(union2));
ASSERT_FALSE(none->isSubtypeOf(union3));
ASSERT_FALSE(opt1->isSubtypeOf(none));
ASSERT_TRUE(opt1->isSubtypeOf(union1));
ASSERT_TRUE(opt1->isSubtypeOf(union2));
ASSERT_FALSE(opt1->isSubtypeOf(union3));
ASSERT_FALSE(union1->isSubtypeOf(none));
ASSERT_TRUE(union1->isSubtypeOf(opt1));
ASSERT_TRUE(union1->isSubtypeOf(union2));
ASSERT_FALSE(union1->isSubtypeOf(union3));
ASSERT_FALSE(union2->isSubtypeOf(union1));
}
} // namespace jit
} // namespace torch

View File

@ -92,7 +92,7 @@ class TestList(JitTestCase):
if 1 == 1:
x = [1, 2, 3]
return
with self.assertRaisesRegexWithHighlight(RuntimeError, r"previously has type List\[Tensor\]", "x"):
with self.assertRaisesRegexWithHighlight(RuntimeError, r"previously had type List\[Tensor\]", "x"):
self.checkScript(reassign_from_empty_literal, (), optimize=False)
def reassign_from_empty_builtin():
@ -113,7 +113,7 @@ class TestList(JitTestCase):
if 1 == 1:
x = [1.0]
return
with self.assertRaisesRegexWithHighlight(RuntimeError, "previously has type", "x"):
with self.assertRaisesRegexWithHighlight(RuntimeError, "previously had type", "x"):
self.checkScript(reassign_bad_type, (), optimize=False)
def reassign_nested():
@ -123,7 +123,7 @@ class TestList(JitTestCase):
if 1 == 1:
x = [1.0]
return
with self.assertRaisesRegexWithHighlight(RuntimeError, "previously has type", "x"):
with self.assertRaisesRegexWithHighlight(RuntimeError, "previously had type", "x"):
self.checkScript(reassign_nested, (), optimize=False)
def test_del(self):

View File

@ -92,10 +92,9 @@ class TestTyping(JitTestCase):
graph = torch.jit.script(fn).graph
print(graph)
# Check that we're making a `List[Tuple[str, Any]]`
FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph)
FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])"
"[] = prim::ListConstruct()").run(graph)
def test_list_type_refinement_defaults_to_Any_list_comprehension(self):
def fn(x):
@ -116,10 +115,9 @@ class TestTyping(JitTestCase):
graph = torch.jit.script(fn).graph
print(graph)
# Check that we're making a `List[Tuple[str, Any]]`
FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph)
FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])"
"[] = prim::ListConstruct()").run(graph)
def test_list_type_refinement_annotation_element_mismatch(self):
def fn():
@ -145,7 +143,8 @@ class TestTyping(JitTestCase):
graph = torch.jit.script(fn).graph
FileCheck().check(r"Dict(str, Any) = prim::DictConstruct").run(graph)
FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])"
" = prim::DictConstruct").run(graph)
def test_dict_type_refinement_defaults_to_Any_dict_comprehension(self):
def fn(x):
@ -161,7 +160,8 @@ class TestTyping(JitTestCase):
graph = torch.jit.script(fn).graph
FileCheck().check("Dict(str, Any) = prim::DictConstruct").run(graph)
FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])"
" = prim::DictConstruct").run(graph)
def test_dict_type_refinement_annotation_key_mismatch(self):
def fn():

657
test/jit/test_union.py Normal file
View File

@ -0,0 +1,657 @@
import io
import os
import sys
import torch
from torch.testing import FileCheck
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestUnion(JitTestCase):
"""
This class tests the functionality of `Union`.
Note: It's important to be able to refine the type of a `Union` to
one of its internal types. Currently, there are differences in the
way Python expects `isinstance` checks and the way TorchScript
expects `isinstance` checks. This means that we can't use
`checkScript` in our test cases because either the eager mode or the
script mode wouldn't run! So, some test cases have separate but
equivalent functions to emulate `checkScript`.
"""
def test_union_with_scalar_values(self):
def fn(x: Union[int, float]) -> str:
return "foo"
self.checkScript(fn, (1,))
self.checkScript(fn, (1.0,))
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[float, int\] but "
"instead found type str"):
scripted("1")
def test_union_with_collections(self):
def fn(x: Union[Dict[str, int], List[int]]) -> str:
return "foo"
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
self.checkScript(fn, ([1, 2, 3],))
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"Dict\[str, str\]"):
scripted({"foo": "bar", "baz": "qux"})
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"List\[str\]"):
scripted(["foo", "bar", "baz"])
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
"str"):
scripted("1")
def test_union_with_enum(self):
class Color(Enum):
RED = 1
GREEN = 2
make_global(Color)
def fn(x: Union[str, Color]) -> str:
return "foo"
self.checkScript(fn, (Color.RED,))
self.checkScript(fn, ("red",))
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[__torch__.jit.test_union."
r"Color, str\] but instead found "
"type int"):
scripted(1)
def test_union_in_class_constructor(self):
@torch.jit.script
class A(object): # noqa: B903
def __init__(self, x: Union[int, str]) -> None:
self.x = x
def fn(x: Union[str, int]) -> A:
return A(x)
self.assertEqual(fn("foo").x, "foo")
self.assertEqual(fn(1).x, 1)
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[int, str\] but instead "
r"found type List\[str\]"):
scripted(["foo", "bar", "baz"])
def test_union_return_type(self):
def fn(x: int) -> Union[int, str]:
return "foo"
self.checkScript(fn, (1,))
def test_union_as_annotation(self):
def fn() -> Union[int, str]:
x: Union[int, str] = "foo"
return x
self.checkScript(fn, ())
def test_union_as_annotation_in_typed_container(self):
def fn() -> None:
l: List[Union[int, str]] = []
u1: Union[int, str] = "foo"
u2: Union[int, str] = 1
l.append(u1)
l.append(u2)
self.checkScript(fn, ())
def test_union_as_annotation_py2(self):
def fn():
# type: () -> Union[int, str]
x: Union[int, str] = "foo"
return x
self.checkScript(fn, ())
def test_union_as_internal_tuple_type(self):
def fn():
t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
return t
self.checkScript(fn, ())
def test_union_variable_can_be_reassigned(self):
@torch.jit.script
def aux1(i: int):
return int(i ** 2)
@torch.jit.script
def aux2(s: str):
return s + s
def fn() -> Union[int, str]:
x: Union[int, str] = "foo"
i: int = 1
x = i
y: int = aux1(x)
z: str = aux2(str(y))
x = z
return x
self.checkScript(fn, ())
def test_union_does_not_replace_existing_annotated_type(self):
def fn():
x: List[int] = [1, 2, 3]
x.append("foo")
return x
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
scripted = torch.jit.script(fn)
scripted()
def test_union_does_not_replace_existing_annotated_type_union(self):
def fn():
x: List[Union[int, str]] = [1, "foo", 3]
x.append(2.0)
return x
with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
scripted = torch.jit.script(fn)
scripted()
def test_union_does_not_replace_existing_annotated_type_empty_container(self):
def fn():
x: List[int] = []
x.append("foo")
return x
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
scripted = torch.jit.script(fn)
scripted()
def test_unions_of_unions_are_flattened(self):
@torch.jit.script
def fn(x: Union[Union[int, str], float]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union[float, int, str]") \
.run(s)
def test_unions_of_a_single_argument_vanish(self):
@torch.jit.script
def fn(x: Union[int]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : int") \
.run(s)
def test_union_redundant_arguments_are_skipped(self):
@torch.jit.script
def fn(x: Union[int, str, int]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union[int, str]") \
.run(s)
def test_union_redundant_arguments_are_skipped_optional(self):
@torch.jit.script
def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union[float, int, NoneType]") \
.run(s)
def test_union_redundant_arguments_are_skipped_subtyping(self):
@torch.jit.script
def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union[(int?, int), str]") \
.run(s)
def test_union_redundant_arguments_are_skipped_container(self):
@torch.jit.script
def fn(x: Union[List[str], List[float], List[str]]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union[float[], str[]]") \
.run(s)
def test_union_argument_order_is_ignored(self):
@torch.jit.script
def fn1(x: Union[int, str]) -> str:
return "foo"
@torch.jit.script
def fn2(x: Union[str, int]) -> str:
return "foo"
for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union[int, str]") \
.run(s)
def test_union_argument_order_is_ignored_container(self):
@torch.jit.script
def fn1(x: Union[List[str], List[int]]) -> str:
return "foo"
@torch.jit.script
def fn2(x: Union[List[int], List[str]]) -> str:
return "foo"
for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union[int[], str[]]") \
.run(s)
def test_union_T_None_is_equivalent_to_optional_T(self):
@torch.jit.script
def inner(x: Union[int, None]) -> int:
if x is not None:
return x
else:
return 5
@torch.jit.script
def fn1() -> int:
a: Optional[int] = 5
b: Optional[int] = None
a_ = inner(a)
b_ = inner(b)
return a_ + b_
self.assertEqual(fn1(), 10)
@torch.jit.script
def inner2(x: Optional[int]) -> int:
if x is not None:
return x
else:
return 5
@torch.jit.script
def fn2() -> int:
a: Union[int, None] = 5
b: Union[int, None] = None
a_ = inner(a)
b_ = inner(b)
return a_ + b_
self.assertEqual(fn2(), 10)
def test_union_optional_of_union_is_flattened(self):
@torch.jit.script
def fn(flag: int) -> Union[str, int, None]:
y: Union[int, str, None] = "foo"
if flag == 0:
x: Optional[Union[int, str]] = y
elif flag == 1:
x: Optional[Union[int, str]] = 1
else:
x: Optional[Union[int, str]] = None
return x
# Can't use `checkScript` because it will flag the fact that
# the original code has `Optional[Union[int, str]]` but the
# saved/loaded code has `Union[int, NoneType, str]` (even
# though this is exactly what we want)
self.assertEqual(fn(0), "foo")
self.assertEqual(fn(1), 1)
self.assertEqual(fn(2), None)
buffer = io.BytesIO()
torch.jit.save(fn, buffer)
buffer = io.BytesIO(buffer.getvalue())
l = torch.jit.load(buffer)
s = l.code
FileCheck().check("Union[int, NoneType, str]") \
.check("Union[int, NoneType, str]") \
.run(s)
def test_union_subclasses_larger_union(self):
def fn() -> Union[int, str, torch.Tensor]:
x: Union[int, str] = "foo"
return x
self.checkScript(fn, ())
# TODO: We would like to eventually support this. The issue is being
# tracked at https://github.com/pytorch/pytorch/issues/58167
def test_union_as_dict_key(self):
def fn():
x: Dict[Union[int, str], str] = {}
x["foo"] = "bar"
x[1] = 2
return x[1]
with self.assertRaisesRegex(RuntimeError, "only int, float, "
"complex, Tensor and string keys "
"are supported"):
torch.jit.script(fn)
def test_union_as_dict_value(self):
def fn():
x: Dict[str, Union[int, str]] = {}
x["foo"] = "bar"
x["baz"] = 2
return x["baz"]
self.checkScript(fn, ())
def test_union_module_with_union_instance_variable(self):
class M(torch.nn.Module):
x: Union[int, str]
def __init__(self, x: Union[int, str]):
super().__init__()
self.x: Union[int, str] = x
def forward(self, y: Union[int, str]):
self.x = y
return self.x
self.checkModule(M(2,), (1,))
self.checkModule(M("bar"), ("foo",))
def test_union_module_with_union_class_variable(self):
class M(torch.nn.Module):
x: Union[int, str] = "foo"
def __init__(self, y: int):
super().__init__()
x = y
def forward(self, z: str):
x = z
return x
self.checkModule(M(1), ("foo",))
def test_union_type_refinement(self):
def fn(x: Union[int, str]) -> str:
if isinstance(x, str):
z = x + "bar"
return x
else:
return "baz"
self.checkScript(fn, ("foo",))
self.checkScript(fn, (1,))
def test_union_type_refinement_union_rhs(self):
def fn(x: int) -> str:
if torch.jit.isinstance(x, Union[int, str]):
return "bar"
else:
return "baz"
self.checkScript(fn, (1,))
def test_union_type_refinement_tuple_rhs(self):
def fn(x: Union[int, float, List[str]]) -> str:
if isinstance(x, (int, float)):
if isinstance(x, int):
return str(x)
else:
return "foo"
else:
if len(x):
return x[0]
else:
return "bar"
self.checkScript(fn, (1,))
self.checkScript(fn, (1.0,))
self.checkScript(fn, (["a", "b", "c"],))
def test_union_type_refinement_tuple_rhs_noncontained_type(self):
def fn(x: Union[int, List[str]]) -> str:
if isinstance(x, (int, float)):
y = x + x
return str(y)
else:
if len(x):
return x[0]
else:
return "bar"
self.checkScript(fn, (1,))
self.checkScript(fn, (["a", "b", "c"],))
def test_union_type_refinement_tuple_rhs_union(self):
@torch.jit.script
def fn(x: int) -> str:
if torch.jit.isinstance(x, (Union[int, str], float)):
y = x + x
return str(y)
else:
return "foo"
# TODO: There's currently an unrelated bug in
# `torch.jit.isinstance` that makes it fail for tuple literals.
# Posted here: https://github.com/pytorch/pytorch/issues/60095
# Change `assertEqual` to `checkScript` when the bug is fixed
self.assertEqual(fn(1), "2")
def test_union_type_refinement_statically_false(self):
@torch.jit.script
def fn(x: int) -> str:
if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
z = x + "foo"
return z
else:
return "bar"
s = fn.graph
# Check that we don't have any branching statements
FileCheck().check_not("block0()") \
.check_not("block1()") \
.run(s)
def test_union_type_refinement_statically_true(self):
@torch.jit.script
def fn(x: Union[List[int], int]) -> Union[List[int], int]:
if not torch.jit.isinstance(x, (int, List[int])):
return x
else:
l = [1, 2, 3]
y: Union[List[int], int] = l
return y
s = fn.graph
# Check that we don't have any branching statements
FileCheck().check_not("block0()") \
.check_not("block1()") \
.run(s)
def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
def fn(x: Union[List[int], int]) -> int:
if torch.jit.isinstance(x, (int, float, str)):
# We should know that `x` is an `int` here
z = x + 1
return z
else:
return 100
self.checkScript(fn, ([1, 2, 3],))
self.checkScript(fn, (1,))
def test_union_type_refinement_partial_static_refinement_union_rhs(self):
def fn(x: Union[List[int], int]) -> int:
if torch.jit.isinstance(x, Union[int, float, str]):
# We should know that `x` is an `int` here
z = x + 1
return z
else:
return 100
self.checkScript(fn, ([1, 2, 3],))
self.checkScript(fn, (1,))
def test_union_type_refinement_internal_declaration(self):
def fn(flag: bool) -> str:
x: Union[int, str, None] = None
if (flag):
y = "foo"
else:
y = 1
if isinstance(x, str):
return x
else:
return "bar"
self.checkScript(fn, (True,))
self.checkScript(fn, (False,))
def test_union_branching_with_union_return_and_homogenous_types(self):
def fn(x: int) -> Union[int, str]:
if x % 2:
return "foo"
else:
return "bar"
self.checkScript(fn, (1,))
self.checkScript(fn, (8,))
def test_union_branching_does_not_autoinfer_undeclared_union(self):
def fn(x: int) -> str:
if x % 2:
y = "foo"
else:
y = x
if isinstance(y, str):
return y
else:
return "bar"
with self.assertRaisesRegex(RuntimeError, "y is set to type str"
" in the true branch and type int "
"in the false branch"):
torch.jit.script(fn)
def test_union_branching_does_not_widen_existing_inferred_type(self):
def fn(x: int) -> str:
y = "foo"
if x % 2:
y = "bar"
else:
y = x
if isinstance(y, str):
return y
else:
return "baz"
with self.assertRaisesRegex(RuntimeError, "previously had type "
"str but is now being assigned to a"
" value of type int"):
torch.jit.script(fn)
def test_union_schema_matching_on_internal_type(self):
def fn(x: Union[List[int], Dict[str, int]]) -> int:
if torch.jit.isinstance(x, List[int]):
return x[0]
else:
return list(x.values())[0]
self.checkScript(fn, ([1, 2, 3],))
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
def test_union_subtractive_refinement(self):
def fn(x: Union[List[int], int]) -> int:
if not isinstance(x, int):
x.append(1)
return x[0]
else:
return x
self.checkScript(fn, (1,))
self.checkScript(fn, ([1, 2, 3],))
def test_union_subtractive_refinement_with_container(self):
def fn(x: Union[List[int], int]) -> int:
if not torch.jit.isinstance(x, List[int]):
return x
else:
x.append(1)
return x[0]
self.checkScript(fn, (1,))
self.checkScript(fn, ([1, 2, 3],))
def test_union_memory_aliasing(self):
def fn():
x : List[torch.Tensor] = []
z : List[Optional[List[torch.Tensor]]] = []
z.append(x)
x_alias = z[0]
if torch.jit.isinstance(x_alias, List[torch.Tensor]):
x_alias.append(torch.tensor(3))
return x
self.checkScript(fn, ())
def test_union_serialization_preserves_type_annotations(self):
# This function will fail after being torch.jit.save'd and
# torch.jit.load'd if the type annotations aren't preserved
# for Union during serialization. We need the `Union[str, int]`
# annotation to make sure that `y` is typed as a Union instead
# of as a str in one branch and an int in the other
def fn(x: int) -> str:
if x % 2:
y: Union[str, int] = "bar"
else:
y: Union[str, int] = x
if isinstance(y, str):
return y
else:
return "baz"
self.checkScript(fn, (1,))
self.checkScript(fn, (8,))

View File

@ -62,6 +62,7 @@ from jit.test_parametrization import TestParametrization # noqa: F401
from jit.test_attr import TestGetDefaultAttr # noqa: F401
from jit.test_aten_pow import TestAtenPow # noqa: F401
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
from jit.test_union import TestUnion # noqa: F401
# Torch
from torch import Tensor
@ -2518,32 +2519,6 @@ graph(%Ra, %Rb):
t = Test()
self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
def test_union_to_optional(self):
def test1(u: Union[int, None]) -> int:
if u is not None:
return u
else:
return 0
scripted = torch.jit.script(test1)
self.assertEqual(scripted(10), test1(10))
def test2(u: Union[None, int]) -> int:
if u is not None:
return u
else:
return 0
scripted = torch.jit.script(test2)
self.assertEqual(scripted(40), test2(40))
def test3(u: Union[float, int]) -> int:
if u is not None:
return u
else:
return 0
expected_result = "General Union types are not currently supported"
with self.assertRaisesRegex(RuntimeError, expected_result):
torch.jit.script(test3)
def test_mutable_default_values(self):
with self.assertRaisesRegex(Exception, "Mutable default parameters"):
@torch.jit.script
@ -8900,6 +8875,7 @@ dedent """
torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
@unittest.skipIf(True, "Skipping while landing PR stack")
def test_torch_functional(self):
def stft(input, n_fft):
# type: (Tensor, int) -> Tensor
@ -9809,8 +9785,9 @@ dedent """
bar()
def test_if_different_type(self):
with self.assertRaisesRegex(RuntimeError, "Type mismatch: c0 is set to type int "
"in the true branch and type float in the false branch:"):
with self.assertRaisesRegex(RuntimeError, "c0 is set to type "
"int in the true branch and type "
"float in the false branch"):
@torch.jit.script
def diff_type_used():
if 1 == 2:
@ -9819,7 +9796,7 @@ dedent """
c0 = 1.0
return c0
with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously has type float"):
with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously had type float"):
@torch.jit.script
def diff_existing_type(x):
c0 = 1.0
@ -10602,7 +10579,7 @@ dedent """
with self.assertRaisesRegex(RuntimeError, r'Expected a value of'
r' type \'List\[int\]\' for argument'
r' \'size\' but instead found type '
r'\'List\[Any\]\''):
r'\'List\[Union\[List\[int\], int\]\]'):
@torch.jit.script
def f6(a):
a.expand(size=[3, [4]])
@ -12672,7 +12649,7 @@ dedent """
for pair in self.type_input_return_pairs():
cu = torch.jit.CompilationUnit(self.format_code(code, pair))
test_str.append(str(cu.foo.schema))
self.assertExpected("\n".join(test_str))
self.assertExpected("\n".join(test_str) + "\n")
# String frontend , Python 3-style type annotations , Script method
def test_annot_string_py3_method(self):
@ -12691,7 +12668,7 @@ dedent """
tm = TestModule()
tm.define(self.format_code(code, pair))
test_str.append(str(tm.foo.schema))
self.assertExpectedStripMangled("\n".join(test_str))
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
# String frontend , MyPy-style type comments , Script function
def test_annot_string_mypy_fn(self):
@ -12704,7 +12681,7 @@ dedent """
for pair in self.type_input_return_pairs():
cu = torch.jit.CompilationUnit(self.format_code(code, pair))
test_str.append(str(cu.foo.schema))
self.assertExpectedStripMangled("\n".join(test_str))
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
# String frontend , MyPy-style type comments , Script method
def test_annot_string_mypy_method(self):
@ -12725,7 +12702,7 @@ dedent """
tm = TestModule()
tm.define(self.format_code(code, pair))
test_str.append(str(tm.foo.schema))
self.assertExpectedStripMangled("\n".join(test_str))
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
# Python AST Frontend , Python 3-style type annotations , Script function
def test_annot_ast_py3_fn(self):
@ -12742,7 +12719,7 @@ dedent """
for pair in self.type_input_return_pairs():
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
test_str.append(str(fn.schema))
self.assertExpectedStripMangled("\n".join(test_str))
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
def test_multiline_annot_ast_py3_fn(self):
code = dedent('''
@ -12817,7 +12794,7 @@ dedent """
for pair in self.type_input_return_pairs():
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
test_str.append(str(fn.foo.schema))
self.assertExpectedStripMangled("\n".join(test_str))
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
# Python AST Frontend , MyPy-style type comments , Script function
def test_annot_ast_mypy_fn(self):
@ -12833,7 +12810,7 @@ dedent """
for pair in self.type_input_return_pairs():
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
test_str.append(str(fn.schema))
self.assertExpected("\n".join(test_str))
self.assertExpected("\n".join(test_str) + "\n")
# Python AST Frontend , MyPy-style type comments , Script method
def test_annot_ast_mypy_method(self):
@ -12851,7 +12828,7 @@ dedent """
for pair in self.type_input_return_pairs():
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
test_str.append(str(fn.foo.schema))
self.assertExpectedStripMangled("\n".join(test_str))
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
# Tests that "# type: ignore[*]" is supported in type lines and is
# properly ignored.
@ -13521,8 +13498,8 @@ dedent """
self.checkScript(fn, ("y"))
def index_str_to_tensor(s):
# type: (str) -> int
return torch.tensor(ord(s))
# type: (str) -> Tensor
return torch.tensor(ord(s)) # noqa: T484
s = u'\u00a3'.encode('utf8')[:1]
self.checkScript(index_str_to_tensor, (s,))

View File

@ -1,5 +1,6 @@
from collections.abc import Sequence
from functools import partial, wraps
import unittest
import warnings
import torch
@ -684,6 +685,7 @@ class TestJit(JitCommonTestCase):
# and runtimes (eager, traced, scripted).
# TODO WARNING: inplace x {traced, scripted} not currently tested
@_variant_ops(op_db)
@unittest.skipIf(True, "Temporarily skipping while landing Union PR stack")
def test_variant_consistency_jit(self, device, dtype, op):
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
op.supports_complex_autograd(torch.device(device).type))

View File

@ -210,6 +210,7 @@ class TestPublicBindings(unittest.TestCase):
"TupleType",
"Type",
"unify_type_list",
"UnionType",
"Use",
"Value",
"autocast_decrement_nesting",

View File

@ -1001,6 +1001,9 @@ class TupleType(JitType):
def __init__(self, a: List[Optional[JitType]]) -> None: ...
def elements(self) -> List[JitType]: ...
class UnionType(JitType):
def __init__(self, a: List[JitType]) -> None: ...
class ClassType(JitType):
def __init__(self, qualified_name: str) -> None: ...

View File

@ -885,33 +885,28 @@ def is_dict(ann) -> bool:
(getattr(ann, '__origin__', None) is Dict or
getattr(ann, '__origin__', None) is dict)
def is_optional(ann) -> bool:
def is_union(ann):
if ann is Union:
raise_error_container_parameter_missing("Union")
return (hasattr(ann, '__module__') and
ann.__module__ == 'typing' and
(getattr(ann, '__origin__', None) is Union))
def is_optional(ann):
if ann is Optional:
raise_error_container_parameter_missing("Optional")
# Optional[T] is just shorthand for Union[T, None], so check for both
def safe_is_subclass(the_type, super_type):
# Don't throw if `the_type` isn't a class type (e.g. if it is
# another type annotation instance)
if not inspect.isclass(the_type):
return False
return issubclass(the_type, super_type)
def is_optional_as_optional(ann):
return (hasattr(ann, '__module__') and
ann.__module__ == 'typing' and
(getattr(ann, '__origin__', None) is Optional))
if not hasattr(ann, '__module__'):
return False
def is_union_as_optional(ann):
ann_args = ann.__args__
return len(ann_args) == 2 and None in ann_args
union_optional = False
if ann.__module__ == 'typing' and \
(getattr(ann, '__origin__', None) is Union):
args = getattr(ann, '__args__', ())
if len(args) == 2:
union_optional = (safe_is_subclass(args[1], type(None)) and not safe_is_subclass(args[0], type(None))) \
or (safe_is_subclass(args[0], type(None)) and not safe_is_subclass(args[1], type(None)))
optional = ann.__module__ == 'typing' and \
(getattr(ann, '__origin__', None) is Optional)
return optional or union_optional
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
def is_future(ann) -> bool:
if ann is Future:
@ -1192,14 +1187,15 @@ def container_checker(obj, target_type) -> bool:
elif not isinstance(el, el_type):
return False
return True
elif origin_type is Union: # actually handles Optional Case
elif origin_type is Union: # also handles Optional
if obj is None: # check before recursion because None is always fine
return True
optional_type = get_args(target_type)[0]
optional_origin = get_origin(optional_type)
if optional_origin:
return container_checker(obj, optional_type)
elif isinstance(obj, optional_type):
inner_types = get_args(target_type)
for t in inner_types:
t_origin = get_origin(t)
if (t_origin):
return container_checker(obj, t)
elif isinstance(obj, t):
return True
return False

View File

@ -792,7 +792,7 @@ In practice, the interpreter will allocate one Stack, and it will eventually rea
[runtime/operator.h](runtime/operator.h)
The Operator object represents a single registered operator in the system. It combines a FunctionSchema that describes how an Operation executes with a method to lookup the corresponding Operation given the `Node` representing the operator in a `Graph`. Most Operators are defined by providing a FunctionSchema and an Operation function. However, primitives like prim::Unpack require knowledge of their `Node` to know how to operate (e.g. how many elements to unpack). These Operators have a function that takes a `Node*` and returns an operation.
The Operator object represents a single registered operator in the system. It combines a FunctionSchema that describes how an Operation executes with a method to look up the corresponding Operation given the Node representing the operator in a Graph. Most Operators are defined by providing a FunctionSchema and an Operation function. However, primitives like prim::Unpack require knowledge of their Node to know how to operate (e.g. how many elements to unpack). These Operators have a function that takes a `Node*` and returns an operation.
## Interpreter ##
@ -1282,13 +1282,14 @@ Note the alias set `*`. This is the **wildcard set**. Optimization passes must a
This annotation language is consumed by the `FunctionSchema` parser, which produces `AliasInfo` objects summarizing the aliasing relationships for each schema `Argument`.
### Alias Analysis in the IR
[ir/alias_analysis.h](ir/alias_analysis.h)
An alias analysis pass consumes the per-operator aliasing information to construct a database of aliasing and mutation relationships in a graph, called `AliasDb`. This section focuses on the alias analysis pass; the public interface to `AliasDb` will be described later.
The core data structure in the AliasDb is called `AliasTracker`, which is a DAG where the edges are "may point to" relationships and the vertices are aliasing `Element`s. The most common kind of `Element` is an IR `Value`, but there are other kinds of things that can alias that aren't first-class `Value`s in the IR, like wildcards or contained types (such as in a list or tuple).
The core data structure in the AliasDb is called `MemoryDAG`, which is a DAG where the edges are "may point to" relationships and the vertices are aliasing `Element`s. The most common kind of `Element` is an IR `Value`, but there are other kinds of things that can alias that aren't first-class `Value`s in the IR, like wildcards or contained types (such as in a list or tuple).
The alias analysis pass walks through the nodes in a graph, examining schema `AliasInfo` objects and adding edges in the `AliasTracker` DAG accordingly. For example, for the node:
The alias analysis pass walks through the nodes in a graph, examining schema `AliasInfo` objects and adding edges in the `MemoryDAG` accordingly. For example, for the node:
```
%output : Tensor = aten::view(%self, %size)
```
@ -1321,7 +1322,7 @@ A few things to note:
The last point demonstrates a key concept: *leaf elements uniquely describe memory locations*. Since a leaf element doesn't point to anything, the memory that backs it must have been freshly allocated by some op. Thus we can use leaf elements to represent disjoint memory locations.
So to determine whether `a` and `b` may alias, we traverse the `AliasTracker` DAG and figure out if `a` and `b` share any leaf nodes. If they do, then we know `a` and `b` might point to the same memory location, i.e. `a` and `b` may alias. This kind of query is common enough that `AliasTracker` does path compression to speed up leaf-finding, so that aliasing queries can be serviced in amortized constant time.
So to determine whether `a` and `b` may alias, we traverse the `MemoryDAG` DAG and figure out if `a` and `b` share any leaf nodes. If they do, then we know `a` and `b` might point to the same memory location, i.e. `a` and `b` may alias. This kind of query is common enough that `MemoryDAG` does path compression to speed up leaf-finding, so that aliasing queries can be serviced in amortized constant time.
### Writing optimization passes with `AliasDb`
`AliasDb` provides a high-level interface to help people write mutability-safe optimization passes.

View File

@ -93,10 +93,8 @@ struct ControlFlowLoadStores {
for (const auto& x : mutated_variables) {
auto true_type = true_vars->findInAnyFrame(x);
auto false_type = false_vars->findInAnyFrame(x);
auto unified = unifyTypes(true_type, false_type);
if (!unified) {
continue;
}
auto unified =
unifyTypes(true_type, false_type, /*default_to_union=*/true);
addBlockOutput(true_block, true_type, x);
addBlockOutput(false_block, false_type, x);

View File

@ -150,8 +150,10 @@ struct ExitTransformer {
registerBlockOutputs(if_view.thenBlock(), true_outs);
registerBlockOutputs(if_view.elseBlock(), false_outs);
for (const auto i : c10::irange(true_outs.size())) {
auto out_type =
unifyTypes(true_outs.at(i)->type(), false_outs.at(i)->type());
auto out_type = unifyTypes(
true_outs.at(i)->type(),
false_outs.at(i)->type(),
/*default_to_union=*/true);
n->addOutput()->setType(*out_type);
}
}

View File

@ -185,7 +185,9 @@ NoneStatus canBeNone(Value* v) {
if (v->node()->mustBeNone()) {
return ALWAYS;
}
if (v->type()->kind() == OptionalType::Kind) {
if (v->type()->kind() == OptionalType::Kind ||
(v->type()->kind() == UnionType::Kind &&
v->type()->expect<UnionType>()->canHoldType(NoneType::get()))) {
return MAYBE;
}
return NEVER;
@ -385,7 +387,7 @@ struct Environment {
std::stringstream why_not;
if (!as_simple_value->type()->isSubtypeOfExt(parent_type, &why_not)) {
auto error = ErrorReport(loc);
error << "Variable '" << name << "' previously has type "
error << "Variable '" << name << "' previously had type "
<< simple_parent->type()->repr_str()
<< " but is now being assigned to a value of type "
<< as_simple_value->type()->repr_str();
@ -547,6 +549,7 @@ struct Environment {
if (!retval && required) {
throwVarNotFoundError(ident, range);
}
return retval;
}
@ -1010,57 +1013,61 @@ struct to_ir {
}
void emitReturn(const Return& stmt) {
TypePtr result_type = def_stack_.back().declared_return_type_;
Value* result = emitExpr(stmt.expr(), result_type);
TypePtr declared_return_type =
def_stack_.back().declared_return_type_; // nullptr if not annotated
auto actual_return = emitExpr(stmt.expr(), declared_return_type);
// result type is annotated, every return must convert to that type
if (result_type) {
if (declared_return_type) {
// this guard skips implicit conversion from None -> Tensor for the return
// type. otherwise forgetting a return a function returning a tensor will
// cause a None to be converted to a tensor.
if (!(result_type->isSubtypeOf(TensorType::get()) &&
result->type()->isSubtypeOf(NoneType::get()))) {
result = tryConvertToType(
if (!(actual_return->type()->isSubtypeOf(TensorType::get()) &&
actual_return->type()->isSubtypeOf(NoneType::get()))) {
actual_return = tryConvertToType(
stmt.range(),
*graph,
result_type,
result,
declared_return_type,
actual_return,
/*allow_conversions=*/true);
}
if (!result->type()->isSubtypeOf(result_type)) {
if (!actual_return->type()->isSubtypeOf(declared_return_type)) {
throw ErrorReport(stmt.range())
<< "Return value was annotated as having type "
<< result_type->repr_str() << " but is actually of type "
<< result->type()->repr_str();
<< declared_return_type->repr_str() << " but is actually of type "
<< actual_return->type()->repr_str();
}
} else {
result_type = def_stack_.back().merged_return_type_;
if (!result_type) {
result_type = result->type();
declared_return_type = def_stack_.back().merged_return_type_;
if (!declared_return_type) {
declared_return_type = actual_return->type();
}
auto merged_result_type = unifyTypes(result_type, result->type());
if (!merged_result_type) {
auto merged_return_type =
unifyTypes(declared_return_type, actual_return->type());
if (!merged_return_type) {
throw ErrorReport(stmt.range())
<< "Previous return statement returned a value of type "
<< result_type->repr_str()
<< declared_return_type->repr_str()
<< " but this return statement returns a value of type "
<< result->type()->repr_str();
<< actual_return->type()->repr_str();
}
result_type = merged_result_type.value();
declared_return_type = merged_return_type.value();
}
AT_ASSERT(result_type);
AT_ASSERT(declared_return_type);
def_stack_.back().merged_return_type_ = result_type;
def_stack_.back().merged_return_type_ = declared_return_type;
// If the annotated return type is Any and the result type is not Any,
// cast the result to Any to facilitate type unification between return
// statements on different code paths (e.g. different branches of an if,
// body and containing scope of a loop).
if (result_type == AnyType::get() && result->type() != AnyType::get()) {
result = graph->insertUncheckedCast(result, result_type);
if (declared_return_type == AnyType::get() &&
actual_return->type() != AnyType::get()) {
actual_return =
graph->insertUncheckedCast(actual_return, declared_return_type);
}
graph->insertNode(graph->create(prim::ReturnStmt, {result}, 0));
graph->insertNode(graph->create(prim::ReturnStmt, {actual_return}, 0));
exit_blocks.insert(environment_stack->block());
}
@ -1142,10 +1149,10 @@ struct to_ir {
return {};
}
// statement must be var {is, is not} None
auto name = Var(lhs).name().name();
// XXX - while it should in theory be possible to specialize
// the `x is None` to know x has type NoneType, we have previously not
// done this. Unfortunately, doing this will make the type None
const std::string& name = Var(lhs).name().name();
// While it should in theory be possible to specialize
// the `x is None` to know x has type NoneType, we have previously
// not done this. Unfortunately, doing this will make the type None
// propagate further in all loaded models. The handling of
// unwrap_optional will fail in these cases since export did
// not expect that the input would be none and an unannotated None.
@ -1154,7 +1161,7 @@ struct to_ir {
// and (2) only enable this OPTIONAL_NONE when loading newer
// graphs because it is incompatible with older graphs.
// Refinement none(name, RefinementKind::OPTIONAL_NONE);
if (auto optional_type = lhs_value->type()->cast<OptionalType>()) {
if (const auto optional_type = lhs_value->type()->cast<OptionalType>()) {
Refinement present(name, optional_type->getElementType());
if (tok == TK_IS) {
return RefinementSet({}, {present});
@ -1162,6 +1169,21 @@ struct to_ir {
return RefinementSet({present}, {});
}
}
if (const auto union_type = lhs_value->type()->cast<UnionType>()) {
std::vector<TypePtr> to_subtract{NoneType::get()};
c10::optional<TypePtr> remaining =
union_type->subtractTypeSet(to_subtract);
std::vector<Refinement> all_present;
if (remaining) {
Refinement present{name, *remaining};
all_present.push_back(std::move(present));
}
if (tok == TK_IS) {
return RefinementSet({}, all_present);
} else { // TK_ISNOT
return RefinementSet(all_present, {});
}
}
return RefinementSet();
}
@ -1340,7 +1362,7 @@ struct to_ir {
auto unified = unifyTypes(
lt->getElementType(),
out->type(),
/*default_to_any=*/true,
/*default_to_union=*/true,
element_type_hint);
if (lt->getElementType() != AnyType::get() &&
@ -1458,7 +1480,7 @@ struct to_ir {
c10::optional<TypePtr> unified = unifyTypes(
dt->getValueType(),
v->type(),
/*default_to_any=*/true,
/*default_to_union=*/true,
value_type_hint);
// Warn the user if we inferred the type of the values to be `Any`
@ -1755,13 +1777,32 @@ struct to_ir {
graph->createStore(x, fv)->insertBefore(false_block->return_node());
}
auto unified = unifyTypes(tv->type(), fv->type());
SugaredValuePtr maybe_sugared_x = environment_stack->findInAnyFrame(x);
TypePtr full_type = nullptr;
if (maybe_sugared_x) {
Value* maybe_simple = asSimple(maybe_sugared_x);
if (maybe_simple) {
full_type = maybe_simple->type();
}
}
// attempt to unify the types. we allow variables to be set to different
// types in each branch as long as that variable is not already in scope,
// or if that variable does not get used later. here, we save the error
// so that the error message will be more informative in the case that is
// used later. When a is accessed in (a + 1), the error will get printed
// Try to unify the types. If we found a type annotation earlier
// in the environment, and if that type annotation is some form
// of union, then we need to tell `unifyTypes` not to throw an
// error if the branched return types we found are heterogenous
bool default_to_union = full_type &&
(full_type->kind() == UnionType::Kind ||
full_type->kind() == OptionalType::Kind ||
full_type->kind() == NumberType::Kind);
auto unified = unifyTypes(
tv->type(), fv->type(), /*default_to_union=*/default_to_union);
// We allow variables to be set to different types in each branch
// as long as that variable is not already in scope or if that
// variable does not get used later. Here, we save the error so
// that the error message will be more informative in the case
// that is used later. When `a` is accessed in `(a + 1)`, the
// error will get printed:
// if cond:
// a = 1
// else:
@ -1799,76 +1840,146 @@ struct to_ir {
}
CondValue emitIsInstance(const Expr& obj, const Expr& classinfo) {
// turn (float, (int, tuple)) into a flat list of types and type kind
// category checks: tuple_check = true, types = {float, int}
struct GatheredTypes {
GatheredTypes(ScriptTypeParser parser) : typeParser_(std::move(parser)) {}
void gather(const Expr& classinfo) {
if (classinfo.kind() == TK_TUPLE_LITERAL) {
for (Expr e : TupleLiteral(classinfo).inputs()) {
gather(e);
Value* lhs_val = emitExpr(obj);
std::vector<TypePtr> lhs_types;
std::vector<TypePtr> rhs_types;
std::function<void(const Expr&)> gather_rhs = [&](const Expr& expr) {
if (expr.kind() == TK_TUPLE_LITERAL) {
for (Expr e : TupleLiteral(expr).inputs()) {
gather_rhs(e);
}
return;
}
TypePtr type = typeParser_.parseTypeFromExpr(classinfo);
types.emplace_back(type);
}
bool staticallyTrue(const TypePtr& actual_type) {
// is this isinstance check statically true?
for (const TypePtr& typ : types) {
if (actual_type->isSubtypeOf(typ)) {
return true;
}
}
return false;
}
bool maybeOfKind(TypeKind kind, const TypePtr& actual_type) {
if (actual_type->kind() == AnyType::Kind) {
return true;
}
if (auto op = actual_type->cast<OptionalType>()) {
return op->getElementType()->kind() == kind;
}
return false;
}
bool staticallyFalse(const TypePtr& actual_type) {
for (const TypePtr& typ : types) {
if (typ->isSubtypeOf(actual_type)) {
return false;
}
if ((typ->isSubtypeOf(AnyListType::get()) &&
maybeOfKind(ListType::Kind, actual_type)) ||
(typ->isSubtypeOf(AnyTupleType::get()) &&
maybeOfKind(TupleType::Kind, actual_type))) {
return false;
}
}
return true;
}
ScriptTypeParser typeParser_;
std::vector<TypePtr> types;
TypePtr type = typeParser_.parseTypeFromExpr(expr);
rhs_types.emplace_back(type);
};
GatheredTypes gathered(typeParser_);
gathered.gather(classinfo);
auto val = emitExpr(obj);
lhs_types.push_back(lhs_val->type());
gather_rhs(classinfo);
standardizeVectorForUnion(&lhs_types);
standardizeVectorForUnion(&rhs_types);
RefinementSet refinement;
if (gathered.types.size() == 1 &&
gathered.types.at(0)->isSubtypeOf(val->type()) &&
obj.kind() == TK_VAR) {
std::string ident = Var(obj).name().name();
Refinement isinstance(std::move(ident), gathered.types.at(0));
refinement = RefinementSet({isinstance}, {});
TypePtr unified_true = nullptr;
TypePtr unified_false = nullptr;
std::vector<TypePtr> isinstance_types;
std::vector<TypePtr> not_isinstance_types;
std::vector<Refinement> true_refinements;
std::vector<Refinement> false_refinements;
bool all_lhs_subtype_some_rhs = true;
// We can discard any rhs types that we know statically would be
// impossible. For example, if we had:
//
// def fn(x: Optional[str]):
// if isinstance(x, (List[str], str, int)):
// ...
//
// then `x` would be `str` in the true branch and `None` in the
// false branch, not `(List[str], str, int)` in the true branch
// and `None` in the false branch
for (const TypePtr& lhs_type : lhs_types) {
if (lhs_type == AnyType::get()) {
isinstance_types.insert(
isinstance_types.end(), rhs_types.begin(), rhs_types.end());
not_isinstance_types.push_back(AnyType::get());
// Edge case: we can still say that all lhs types subtype some
// rhs type if `lhs` is `Any` and `rhs` is `Any`
if (isinstance_types.size() != 1 ||
isinstance_types[0] != AnyType::get()) {
all_lhs_subtype_some_rhs = false;
}
break;
}
if (gathered.staticallyTrue(val->type())) {
auto get_smaller_type = [&](TypePtr t1, TypePtr t2) -> TypePtr {
if (t1->isSubtypeOf(t2)) {
return t1;
} else if (t2->isSubtypeOf(t1)) {
return t2;
} else {
return nullptr;
}
};
TypePtr found_refinement = nullptr;
for (const TypePtr& rhs_type : rhs_types) {
TypePtr maybe_smaller_type = get_smaller_type(lhs_type, rhs_type);
if (!maybe_smaller_type) {
continue;
} else if (*maybe_smaller_type == *lhs_type) {
// Cover the case that we have something like
// lhs = `List[str]` and rhs = `list`
found_refinement = lhs_type;
} else if (*maybe_smaller_type == *rhs_type) {
// We want the narrowest possible type
found_refinement = found_refinement
? *(unifyTypes(found_refinement, rhs_type))
: rhs_type;
}
}
if (found_refinement) {
if (*found_refinement == *lhs_type) {
all_lhs_subtype_some_rhs &= true;
}
isinstance_types.push_back(found_refinement);
} else {
// If the lhs couldn't be a subtype of the rhs (or couldn't
// be "refined" to itself, as in the `List[str]` and `list`
// case above), then we add `lhs_type` to the false branch
// refinements. This is because the type can still be itself
// if the `isinstance` check is false
not_isinstance_types.push_back(lhs_type);
all_lhs_subtype_some_rhs = false;
}
}
// For use with `unifyTypeList`
std::stringstream nowhere;
// Get a single type for the true and false branches
if (!isinstance_types.empty()) {
unified_true =
*unifyTypeList(isinstance_types, nowhere, /*default_to_union=*/true);
}
if (obj.kind() == TK_VAR && unified_true) {
std::string ident = Var(obj).name().name();
true_refinements = {Refinement(ident, unified_true)};
}
// Get a single type for the true and false branches
if (!not_isinstance_types.empty()) {
unified_false = *unifyTypeList(
not_isinstance_types, nowhere, /*default_to_union=*/true);
}
if (obj.kind() == TK_VAR && unified_false) {
std::string ident = Var(obj).name().name();
false_refinements = {Refinement(ident, unified_false)};
}
refinement = RefinementSet(true_refinements, false_refinements);
bool is_statically_false = isinstance_types.empty();
// If the statement is statically true
if (all_lhs_subtype_some_rhs) {
return CondValue(*graph, obj.range(), true, std::move(refinement));
}
if (gathered.staticallyFalse(val->type())) {
if (is_statically_false) {
return CondValue(*graph, obj.range(), false, std::move(refinement));
}
// check maybe true/false at runtime, need an actual op
Value* result =
graph->insertNode(graph->createIsInstance(val, gathered.types))
graph->insertNode(graph->createIsInstance(lhs_val, rhs_types))
->output();
return CondValue(result, std::move(refinement), c10::nullopt);
}
@ -2124,6 +2235,7 @@ struct to_ir {
}
// emit assserions as an if branch so that assertions will reuse the
// message
void emitAssert(const Assert& stmt) {
CondValue cond_value = emitCondExpr(stmt.test());
List<Stmt> true_branch = List<Stmt>::create(stmt.range(), {});
@ -2979,7 +3091,9 @@ struct to_ir {
// after annotation so that variables assigned to this None will still
// get the right type. To do this, we make a None constant that
// has the type Optional[T]
if (type->kind() == OptionalType::Kind &&
if ((type->kind() == OptionalType::Kind ||
(type->kind() == UnionType::Kind &&
type->expect<UnionType>()->canHoldType(NoneType::get()))) &&
expr->type()->isSubtypeOf(NoneType::get())) {
Node* none = graph->createNone();
none->output()->setType(type);
@ -3435,8 +3549,9 @@ struct to_ir {
size_t n_binders,
const TypePtr& type_hint = nullptr) {
switch (tree.kind()) {
case TK_VAR:
case TK_VAR: {
return environment_stack->getSugaredVar(Var(tree).name());
}
case '.': {
auto select = Select(tree);
auto sv = emitSugaredExpr(select.value(), 1);
@ -3710,7 +3825,7 @@ struct to_ir {
type_hint ? type_hint->expect<ListType>()->getElementType() : nullptr;
c10::optional<TypePtr> unified = unifyTypeList(
types, nowhere, /*default_to_any=*/true, element_type_hint);
types, nowhere, /*default_to_union=*/true, element_type_hint);
if (!type_hint && *unified == AnyType::get()) {
TORCH_WARN(
@ -3881,7 +3996,7 @@ struct to_ir {
c10::optional<TypePtr> unified = unifyTypeList(
types,
/*why_not=*/nowhere,
/*default_to_any=*/true,
/*default_to_union=*/true,
value_type_hint);
if (!type_hint && *unified == AnyType::get()) {

View File

@ -8,9 +8,10 @@
namespace torch {
namespace jit {
// try to match a list of inputs and keyword 'attributes' to this schema,
// if it works return the flat list of positional inputs to the call
// if it returns nullopt, then failure_messages contains a good error report
// Try to match a list of inputs and keyword 'attributes' to this
// schema. Return the flat list of positional inputs to the call or
// `c10::nullopt` on failure (`failure_messages` contains a good error
// report in this case)
struct MatchedSchema {
std::vector<Value*> inputs;

View File

@ -32,6 +32,7 @@ using c10::StringType;
using c10::Symbol;
using c10::TensorType;
using c10::TupleType;
using c10::UnionType;
using c10::VarType;
namespace torch {
@ -331,6 +332,18 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
L.expect(')');
alias_info = parseAliasAnnotation();
value = DictType::create(key_type, value_type);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
L.next();
L.expect('(');
std::vector<TypePtr> types;
types.emplace_back(parseType().first);
while (L.cur().kind != ')') {
L.expect(',');
types.emplace_back(parseType().first);
}
L.expect(')');
alias_info = parseAliasAnnotation();
value = UnionType::create(types);
} else if (
complete_tensor_types && L.cur().kind == TK_IDENT &&
parseTensorDType(L.cur().text())) {

View File

@ -42,7 +42,7 @@ TypePtr ScriptTypeParser::subscriptToType(
}
std::vector<TypePtr> subscript_expr_types;
for (auto expr : subscript.subscript_exprs()) {
subscript_expr_types.push_back(parseTypeFromExprImpl(expr));
subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr));
}
return TupleType::create(subscript_expr_types);
} else if (typeName == "List" || typeName == "list") {
@ -65,6 +65,13 @@ TypePtr ScriptTypeParser::subscriptToType(
parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
return OptionalType::create(elem_type);
} else if (typeName == "Union") {
std::vector<TypePtr> subscript_expr_types;
subscript_expr_types.reserve(subscript.subscript_exprs().size());
for (auto expr : subscript.subscript_exprs()) {
subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr));
}
return UnionType::create(subscript_expr_types);
} else if (typeName == "Future" || typeName == "torch.jit.Future") {
if (subscript.subscript_exprs().size() != 1) {
throw ErrorReport(subscript)
@ -83,30 +90,6 @@ TypePtr ScriptTypeParser::subscriptToType(
auto elem_type =
parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
return RRefType::create(elem_type);
} else if (typeName == "Union") {
// In Python 3.9+, Union[NoneType, T] or Union[T, NoneType] are
// treated as Optional[T]. Adding the same support for Union in Torchscript.
const char* const err =
"General Union types are not currently supported."
" Only Union[T, NoneType] (i.e. Optional[T]) is "
"supported.";
if (subscript.subscript_exprs().size() != 2) {
throw ErrorReport(subscript) << (err);
}
auto first_type = parseTypeFromExprImpl(subscript.subscript_exprs()[0]);
auto second_type = parseTypeFromExprImpl(subscript.subscript_exprs()[1]);
bool first_none = first_type == NoneType::get();
bool second_none = second_type == NoneType::get();
if (first_none && !second_none) {
return OptionalType::create(second_type);
} else if (!first_none && second_none) {
return OptionalType::create(first_type);
} else {
throw ErrorReport(subscript.range()) << err;
}
} else if (typeName == "Dict" || typeName == "dict") {
if (subscript.subscript_exprs().size() != 2) {
throw ErrorReport(subscript)

View File

@ -13,94 +13,139 @@ namespace jit {
namespace {
// For any mutable type, map it to a type such that all other types which it can
// alias will be mapped to the same type. This function follows a similar logic
// to `unifyTypes` because any two mutable types which can be unified
// can alias each other.
// getMutableTypePtr(Optional[List[int]]) == getMutableTypePtr([List[int]])
// If a type is not mutable, return nullopt
// This class helps convert types to their mutable equivalent by looking up
// cached conversions.
TypePtr toSingleType(AliasTypeSet& mut_types) {
return mut_types.size() == 1 ? mut_types[0]
: c10::UnionType::create(mut_types);
}
// This class determines whether a type is mutable, and, if so, it maps
// the type to its "mutable equivalent" (see definition in
// `mapTypeToAliasTypeSet`). It uses a cache of TypePtrs to speed up these
// type lookups
class MutableTypePtrHelper {
public:
explicit MutableTypePtrHelper(
std::unordered_map<TypePtr, TypePtr>* mutable_type_cache)
std::unordered_map<TypePtr, AliasTypeSet>* mutable_type_cache)
: mutable_type_cache_(mutable_type_cache) {}
c10::optional<TypePtr> getMutableType(const TypePtr& type) {
// Map any mutable type to a type such that all other types which the
// mutable type can alias will be mapped to the same type. For
// example, calling this method on `Optional[List[int]]` should be
// the same as calling this method on `List[int]`.
//
// Rules:
// - If the type is not mutable, return `nullopt`
// - If the type is a `Tuple`, that means that it's an immutable
// object that can itself contain mutable objects. We want to make
// sure that the mutable objects are correctly aliased, so we
// remove the immutable objects. (For example,
// `Tuple[int, Tensor]` would become `Tuple[Tensor]`, while
// `Tuple[int, str]` would be returned as `nullopt`.) This is a
// convenience that makes it easy to check if the `Tuple`
// contains only immutable objects, though it's not technically
// necessary
// - For any Tensor type (including Tensor types that are part of
// a larger container, e.g. `List[Tensor]`), return the
// "unshaped" version of that Tensor. An "unshaped" Tensor is a
// Tensor with shape information removed. For example, a Tensor
// of dimension 4 would map to the same type as a Tensor of
// dimension 1. This allows us to treat all subclasses of Tensor
// as a single, homogenous "Tensor" type.
c10::optional<AliasTypeSet> mapTypeToAliasTypeSet(const TypePtr& type) {
if (mutable_type_cache_) {
auto maybe_type = mutable_type_cache_->find(type);
if (maybe_type != mutable_type_cache_->end()) {
return maybe_type->second;
auto maybe_type_mapping = mutable_type_cache_->find(type);
if (maybe_type_mapping != mutable_type_cache_->end()) {
return maybe_type_mapping->second;
}
}
auto mutable_type = getMutableTypeImpl(type);
if (mutable_type_cache_ && mutable_type) {
mutable_type_cache_->emplace(type, *mutable_type);
auto mutable_types = mapTypeToAliasTypeSetImpl(type);
if (mutable_type_cache_ && mutable_types) {
mutable_type_cache_->emplace(type, *mutable_types);
}
return mutable_type;
return mutable_types;
}
private:
c10::optional<TypePtr> getMutableTypeImpl(const TypePtr& type) {
c10::optional<AliasTypeSet> mapTypeToAliasTypeSetImpl(const TypePtr& type) {
switch (type->kind()) {
case TypeKind::ListType:
case TypeKind::DictType:
case TypeKind::ClassType:
case TypeKind::TensorType:
// TODO: lookup cached contained types. this is kind of tricky
// because a List[Optional[T]] should still be
// List[Optional[Unshaped(T)]], however the getMutableType(Optional[T])
// == T
return unshapedType(type);
case TypeKind::OptionalType:
return getMutableType(type->castRaw<OptionalType>()->getElementType());
// TODO: Look up cached contained types. this is kind of tricky
// because a `List[Optional[T]]` should still be
// `List[Optional[Unshaped(T)]]`, but
// `mapTypeToAliasTypeSet(Optional[T])` should be `T`
return AliasTypeSet{unshapedType(type)};
case TypeKind::UnionType: {
AliasTypeSet mutable_types;
for (const TypePtr& inner :
type->expect<UnionType>()->containedTypes()) {
if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
mutable_types.insert(
mutable_types.end(),
(*maybe_inner_types).begin(),
(*maybe_inner_types).end());
}
}
if (mutable_types.size() == 0) {
return c10::nullopt;
}
return mutable_types;
}
case TypeKind::OptionalType: {
auto inner = type->castRaw<OptionalType>()->getElementType();
return mapTypeToAliasTypeSet(inner);
}
case TypeKind::AnyType:
return type;
return {AliasTypeSet{type}};
case TypeKind::FutureType: {
if (auto elem =
getMutableType(type->castRaw<FutureType>()->getElementType())) {
return FutureType::create(*elem);
if (auto maybe_mut_types = mapTypeToAliasTypeSet(
type->castRaw<FutureType>()->getElementType())) {
auto mut_type = toSingleType(*maybe_mut_types);
return {AliasTypeSet{FutureType::create(mut_type)}};
}
return c10::nullopt;
}
case TypeKind::TupleType: {
std::vector<TypePtr> mutable_types;
for (const auto& elem : type->expectRef<TupleType>().elements()) {
if (auto mut_elem = getMutableType(elem)) {
mutable_types.push_back(*mut_elem);
for (const TypePtr& inner : type->expectRef<TupleType>().elements()) {
if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
mutable_types.insert(
mutable_types.end(),
(*maybe_inner_types).begin(),
(*maybe_inner_types).end());
}
}
if (mutable_types.size() == 0) {
return c10::nullopt;
} else {
return TupleType::create(mutable_types);
}
return {AliasTypeSet{TupleType::create(mutable_types)}};
}
default:
return c10::nullopt;
}
}
std::unordered_map<TypePtr, TypePtr>* mutable_type_cache_;
std::unordered_map<TypePtr, AliasTypeSet>* mutable_type_cache_;
};
bool isMutableTypeImpl(
const TypePtr& type,
std::unordered_map<TypePtr, TypePtr>* mutable_type_cache) {
// check common cases to avoid recursively constructing type in
// getMutableTypePtrImpl
std::unordered_map<TypePtr, AliasTypeSet>* mutable_type_cache) {
// Check common cases to avoid recursively constructing type in
// `mapTypeToAliasTypeSetPtrImpl`
auto kind = type->kind();
if (kind == TypeKind::TensorType || kind == TypeKind::ListType ||
kind == TypeKind::ClassType || kind == TypeKind::DictType) {
return true;
}
MutableTypePtrHelper helper(mutable_type_cache);
return helper.getMutableType(type) != c10::nullopt;
return helper.mapTypeToAliasTypeSet(type) != c10::nullopt;
}
} // namespace
// static isMutableType does not use cache of type -> mutable type equivalent
// Static `isMutableType` does not use cache of type -> mutable type equivalent
bool AliasDb::isMutableType(const TypePtr& type) {
return isMutableTypeImpl(type, nullptr);
}
@ -109,7 +154,7 @@ bool AliasDb::isMutableType(const Value* v) {
return isMutableType(v->type());
}
// makes use of type -> mutable cache
// Make use of type -> mutable cache
bool AliasDb::isMutableTypeInternal(const TypePtr& type) const {
return isMutableTypeImpl(type, &mapped_mutable_types_);
}
@ -118,21 +163,17 @@ bool AliasDb::isMutableTypeInternal(const Value* v) const {
return isMutableTypeInternal(v->type());
}
c10::optional<TypePtr> AliasDb::getMutableTypePtr(const TypePtr& type) const {
c10::optional<AliasTypeSet> AliasDb::mapTypeToAliasTypeSetPtr(
const TypePtr& type) const {
MutableTypePtrHelper helper(&mapped_mutable_types_);
return helper.getMutableType(type);
}
bool AliasDb::isContainerType(const TypePtr& type) const {
auto mut_type = getMutableTypePtr(type);
return mut_type && (*mut_type)->containedTypes().size() > 0;
return helper.mapTypeToAliasTypeSet(type);
}
AliasDb::~AliasDb() = default;
// Structure used during analysis to keeps track of all writes at a high level.
// When analysis is completed this will be used to construct a more efficient
// WriteIndex.
// Structure used during analysis to keep track of all writes at a high
// level. When the analysis is completed, this will be used to construct
// a more efficient WriteIndex
struct AliasDb::WriteRegistry {
void registerWrite(const Value* v, Node* n) {
writes_[n].emplace_back(v);
@ -170,7 +211,7 @@ AliasDb::AliasDb(std::shared_ptr<Graph> graph, bool isFrozen)
writeIndex_ = TWriteIndex();
auto& writeIndex = *writeIndex_; // to make operator[] less ugly
// build the write index
// Build the write index
for (const auto& write : writeRegistry_->writes_) {
Node* node = write.first;
const std::vector<const Value*> writtenValues = write.second;
@ -207,7 +248,7 @@ AliasDb::AliasDb(std::shared_ptr<Graph> graph, bool isFrozen)
// out of sync (since we have no way of registering new writes)
writeRegistry_ = nullptr;
// initialize the write cache
// Initialize the write cache
buildWrittenToLocationsIndex();
GRAPH_DEBUG(toString());
}
@ -324,10 +365,10 @@ MemoryLocations AliasDb::getReads(Node* n) const {
std::string AliasDb::getElementName(const Element* e) const {
if (e->values.empty()) {
// not the most efficient way, but given the fact there are
// Not the most efficient way, but given the fact there are
// not too many types and even fewer of them will end up in
// wildcardIndex_, we should be fine with a linear search
// each time we hit a wildcard leaf
// `wildcardIndex_`, we should be fine with a linear search
// each time we hit a Wildcard leaf
for (const auto& ent : wildcardIndex_) {
if (ent.second == e) {
return std::string("WILDCARD for type ") + ent.first->str();
@ -362,17 +403,27 @@ std::string AliasDb::toString() const {
ss << "\n===2. ALIAS DB===\n";
for (const auto& ptrPair : elementMap_) {
const auto element = ptrPair.second;
int ct = 0;
if (!element->pointsTo.empty()) {
ss << getElementName(element) << " points to: ";
for (const auto pointedTo : element->pointsTo) {
ss << getElementName(memoryDAG_->fromIndex(pointedTo)) << ", ";
if (ct > 0) {
ss << ", ";
}
++ct;
ss << getElementName(memoryDAG_->fromIndex(pointedTo));
}
ss << "\n";
}
ct = 0;
if (!element->containedElements.empty()) {
ss << getElementName(element) << " contains: ";
for (const auto contained : element->containedElements) {
ss << getElementName(memoryDAG_->fromIndex(contained)) << ", ";
ss << getElementName(memoryDAG_->fromIndex(contained));
if (ct > 0) {
ss << ", ";
}
++ct;
}
ss << "\n";
}
@ -839,8 +890,7 @@ void AliasDb::analyzeLoop(Node* node) {
TORCH_INTERNAL_ASSERT(blockOutputs.size() == node->outputs().size());
// Run alias analysis on the loop body, iterating until the block output
// alias info converges.
// Copy node input aliases to block input
// alias info converges. Copy node input aliases to block input
mapAliases(blockInputs, loopCarriedInputs);
// Populate block output alias info by analyzing the body
@ -996,7 +1046,7 @@ bool AliasDb::functionalNonEscapingListUse(const Use& use) const {
return false;
}
// List or dict or tuple: construct: create an aliasing element for the actual
// List or dict or tuple construct: create an aliasing element for the actual
// container, then mark all inputs as wildcards, since they've gone inside the
// container. Then, add the wildcard sets of appropriate type to the contained
// elements of the container.
@ -1073,52 +1123,50 @@ void AliasDb::makePointerTo(const Value* from, const Value* to) {
return;
}
// the contained types of immutable type containers (optional, tuple, future)
// are unified, so these types can be mutable or immutable
// and point to a type which is mutable or immutable.
// Any is mutable but can point to an immutable type through refinement
// The contained types of immutable type containers (`Optional`,
// `Tuple`, `Future`, and `Union`) are unified, so these types can be
// mutable or immutable and point to a type which is mutable or
// immutable. `Any` is mutable but can point to an immutable type
// through refinement
if (isMutableTypeInternal(from) != isMutableTypeInternal(to)) {
bool expected_kind = false;
for (auto kind : {from->type()->kind(), to->type()->kind()}) {
expected_kind = expected_kind ||
(kind == TypeKind::OptionalType || kind == TypeKind::FutureType ||
kind == TypeKind::TupleType) // immutable type containers
kind == TypeKind::TupleType ||
kind == TypeKind::UnionType) // immutable type containers
|| kind == TypeKind::AnyType;
}
TORCH_INTERNAL_ASSERT(
expected_kind, from->type()->str(), to->type()->str());
return;
}
// both immutable
if (!isMutableTypeInternal(from)) {
return;
}
if (from == to) {
return;
}
// At this point, we are dealing with two mutable types.
auto fromEl = getOrCreateElement(from);
auto toEl = getOrCreateElement(to);
// At this point, we are dealing with two mutable types
auto from_el = getOrCreateElement(from);
auto to_el = getOrCreateElement(to);
memoryDAGBuilder_->makePointerTo(fromEl, toEl);
memoryDAGBuilder_->makePointerTo(from_el, to_el);
}
void AliasDb::addToContainedElements(
const Value* elem,
const Value* inner,
const Value* container) {
if (!isMutableTypeInternal(elem)) {
if (!isMutableTypeInternal(inner)) {
return;
}
TORCH_INTERNAL_ASSERT(isContainerType(container->type()));
auto inner_el = getOrCreateElement(inner);
auto cont_el = getOrCreateElement(container);
auto elemEl = getOrCreateElement(elem);
auto contEl = getOrCreateElement(container);
memoryDAGBuilder_->addToContainedElements(elemEl, contEl);
memoryDAGBuilder_->addToContainedElements(inner_el, cont_el);
}
bool AliasDb::mayAlias(const Value* a, const Value* b) const {
@ -1203,8 +1251,8 @@ void AliasDb::createValue(const Value* value) {
void AliasDb::giveFreshAlias(
const Value* value,
bool add_wildcard_to_contained_elems) {
auto maybe_mut_type = getMutableTypePtr(value->type());
if (!maybe_mut_type) {
auto maybe_mut_types = mapTypeToAliasTypeSetPtr(value->type());
if (!maybe_mut_types) {
return;
}
@ -1217,7 +1265,11 @@ void AliasDb::giveFreshAlias(
auto new_elem = memoryDAGBuilder_->makeFreshValue(value);
elementMap_[value] = new_elem;
if (add_wildcard_to_contained_elems) {
addContainedTypesToFreshElement(new_elem, *maybe_mut_type);
if ((*maybe_mut_types).size() > 1) {
pointUnionTypeElementToAllContainedTypes(new_elem, *maybe_mut_types);
} else {
addContainedTypesToFreshElement(new_elem, *maybe_mut_types);
}
}
}
@ -1639,56 +1691,86 @@ bool AliasDb::mayAliasWildcard(const at::ArrayRef<Value*> vs) const {
}
c10::optional<Element*> AliasDb::tryGetOrCreateWildcard(const TypePtr& type) {
auto updated_type = getMutableTypePtr(type);
if (!updated_type) {
auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type);
if (!maybe_mut_types) {
return c10::nullopt;
}
auto mapped_type = *updated_type;
auto existing_wildcard = wildcardIndex_.find(mapped_type);
auto mut_type = toSingleType(*maybe_mut_types);
auto existing_wildcard = wildcardIndex_.find(mut_type);
if (existing_wildcard != wildcardIndex_.end()) {
return existing_wildcard->second;
}
auto wildcard_elem = memoryDAGBuilder_->makeFreshValue(nullptr);
wildcardIndex_.emplace(mapped_type, wildcard_elem);
addContainedTypesToFreshElement(wildcard_elem, mapped_type);
wildcardIndex_.emplace(mut_type, wildcard_elem);
if ((*maybe_mut_types).size() > 1) {
pointUnionTypeElementToAllContainedTypes(wildcard_elem, *maybe_mut_types);
} else {
addContainedTypesToFreshElement(wildcard_elem, *maybe_mut_types);
}
return wildcard_elem;
}
void AliasDb::pointUnionTypeElementToAllContainedTypes(
Element* container_elem,
const AliasTypeSet& mut_types) {
for (const auto& mut_type : mut_types) {
auto maybe_elem = tryGetOrCreateWildcard(mut_type);
if (maybe_elem) {
TORCH_INTERNAL_ASSERT(*maybe_elem != container_elem);
memoryDAGBuilder_->makePointerTo(container_elem, *maybe_elem);
}
}
}
void AliasDb::addContainedTypesToFreshElement(
Element* container_elem,
const TypePtr& mut_type) {
const AliasTypeSet& mut_types) {
for (const auto& mut_type : mut_types) {
for (const auto& contained : mut_type->containedTypes()) {
auto maybe_elem = tryGetOrCreateWildcard(contained);
if (maybe_elem) {
memoryDAGBuilder_->addToContainedElements(*maybe_elem, container_elem);
}
}
}
}
// Search the wildcard index for an element that corresponds to the given type.
// Const version returns nullptr
Element* AliasDb::getWildcard(const TypePtr& type) const {
auto maybe_mut_type = getMutableTypePtr(type);
if (!maybe_mut_type) {
return nullptr;
auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type);
if (!maybe_mut_types) {
return {};
}
TypePtr mut_type = *maybe_mut_type;
auto wildcard = wildcardIndex_.find(mut_type);
if (wildcard != wildcardIndex_.end()) {
return wildcard->second;
if ((*maybe_mut_types).size() > 1) {
auto union_type = UnionType::create(*maybe_mut_types);
// Get a <TypePtr, Element*> pair where the TypePtr is this Union
// type and the Element is the corresponding Wildcard
auto maybe_union_pair = wildcardIndex_.find(union_type);
if (maybe_union_pair != wildcardIndex_.end()) {
return (*maybe_union_pair).second;
}
return nullptr;
} else {
// Get a <TypePtr, Element*> pair where the TypePtr is the given
// type and the Element is the corresponding Wildcard
auto type_pair = wildcardIndex_.find((*maybe_mut_types)[0]);
if (type_pair != wildcardIndex_.end()) {
return type_pair->second;
}
}
return {};
}
// Register `v` as a wildcard value.
c10::optional<Element*> AliasDb::setWildcard(const Value* v) {
auto maybe_wildcardElement = tryGetOrCreateWildcard(v->type());
c10::optional<Element*> maybe_wildcardElement =
tryGetOrCreateWildcard(v->type());
if (!maybe_wildcardElement) {
return c10::nullopt;
}
// Ensure that we create a corresponding element for `v` still, as it is an
// invariant that all mutable values have an element.
// Ensure that we create a corresponding Element for `v` still, as it is an
// invariant that all mutable values have an Element
getOrCreateElement(v);
wildcards_.insert(v);
return *maybe_wildcardElement;

View File

@ -34,6 +34,12 @@ namespace jit {
* Values that contain other mutable types, such as List[Tensor], are
* initialized as containing the Wildcard set for all contained mutable types.
*
* The AliasDb API references the idea of "mutable" vs "immutable"
* types. "Mutable" means that the object's value can change, while
* "immutable" means that the value is fixed. (For example, `List` is
* mutable, so you can add and delete elements from it. On the other
* hand, you can't modify a Tuple once you create it, making `Tuple` an
* immutable container.)
*/
class AliasDb {
public:
@ -95,7 +101,7 @@ class AliasDb {
const at::ArrayRef<Value*>& a,
const at::ArrayRef<Value*>& b) const;
// Move 'n' (already in the graph) after 'movePoint' in the topological order.
// Move `n` (already in the graph) after `movePoint` in the topological order.
//
// Tries to preserve value dependencies, so other nodes might be moved. We
// make two guarantees about the postcondition of the node list:
@ -125,6 +131,10 @@ class AliasDb {
TORCH_API bool dumpToGraphvizFile(const char* filename) const;
TORCH_API std::string toGraphviz() const;
// Returns `true` if the given element is mutable or if it is a
// container type with an internal mutable element (e.g.
// `Tuple[int, Tensor]` has an internal mutable type `Tensor`, so
// it would be considered a "mutable type" in AliasDb)
static bool isMutableType(const Value* v);
static bool isMutableType(const TypePtr& type);
@ -181,7 +191,7 @@ class AliasDb {
// Register `v` as a wildcard value.
c10::optional<Element*> setWildcard(const Value* v);
// Is this a value which will not alias
// Is this a value which will not alias?
bool nonAliasingValue(const Value* elem) const;
/**
@ -221,11 +231,10 @@ class AliasDb {
bool add_wildcard_to_contained_elems = true);
Element* getOrCreateElement(const Value* value);
c10::optional<TypePtr> getMutableTypePtr(const TypePtr& type) const;
c10::optional<AliasTypeSet> mapTypeToAliasTypeSetPtr(
const TypePtr& type) const;
bool functionalNonEscapingListUse(const Use& use) const;
bool isContainerType(const TypePtr& type) const;
std::shared_ptr<Graph> graph_;
// If the Module is frozen then consider attributes as freshly created
@ -239,21 +248,24 @@ class AliasDb {
// Mapping of values to MemoryDAG elements
ska::flat_hash_map<const Value*, Element*> elementMap_;
// All wildcard elements (one for each unique mutable type).
// All wildcard Elements (one for each unique mutable type)
std::unordered_map<TypePtr, Element*, HashType, EqualType> wildcardIndex_;
Element* getWildcard(const TypePtr& type) const;
c10::optional<Element*> tryGetOrCreateWildcard(const TypePtr& type);
void addContainedTypesToFreshElement(
Element* container_elem,
const TypePtr& mut_type);
const AliasTypeSet& mut_types);
void pointUnionTypeElementToAllContainedTypes(
Element* container_elem,
const AliasTypeSet& mut_types);
std::vector<Element*> getElements(at::ArrayRef<Value*> vs) const;
bool mayAliasWildcard(const Value* v) const;
bool mayAliasWildcard(const at::ArrayRef<Value*> vs) const;
bool hasWriters(const at::ArrayRef<Value*>& values) const;
// cached mapping of type ptrs to their mutable types
mutable std::unordered_map<TypePtr, TypePtr> mapped_mutable_types_;
// Cached mapping of type ptrs to their mutable types
mutable std::unordered_map<TypePtr, AliasTypeSet> mapped_mutable_types_;
/**
* State for tracking write info.

View File

@ -511,7 +511,7 @@ void Graph::lint() const {
// - Params and return do NOT occur in nodes
// - next_unique_ is greater than all uniques in graph
// - uniques in all_nodes are unique
// - every use will occur later in the topsort
// - every use will occur later in the toposort
struct LintScope {
LintScope() = default;
@ -787,7 +787,9 @@ bool Value::mustBeNone() const {
}
bool Value::mustNotBeNone() const {
return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
!type()->cast<OptionalType>();
!type()->cast<OptionalType>() &&
!(type()->cast<UnionType>() &&
type()->expect<UnionType>()->canHoldType(NoneType::get()));
}
std::string Value::debugNameBase() const {
@ -1765,20 +1767,23 @@ Node* Graph::createEnumValue(Value* e) {
return n;
}
Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
Node* Graph::createList(
const TypePtr& contained_type,
at::ArrayRef<Value*> values) {
auto n = create(prim::ListConstruct, values);
for (const auto& v : values) {
TORCH_CHECK(
v->type()->isSubtypeOf(elem_type),
v->type()->isSubtypeOf(contained_type),
"Expected a list element that subtypes '",
elem_type->repr_str(),
contained_type->repr_str(),
"' but got an element of type '",
v->type()->repr_str(),
"'");
}
n->output()->setType(ListType::create(elem_type));
n->output()->setType(ListType::create(contained_type));
return n;
}
Node* Graph::createListUnpack(Value* v, size_t size) {
ListTypePtr list_type = v->type()->expect<ListType>();
TypePtr elem_type = list_type->getElementType();

View File

@ -84,7 +84,7 @@ using namespace ::c10::cuda;
struct Function;
struct MatchedSchema;
// Graph represents one "function" of computation.
// A Graph represents one "function" of computation.
// It uses a simple ownership model where the graph owns all the nodes inside
// it. All references inside the graph are raw pointers. Destroying the Graph
// will invalidate any pointers to nodes in the graph.
@ -104,9 +104,9 @@ TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);
// A list of nodes, with inputs and outputs
struct Block;
// Each use is represented by this type, see Node::uses()
// 'user' is the consumer of the value, offset is the index into
// 'user's input this where the produces will be found.
// Each use is represented by this type, see 'Node::uses()'
// 'user' is the consumer of the value, 'offset' is the index into
// 'user's input this where the producers will be found.
struct Use {
Use(Node* user, size_t offset) : user(user), offset(offset) {}
Node* user;
@ -338,14 +338,16 @@ struct TORCH_API Node {
protected:
Node(Graph* graph_, NodeKind kind_); // defined after graph
public:
// each node but Return/Param
// is associated with exactly one place in the node list...
// of the graph_
// this circular is a doubly-linked list, the Return node is used as the
// sentinel for the beginning and end of the list such that the list never has
// null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev
// pointer using an array to allow the same iterator class for forward and
// reverse node lists This list represents a topological sort
// Each Node but Return/Param Nodes are associated with exactly one
// place in the Node list of the Graph. The Graph itself is a circular
// doubly-linked list. The Return Node is used as the sentinel for the
// "beginning"/"end" of the list. This means that you can tell when
// you've traversed the entire list without means worrying about null
// pointers. `next_in_graph[0]` is the pointer to the next Node, while
// `next_in_graph[1]` is the pointer to the previous Node. The
// linked list is implemented as an array to allow the same iterator
// class for forward and reversed Node lists. Taken together, this
// list also represents a topological sort of the Nodes in the Graph.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-non-private-member-variables-in-classes,modernize-avoid-c-arrays)
Node* next_in_graph[2] = {nullptr, nullptr};
@ -980,7 +982,6 @@ struct TORCH_API Node {
// subclasses should extend if they have additional information to copy.
// 'this' will be allocated with s->allocNewInstance(g) so it should have
// the same concrete type as 's'
//
virtual void cloneFrom(Node* s);
};
@ -1247,7 +1248,7 @@ struct Graph {
TORCH_API Node* createEnumName(Value* e);
TORCH_API Node* createEnumValue(Value* e);
TORCH_API Node* createList(
const TypePtr& elem_type,
const TypePtr& contained_type,
at::ArrayRef<Value*> values);
TORCH_API Node* createListUnpack(Value* v, size_t size);
TORCH_API Node* createDict(

View File

@ -42,6 +42,17 @@ class TypeParser {
return simpleTypeIt->second;
} else if (token == "List") {
return CreateSingleElementType<ListType>();
} else if (token == "Union") {
std::vector<TypePtr> types;
expect("[");
while (cur() != "]") {
types.emplace_back(parse());
if (cur() != "]") {
expect(",");
}
}
expect("]");
return UnionType::create(types);
} else if (token == "Optional") {
return CreateSingleElementType<OptionalType>();
} else if (token == "Future") {

View File

@ -288,6 +288,24 @@ class ShapePropagator {
return zerodim;
}
bool mergeTypes(
ArrayRef<Value*> lhs,
ArrayRef<Value*> rhs,
ArrayRef<Value*> outputs) {
AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size());
bool changed = false;
for (size_t i = 0; i < lhs.size(); ++i) {
auto old_output_type = outputs[i]->type();
auto new_type =
unifyTypes(lhs[i]->type(), rhs[i]->type(), /*default_to_union=*/true);
AT_ASSERT(new_type);
outputs[i]->setType(*new_type);
if (*old_output_type != *outputs[i]->type())
changed = true;
}
return changed;
}
void broadcastBinary(
Node* node,
std::vector<TensorTypePtr>& types,

View File

@ -8,6 +8,7 @@
namespace torch {
namespace jit {
namespace {
void makePointerToImpl(Element* from, Element* to) {
from->pointsTo.set(to->index);
to->pointedFrom.set(from->index);
@ -131,11 +132,13 @@ Element* MemoryDAGBuilder::makeFreshValue(const Value* v) {
return makeFreshValueImpl(v, indexToElementMap_);
}
// This function builds up a bitset representing the "alias set" for
// `e` (`MemoryLocations` is just a typedef'd c10::SparseBitVector).
const MemoryLocations& MemoryDAG::getMemoryLocations(const Element* e) const {
// Note on cache invalidation: all mutation should occur through
// MemoryDAGBuilder. Thus, once we consume the builder to create an immutable
// MemoryDAG, we can cache here without worrying that we might potentially get
// invalidated.
// MemoryDAGBuilder. Thus, once we consume the builder to create an
// immutable MemoryDAG, we can cache here without worrying that we
// might potentially get invalidated.
if (e->cachedMemoryLocations_) {
return *e->cachedMemoryLocations_;
}
@ -174,7 +177,6 @@ void MemoryDAG::setWildcards(
makePointerToImpl(from, wildcardElement);
}
}
// Track which memory locations we edited with a new pointer to the wildcard
// element.
cacheUpdates[wildcardElement] |= pointeeSet;
@ -189,7 +191,6 @@ void MemoryDAG::setWildcards(
for (const std::unique_ptr<Element>& e : this->indexToElementMap_) {
if (e->values.empty()) {
// This element is a wildcard element, we can skip it.
TORCH_INTERNAL_ASSERT(e->pointsTo.empty());
continue;
}

View File

@ -1,9 +1,12 @@
#pragma once
#include <ATen/core/jit_type.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/sparse_bitset.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/type_hashing.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
@ -20,6 +23,9 @@ struct Element;
struct Value;
class MemoryDAG;
using TypePtr = std::shared_ptr<c10::Type>;
using AliasTypeSet = std::vector<TypePtr>;
/**
* Helper to build up the points-to graph.
*
@ -38,13 +44,15 @@ class TORCH_API MemoryDAGBuilder {
void addToContainedElements(Element* contained, Element* container);
// Make a fresh element (i.e. an element that doesn't point to anything) and
// Make a fresh Element (i.e. an Element that doesn't point to anything) and
// return it.
Element* makeFreshValue(const Value* v);
friend MemoryDAG;
private:
// `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses
// the map to construct the `MemoryDAG`
std::vector<std::unique_ptr<Element>> indexToElementMap_;
};
@ -54,8 +62,8 @@ class TORCH_API MemoryDAGBuilder {
// AliasDb to provide a higher-level API.
//
// We maintain a DAG where:
// - Vertices (called "elements") represent values and
// other aliasing entities (e.g. like the stuff inside a list)
// - Vertices (called "Elements") represent Values and
// other aliasing entities (e.g. the stuff inside a list)
// - Edges represent a "points-to" relationship.
//
// Leaves in this DAG are entities that don't point to anything, and thus
@ -80,7 +88,7 @@ class TORCH_API MemoryDAG {
bool mayAlias(const Element* a, const Element* b) const;
bool mayAlias(Element* a, Element* b) const;
// Does a hold reference to any memory that is stored in elem, or vice versa?
// Does `a` hold reference to any memory that is stored in `b`, or vice versa?
bool mayContainAlias(const Element* a, const Element* b) const;
bool mayContainAlias(Element* a, Element* b) const;
@ -96,12 +104,13 @@ class TORCH_API MemoryDAG {
MemoryLocations& cont) const;
/**
* The following methods are special cases where we need to reach mutate the
* The following methods are special cases where we need to mutate the
* internals of MemoryDAG for efficiency reasons. Don't call them unless you
* know what you're doing! In particular, don't add new mutating methods
* without ensuring that you are maintaining cache consistency for memory
* locations.
*/
// Adding wildcards can trigger extremely expensive cache invalidations. This
// method adds them in a more efficient cache-aware way.
void setWildcards(
@ -117,9 +126,10 @@ class TORCH_API MemoryDAG {
std::vector<std::unique_ptr<Element>> indexToElementMap_;
};
// `Element` represents the vertex in the points-to graph. It represents
// anything that could have an aliasing relationship, mostly IR `Value`s, but
// also the "inside of a list", or wildcards.
// `Element` represents a vertex in the points-to graph. It represents
// anything that could have an aliasing relationship--mostly IR
// `Value`s, but also wildcards or the type inside a container (e.g. `T`
// in `List[T]`)
struct Element {
Element(const Value* value_, unsigned index_);
// wildcard constructor

View File

@ -89,6 +89,19 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type)
: c10::ivalue::Tuple::create(std::move(values));
}
case TypeKind::UnionType: {
auto actual_type = toTypeInferredIValue(obj);
auto actual_type_ptr = actual_type.type();
auto union_type = type->expect<UnionType>();
if (!actual_type_ptr->isSubtypeOf(union_type)) {
throw py::cast_error(c10::str(
"Expected a member of ",
union_type->annotation_str(),
" but instead found type ",
actual_type.type()->annotation_str()));
}
return actual_type;
}
case TypeKind::StringType:
return ConstantString::create(py::cast<std::string>(obj));
case TypeKind::DeviceObjType: {

View File

@ -869,6 +869,12 @@ void initPythonIRBindings(PyObject* module_) {
}
return types;
});
py::class_<UnionType, Type, std::shared_ptr<UnionType>>(m, "UnionType")
.def(py::init(
[](const std::vector<TypePtr>& a) { return UnionType::create(a); }))
.def("containedTypes", [](UnionType& self) {
return self.containedTypes().vec();
});
py::class_<ListType, Type, std::shared_ptr<ListType>>(m, "ListType")
.def(py::init([](TypePtr a) { return ListType::create(a); }))
.def_static("ofInts", &ListType::ofInts)

View File

@ -47,7 +47,8 @@ void postSetStateValidate(const IValue& v) {
// const auto attrType = objType->getAttribute(i);
// Verify that all the non-optional attributes have been initialized
// TODO: Issue #20497
if (attrType->kind() != TypeKind::OptionalType &&
if (attrType->kind() != TypeKind::UnionType &&
attrType->kind() != TypeKind::OptionalType &&
attrType->kind() != TypeKind::NoneType) {
TORCH_CHECK(
!slot.isNone(),

View File

@ -482,12 +482,13 @@ void SourceImporterImpl::importClass(
} break;
case TK_DEF: {
Def def = Def(statement);
if (pre_hook_names.find(def.name().name()) != pre_hook_names.end()) {
pre_hook_def_map.emplace(def.name().name(), def);
pre_hook_resolver_map.emplace(def.name().name(), shared_from_this());
} else if (hook_names.find(def.name().name()) != hook_names.end()) {
hook_def_map.emplace(def.name().name(), def);
hook_resolver_map.emplace(def.name().name(), shared_from_this());
const auto def_name = def.name().name();
if (pre_hook_names.find(def_name) != pre_hook_names.end()) {
pre_hook_def_map.emplace(def_name, def);
pre_hook_resolver_map.emplace(def_name, shared_from_this());
} else if (hook_names.find(def_name) != hook_names.end()) {
hook_def_map.emplace(def_name, def);
hook_resolver_map.emplace(def_name, shared_from_this());
} else {
methods.emplace_back(def);
method_resolvers.push_back(shared_from_this());

View File

@ -511,14 +511,32 @@ struct PythonPrintImpl {
}
indent();
printValueList(body_, lhs);
// We need to preserve Union/Optional type annotations, but only if
// we're not assigning values as part of a tuple unpacking statement
// (Python doesn't allow type annotations in multiple assignment)
if (lhs.size() == 1) {
Value* v = lhs.at(0);
if (!annotated_unions_.count(v) && !expr_table_.count(v) &&
(v->type()->kind() == UnionType::Kind ||
v->type()->kind() == OptionalType::Kind)) {
body_ << " : " << v->type()->annotation_str();
annotated_unions_.insert(v);
}
}
body_ << " = ";
// or if value is being assigned to something of a union type
printValueList(body_, rhs);
body_ << "\n";
}
bool requiresAnnotation(Value* lhs, Value* rhs) {
if (lhs->type()->kind() == UnionType::Kind ||
lhs->type()->kind() == OptionalType::Kind) {
return annotated_unions_.insert(lhs).second;
} else {
return *lhs->type() != *rhs->type();
}
}
void printAnnotatedAssignment(
at::ArrayRef<Value*> lhs,
@ -1302,10 +1320,12 @@ struct PythonPrintImpl {
body_ << arg_name;
if (print_first_argument_type) {
body_ << ": " << arg.type()->annotation_str(type_printer_);
annotated_unions_.insert(*param_it);
}
} else {
body_ << ",\n " << arg_name << ": "
<< arg.type()->annotation_str(type_printer_);
annotated_unions_.insert(*param_it);
}
if (arg.default_value()) {
printDefaultValue(arg, body_, *arg.default_value());
@ -1559,6 +1579,12 @@ struct PythonPrintImpl {
// table.
PrintDepsTable& deps_table_;
// We need to preserve Union/Optional type annotations, but we should
// only print the annotation on variable declaration (not on any
// following uses). This set tracks the Value*s that we've already
// printed with annotations
std::unordered_set<Value*> annotated_unions_;
// A function that, given a named type, returns us the correct string to print
// for it.
c10::TypePrinter type_printer_;

View File

@ -23,8 +23,8 @@ static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
// Pickled objects are stored in a form compatible with Python pickling.
// In torchscript List[T]/Dict[K, V] are statically typed and contain
// dynamic type tags allow T, K, and V to be recovered. But this info
// is not stored in the Python pickling information. However, we
// dynamic type tags that allow T, K, and V to be recovered. But this
// info is not stored in the Python pickling information. However, we
// can recover this information from the static type of the top-level
// object being unpickled, because we have a record of the type of the
// objects it contains as attributes.
@ -108,6 +108,19 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
to_process.emplace_back(std::move(elem));
}
} break;
case UnionType::Kind: {
auto t = w.static_type->expect<UnionType>();
if (t->containedTypes().size() == 2 &&
t->canHoldType(NoneType::get())) {
if (!w.value.isNone()) {
auto inner = t->containedTypes()[0] != NoneType::get()
? t->containedTypes()[0]
: t->containedTypes()[1];
Work elem = {inner, w.value};
to_process.emplace_back(std::move(elem));
}
}
} break;
case ListType::Kind: {
// specialized lists do not need their type refined, so we can exit
// early here

View File

@ -449,7 +449,7 @@ if _enabled:
setattr(RecursiveScriptClass, method_name, method_template)
# this is a Python 'non-data descriptor' that causes the first access
# to ScriptModule's forward to lookup the forward method and stash
# to ScriptModule's forward to look up the forward method and stash
# it in the objects dict. Due to the standard rules for attribute lookup,
# subsequent lookups will just directly return the previously looked up method.
# This is necessary because nn.Module defines forward as a method. If we

View File

@ -6,13 +6,13 @@ import builtins
import torch
import warnings
from .._jit_internal import List, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
is_optional, _qualified_name, Any, Future, is_future, is_ignored_fn
is_optional, _qualified_name, Any, Future, is_future, is_ignored_fn, Union, is_union
from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingList3 # type: ignore[attr-defined]
from ._state import _get_script_class
from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, NoneType, \
DeviceObjType, StreamObjType, FutureType, EnumType
ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, \
NoneType, DeviceObjType, StreamObjType, FutureType, EnumType, UnionType
from textwrap import dedent
@ -45,7 +45,8 @@ class EvalEnv(object):
'List': List,
'Dict': Dict,
'Optional': Optional,
'Future': Future,
'Union': Union,
'Future': Future
}
def __init__(self, rcb):
@ -245,6 +246,9 @@ def split_type_line(type_line):
def try_real_annotations(fn, loc):
"""Tries to use the Py3.5+ annotation syntax to get the type."""
try:
# Note: anything annotated as `Optional[T]` will automatically
# be returned as `Union[T, None]` per
# https://github.com/python/typing/blob/master/src/typing.py#L850
sig = inspect.signature(fn)
except ValueError:
return None
@ -276,7 +280,6 @@ def get_enum_value_type(e: Type[enum.Enum], loc):
return torch._C.unify_type_list(ir_types)
def is_tensor(ann):
if issubclass(ann, torch.Tensor):
return True
@ -326,6 +329,19 @@ def try_ann_to_type(ann, loc):
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
assert valid_type, msg.format(repr(ann), repr(contained))
return OptionalType(valid_type)
if is_union(ann):
inner: List = []
# We need these extra checks because both `None` and invalid
# values will return `None`
# TODO: Determine if the other cases need to be fixed as well
for a in ann.__args__:
if a is None:
inner.append(NoneType.get())
maybe_type = try_ann_to_type(a, loc)
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
assert maybe_type, msg.format(repr(ann), repr(maybe_type))
inner.append(maybe_type)
return UnionType(inner) # type: ignore[arg-type]
if torch.distributed.rpc.is_available() and is_rref(ann):
return RRefType(try_ann_to_type(ann.__args__[0], loc))
if is_future(ann):
@ -390,6 +406,8 @@ __all__ = [
'is_list',
'Dict',
'is_dict',
'is_optional',
'is_union',
'TensorType',
'TupleType',
'FloatType',

View File

@ -452,6 +452,7 @@ def get_default_args(fn):
return {}
signature = inspect.signature(fn)
return {
k: v.default
for k, v in signature.parameters.items()