mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support Union in TorchScript (#64234)
Summary: This PR is created to replace https://github.com/pytorch/pytorch/pull/53180 PR stack, which has all the review discussions. Reason for needing a replacement is due to a messy Sandcastle issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64234 Reviewed By: gmagogsfm Differential Revision: D30656444 Pulled By: ansley fbshipit-source-id: 77536c8bcc88162e2c72636026ca3c16891d669a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
91b926fab3
commit
6831d8e379
@ -435,12 +435,12 @@ is `./build/bin/FILENAME --gtest_filter=TESTSUITE.TESTNAME`, where
|
||||
`TESTNAME` is the name of the test you'd like to run and `TESTSUITE` is
|
||||
the suite that test is defined in.
|
||||
|
||||
For example, if you wanted to run the test ` MayContainAlias`, which
|
||||
For example, if you wanted to run the test `MayContainAlias`, which
|
||||
is part of the test suite `ContainerAliasingTest` in the file
|
||||
`test/cpp/jit/test_alias_analysis.cpp`, the command would be:
|
||||
|
||||
```bash
|
||||
./build/bin/test_jit --gtest_filter=ContainerAliasingTest.UnionAliasing
|
||||
./build/bin/test_jit --gtest_filter=ContainerAliasingTest.MayContainAlias
|
||||
```
|
||||
|
||||
|
||||
|
@ -30,6 +30,9 @@ struct FunctionSchema;
|
||||
struct NamedType;
|
||||
using OptNameList = c10::optional<std::vector<std::string>>;
|
||||
|
||||
void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
|
||||
void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
|
||||
|
||||
struct AnyType;
|
||||
using AnyTypePtr = std::shared_ptr<AnyType>;
|
||||
// Any is the top of the type hierarchy, all other types are subtypes
|
||||
@ -94,25 +97,84 @@ struct SingleElementType : public Type {
|
||||
TypePtr elem;
|
||||
};
|
||||
|
||||
struct UnionType;
|
||||
using UnionTypePtr = std::shared_ptr<UnionType>;
|
||||
struct TORCH_API UnionType : public Type {
|
||||
friend struct Type;
|
||||
|
||||
static const TypeKind Kind = TypeKind::UnionType;
|
||||
|
||||
bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override;
|
||||
|
||||
std::string str() const override;
|
||||
|
||||
static UnionTypePtr create(std::vector<TypePtr> reference);
|
||||
|
||||
bool operator==(const Type& rhs) const override;
|
||||
|
||||
at::ArrayRef<TypePtr> containedTypes() const override {
|
||||
return types_;
|
||||
}
|
||||
|
||||
// For testing purposes only
|
||||
at::ArrayRef<TypePtr> getTypes() const {
|
||||
return types_;
|
||||
}
|
||||
|
||||
TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
|
||||
return create(contained_types);
|
||||
}
|
||||
|
||||
bool canHoldType(TypePtr type) const;
|
||||
|
||||
bool hasFreeVariables() const override {
|
||||
return has_free_variables_;
|
||||
}
|
||||
|
||||
c10::optional<TypePtr> toOptional() const;
|
||||
|
||||
c10::optional<TypePtr> subtractTypeSet(std::vector<TypePtr>& to_subtract) const;
|
||||
|
||||
protected:
|
||||
explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
|
||||
std::string unionStr(TypePrinter printer = nullptr, bool is_annotation_str = false) const;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
bool has_free_variables_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::vector<TypePtr> types_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
bool can_hold_none_;
|
||||
|
||||
};
|
||||
|
||||
struct OptionalType;
|
||||
using OptionalTypePtr = std::shared_ptr<OptionalType>;
|
||||
// This type represents an optional type, for each element type.
|
||||
// Optional[T] can accept both T and None(nullopt in C++)
|
||||
// This type represents an optional type. There is one `Optional` for
|
||||
// each element type. `Optional[T]` can accept both `T` and
|
||||
// `None`(`c10::nullopt` in C++)
|
||||
// Subtype hierarchy for Optional:
|
||||
// 1. Optional[T] <: Optional[R] iff T <: R
|
||||
// 2. T <: Optional[R] if T <: R
|
||||
// 3. None <: Optional[T] for all T
|
||||
struct TORCH_API OptionalType
|
||||
: public SingleElementType<TypeKind::OptionalType, OptionalType> {
|
||||
static OptionalTypePtr create(TypePtr element) {
|
||||
TORCH_INTERNAL_ASSERT(element, "OptionalType requires valid TypePtr");
|
||||
// Optional is a union of [None, T], so Optional[[Optional[T]]] ->
|
||||
// Optional[T]
|
||||
if (auto opt_ptr = element->cast<OptionalType>()) {
|
||||
return opt_ptr;
|
||||
// - Optional[T] <: Optional[R] iff T <: R
|
||||
// - T <: Optional[R] if T <: R
|
||||
// - None <: Optional[T] for all T
|
||||
// - Optional[T] == Union[T, None] for all T
|
||||
struct TORCH_API OptionalType : public UnionType {
|
||||
static OptionalTypePtr create(TypePtr contained) {
|
||||
return OptionalTypePtr(new OptionalType(std::move(contained)));
|
||||
}
|
||||
return OptionalTypePtr(
|
||||
new OptionalType(std::move(element))); // NOLINT(modernize-make-shared)
|
||||
|
||||
static const TypeKind Kind = TypeKind::OptionalType;
|
||||
|
||||
friend struct Type;
|
||||
|
||||
bool operator==(const Type& rhs) const override;
|
||||
|
||||
TypePtr getElementType() const {
|
||||
return contained_;
|
||||
}
|
||||
|
||||
at::ArrayRef<TypePtr> containedTypes() const override {
|
||||
return contained_;
|
||||
}
|
||||
|
||||
std::string str() const override {
|
||||
@ -127,20 +189,15 @@ struct TORCH_API OptionalType
|
||||
return create(contained_types[0]);
|
||||
}
|
||||
|
||||
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
|
||||
if (Type::isSubtypeOfExt(rhs, why_not)) {
|
||||
return true;
|
||||
}
|
||||
if (auto rhs_ = rhs->cast<OptionalType>()) {
|
||||
return getElementType()->isSubtypeOfExt(rhs_->getElementType(), why_not);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
|
||||
|
||||
// common cast Optional[Tensor] for undefined tensor type
|
||||
static OptionalTypePtr ofTensor();
|
||||
|
||||
private:
|
||||
OptionalType(TypePtr elem) : SingleElementType(elem) {}
|
||||
explicit OptionalType(TypePtr contained);
|
||||
|
||||
TypePtr contained_;
|
||||
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
@ -908,7 +965,6 @@ struct TORCH_API RRefType
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct NamedType;
|
||||
using NamedTypePtr = std::shared_ptr<NamedType>;
|
||||
using ConstNamedTypePtr = std::shared_ptr<const NamedType>;
|
||||
@ -1112,7 +1168,6 @@ struct TORCH_API EnumType : public NamedType {
|
||||
std::weak_ptr<::torch::jit::CompilationUnit> cu_;
|
||||
};
|
||||
|
||||
|
||||
// the common supertype of all Enums, only used in operator registraion.
|
||||
// EnumType <: AnyEnumType for all Enums
|
||||
struct AnyEnumType;
|
||||
@ -1132,7 +1187,6 @@ private:
|
||||
: Type(TypeKind::AnyEnumType) {}
|
||||
};
|
||||
|
||||
|
||||
struct NumberType;
|
||||
using NumberTypePtr = std::shared_ptr<NumberType>;
|
||||
// This type represents a Python number
|
||||
@ -1141,9 +1195,10 @@ using NumberTypePtr = std::shared_ptr<NumberType>;
|
||||
// FloatType <: NumberType
|
||||
// ComplexType <:NumberType
|
||||
struct TORCH_API NumberType : public Type {
|
||||
bool operator==(const Type& rhs) const override {
|
||||
return rhs.kind() == kind();
|
||||
}
|
||||
bool operator==(const Type& rhs) const override;
|
||||
|
||||
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
|
||||
|
||||
std::string str() const override {
|
||||
return "Scalar"; // match what PythonArgParser says for clarity
|
||||
}
|
||||
@ -1172,7 +1227,8 @@ struct TORCH_API FloatType : public NumberType {
|
||||
return "float";
|
||||
}
|
||||
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
|
||||
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
|
||||
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
|
||||
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
static const TypeKind Kind = TypeKind::FloatType;
|
||||
// global singleton
|
||||
@ -1196,7 +1252,8 @@ struct TORCH_API ComplexType : public NumberType {
|
||||
return "complex";
|
||||
}
|
||||
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
|
||||
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
|
||||
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
|
||||
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
static const TypeKind Kind = TypeKind::ComplexType;
|
||||
// global singleton
|
||||
@ -1220,7 +1277,8 @@ struct TORCH_API IntType : public NumberType {
|
||||
return "int";
|
||||
}
|
||||
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
|
||||
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
|
||||
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
|
||||
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
static const TypeKind Kind = TypeKind::IntType;
|
||||
// global singleton
|
||||
@ -1334,12 +1392,8 @@ struct TORCH_API NoneType : public Type {
|
||||
std::string str() const override {
|
||||
return "NoneType";
|
||||
}
|
||||
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override {
|
||||
if (rhs->kind() == OptionalType::Kind) {
|
||||
return true;
|
||||
}
|
||||
return Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override;
|
||||
|
||||
static const TypeKind Kind = TypeKind::NoneType;
|
||||
// global singleton
|
||||
static NoneTypePtr get();
|
||||
@ -1524,8 +1578,15 @@ TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
|
||||
// what is the type, ignoring extra size/shape information?
|
||||
// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
|
||||
|
||||
// xxx: be careful with calls because this can be very slow. If calling this on a graph
|
||||
// use `EraseShapeInformation` in shape_analysis.h
|
||||
// `unshapedType` is used to remove Tensor subtypes. We treat all Tensor
|
||||
// subtypes as simply "Tensor"; we also create a new version of any
|
||||
// container types in which internal Tensors have undergone the same
|
||||
// operation. This is used for type comparisons between two Tensor types
|
||||
// (`unshapedType` means that we don't falsely return `false` for e.g.
|
||||
// Tensors of different dimensions). It's also used in the alias
|
||||
// analysis pass.
|
||||
// Be careful with calls because this can be very slow. If calling this
|
||||
// on a graph, use `EraseShapeInformation` in shape_analysis.h
|
||||
inline TypePtr unshapedType(const TypePtr& type) {
|
||||
if (type->isSubtypeOf(TensorType::get())) {
|
||||
return TensorType::get();
|
||||
@ -1569,27 +1630,32 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
|
||||
return *result;
|
||||
}
|
||||
|
||||
// Attempt to find the correct supertype of t1 and t2. If none is found then
|
||||
// nullopt will be returned if default_to_any is false, and Any will be returned
|
||||
// if it is true. If t1 == t2, or t1 is a type refinement of t2,
|
||||
// then t2 will be returned (and vice versa).
|
||||
// Attempt to find the correct supertype of the two types `t1` and `t2`.
|
||||
// If no supertype is found, then nullopt will be returned if
|
||||
// `default_to_union` is false, and `Union[t1, t2]` will be returned
|
||||
// if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`,
|
||||
// then `t2` will be returned (and vice versa).
|
||||
//
|
||||
// Two different tensortypes will return dynamic.
|
||||
// Currently we chose not to support returning a NumberType for a float & int
|
||||
// input because of a lack of operator support for NumberType.
|
||||
//
|
||||
// Currently we chose not to support returning a NumberType for
|
||||
// two types from the set of {FloatType, IntType, ComplexType}, because
|
||||
// there is a lack of operator support for NumberType.
|
||||
//
|
||||
// If `type_hint` is an `InterfaceType`, then we can use that as a
|
||||
// potential supertype for `ClassType`s in the list. Otherwise, we have
|
||||
// no way to find and use some common interface type
|
||||
TORCH_API c10::optional<TypePtr> unifyTypes(
|
||||
const TypePtr& t1,
|
||||
const TypePtr& t2,
|
||||
bool default_to_any = false,
|
||||
TypePtr type_hint=nullptr);
|
||||
bool default_to_union = false,
|
||||
TypePtr type_hint = nullptr);
|
||||
|
||||
TORCH_API c10::optional<TypePtr> unifyTypeList(
|
||||
at::ArrayRef<TypePtr> elements,
|
||||
std::ostream& why_not,
|
||||
bool default_to_any=false,
|
||||
TypePtr type_hint=nullptr);
|
||||
bool default_to_union = false,
|
||||
TypePtr type_hint = nullptr);
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
|
@ -44,7 +44,8 @@ namespace c10 {
|
||||
_(ScalarTypeType) \
|
||||
_(AnyListType) \
|
||||
_(AnyTupleType) \
|
||||
_(AnyClassType)
|
||||
_(AnyClassType) \
|
||||
_(UnionType)
|
||||
|
||||
enum class TypeKind {
|
||||
#define DEFINE_TYPE(T) T,
|
||||
@ -203,7 +204,7 @@ struct TORCH_API Type : std::enable_shared_from_this<Type> {
|
||||
// contained_types
|
||||
TypePtr withContained(std::vector<TypePtr> contained_types) {
|
||||
auto current_contained = containedTypes();
|
||||
AT_ASSERT(current_contained.size() == contained_types.size());
|
||||
TORCH_INTERNAL_ASSERT(current_contained.size() == contained_types.size());
|
||||
if (current_contained.equals(contained_types)) {
|
||||
return shared_from_this();
|
||||
}
|
||||
|
@ -265,7 +265,7 @@ AnyEnumTypePtr AnyEnumType::get() {
|
||||
return value;
|
||||
}
|
||||
|
||||
c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_any=false, TypePtr type_hint=nullptr) {
|
||||
c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_union=false, TypePtr type_hint=nullptr) {
|
||||
// check direct subtyping relation
|
||||
if (t1->isSubtypeOf(t2)) {
|
||||
return t2;
|
||||
@ -308,7 +308,7 @@ c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool
|
||||
}
|
||||
std::vector<TypePtr> elements;
|
||||
for (size_t i = 0; i < tuple1->elements().size(); i++) {
|
||||
if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_any)) {
|
||||
if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_union)) {
|
||||
elements.push_back(*elem);
|
||||
} else {
|
||||
return c10::nullopt;
|
||||
@ -347,11 +347,11 @@ c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any, TypePtr type_hint) {
|
||||
auto unified = unifyTypesImpl(t1, t2, default_to_any, type_hint);
|
||||
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_union, TypePtr type_hint) {
|
||||
auto unified = unifyTypesImpl(t1, t2, default_to_union, type_hint);
|
||||
|
||||
if (default_to_any && !unified) {
|
||||
return AnyType::get();
|
||||
if (default_to_union && !unified) {
|
||||
return UnionType::create({t1, t2});
|
||||
}
|
||||
|
||||
return unified;
|
||||
@ -360,7 +360,7 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool def
|
||||
c10::optional<TypePtr> unifyTypeList(
|
||||
at::ArrayRef<TypePtr> elements,
|
||||
std::ostream& why_not,
|
||||
bool default_to_any,
|
||||
bool default_to_union,
|
||||
TypePtr type_hint) {
|
||||
if (elements.size() == 0) {
|
||||
why_not << "Cannot get unified type from empty list";
|
||||
@ -369,7 +369,7 @@ c10::optional<TypePtr> unifyTypeList(
|
||||
|
||||
TypePtr ret_type = elements.at(0);
|
||||
for (size_t i = 1; i < elements.size() && ret_type; ++i) {
|
||||
c10::optional<TypePtr> maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_any, type_hint);
|
||||
c10::optional<TypePtr> maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_union, type_hint);
|
||||
if (!maybe_unified) {
|
||||
why_not << "Could not unify type list since element " << i << " of type "
|
||||
<< elements.at(i)->repr_str()
|
||||
@ -547,8 +547,9 @@ TORCH_API TypePtr tryEvalTypeVariables(TypePtr type, std::unordered_map<std::str
|
||||
}
|
||||
|
||||
TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) {
|
||||
if (elem_type->kind() == OptionalType::Kind ||
|
||||
elem_type->kind() == NumberType::Kind) {
|
||||
if (elem_type->kind() == UnionType::Kind
|
||||
|| elem_type->kind() == OptionalType::Kind
|
||||
|| elem_type->kind() == NumberType::Kind) {
|
||||
// Builtin Union types
|
||||
return false;
|
||||
}
|
||||
@ -577,8 +578,16 @@ bool Type::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
|
||||
if (rhs->kind() == TypeKind::AnyType || *this == *rhs) {
|
||||
return true;
|
||||
}
|
||||
if(auto rhs_ = rhs->cast<OptionalType>()) {
|
||||
return this->isSubtypeOfExt(rhs_->getElementType(), why_not);
|
||||
if (auto opt_rhs = rhs->cast<OptionalType>()) {
|
||||
return this->isSubtypeOfExt(opt_rhs->getElementType(), why_not);
|
||||
}
|
||||
if (auto union_rhs = rhs->cast<UnionType>()) {
|
||||
// Check if `this` is a subtype of any of the types within the Union
|
||||
return std::any_of(union_rhs->containedTypes().begin(),
|
||||
union_rhs->containedTypes().end(),
|
||||
[&](TypePtr inner) {
|
||||
return this->isSubtypeOfExt(inner, why_not);
|
||||
});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -808,6 +817,453 @@ TupleTypePtr TupleType::createNamed(const c10::optional<c10::QualifiedName>& qua
|
||||
field_types, qualName, schema)); // NOLINT(modernize-make-shared)
|
||||
}
|
||||
|
||||
bool NoneType::isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const {
|
||||
if (rhs->kind() == OptionalType::Kind) {
|
||||
return true;
|
||||
}
|
||||
return Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
|
||||
// Remove nested Optionals/Unions during the instantiation of a Union or
|
||||
// an Optional. This populates `types` with all the types found during
|
||||
// flattening. At the end of `flattenUnion`, `types` may have
|
||||
// duplicates, but it will not have nested Optionals/Unions
|
||||
void flattenUnion(TypePtr& type, std::vector<TypePtr>* to_fill) {
|
||||
if (auto union_type = type->cast<UnionType>()) {
|
||||
for (auto inner : union_type->containedTypes()) {
|
||||
flattenUnion(inner, to_fill);
|
||||
}
|
||||
} else if (auto opt_type = type->cast<OptionalType>()) {
|
||||
auto inner = opt_type->getElementType();
|
||||
flattenUnion(inner, to_fill);
|
||||
to_fill->emplace_back(NoneType::get());
|
||||
} else if (type->kind() == NumberType::Kind) {
|
||||
to_fill->emplace_back(IntType::get());
|
||||
to_fill->emplace_back(FloatType::get());
|
||||
to_fill->emplace_back(ComplexType::get());
|
||||
} else {
|
||||
to_fill->emplace_back(type);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for `standardizeUnion`
|
||||
//
|
||||
// NB: If we have types `T1`, `T2`, `T3`, and `PARENT_T` such that `T1`,
|
||||
// `T2`, and `T2` are children of `PARENT_T`, then `unifyTypes(T1, T2)`
|
||||
// will return `PARENT_T`. This could be a problem if we didn't want our
|
||||
// Union to also be able to take `T3 `. In our current type hierarchy,
|
||||
// this isn't an issue--most types SHOULD be unified even if the parent
|
||||
// type wasn't in the original vector. However, later additions to the
|
||||
// type system might necessitate reworking `get_supertype`
|
||||
void filterDuplicateSubtypes(std::vector<TypePtr>* types) {
|
||||
if (types->empty()) {
|
||||
return;
|
||||
}
|
||||
auto get_supertype = [](const TypePtr t1, const TypePtr t2) -> c10::optional<TypePtr> {
|
||||
// We don't want nested Optionals. Also, prematurely unifying to
|
||||
// `Optional` could prevent us from coalescing other types
|
||||
if ((t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get()))
|
||||
|| (!t1->isSubtypeOf(NoneType::get()) && t2->isSubtypeOf(NoneType::get()))) {
|
||||
return c10::nullopt;
|
||||
} else {
|
||||
return unifyTypes(t1, t2, /*default_to_union=*/false);
|
||||
}
|
||||
};
|
||||
|
||||
// Coalesce types and delete all duplicates. Moving from right to left
|
||||
// through the vector, we try to unify the current element (`i`) with
|
||||
// each element (`j`) before the "new" end of the vector (`end`).
|
||||
// If we're able to unify the types at `types[i]` and `types[j]`, we
|
||||
// decrement `end`, swap `types[j]` with the unified type, and
|
||||
// break. Otherwise, we keep `end` where it is to signify that the
|
||||
// new end of the vector hasn't shifted
|
||||
size_t end_idx = types->size()-1;
|
||||
for (size_t i = types->size()-1; i > 0; --i) {
|
||||
for (size_t j = std::min(i-1, end_idx); ; --j) {
|
||||
c10::optional<TypePtr> unified;
|
||||
unified = get_supertype((*types)[i], (*types)[j]);
|
||||
if (unified) {
|
||||
(*types)[j] = *unified;
|
||||
(*types)[i] = (*types)[end_idx];
|
||||
--end_idx;
|
||||
break;
|
||||
}
|
||||
// Break condition here so we don't get `j = 0; j = j-1` and end
|
||||
// up with MAX_INT
|
||||
if (j == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Cut off the vector's tail so that `end` is the real last element
|
||||
types->erase(types->begin() + end_idx + 1, types->end());
|
||||
|
||||
}
|
||||
|
||||
void sortUnion(std::vector<TypePtr>* types) {
|
||||
// We want the elements to be sorted so we can easily compare two
|
||||
// UnionType objects for equality in the future. Note that this order
|
||||
// is guaranteed to be stable since we've already coalesced any
|
||||
// possible types
|
||||
std::sort(types->begin(), types->end(),
|
||||
[](const TypePtr a, const TypePtr b) -> bool {
|
||||
if (a->kind() != b->kind()) {
|
||||
return a->kind() < b->kind();
|
||||
}
|
||||
return a->str() < b->str();
|
||||
});
|
||||
}
|
||||
|
||||
void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill) {
|
||||
for (auto type : reference) {
|
||||
flattenUnion(type, to_fill);
|
||||
}
|
||||
filterDuplicateSubtypes(to_fill);
|
||||
sortUnion(to_fill);
|
||||
}
|
||||
|
||||
void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten) {
|
||||
TORCH_INTERNAL_ASSERT(to_flatten, "`standardizeVectorForUnion` was ",
|
||||
"passed a `nullptr`");
|
||||
std::vector<TypePtr> to_fill;
|
||||
standardizeVectorForUnion(*to_flatten, &to_fill);
|
||||
*to_flatten = to_fill;
|
||||
}
|
||||
|
||||
UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : Type(kind) {
|
||||
TORCH_INTERNAL_ASSERT(!reference.empty(), "Cannot create an empty Union");
|
||||
|
||||
standardizeVectorForUnion(reference, &types_);
|
||||
|
||||
// Gate the assert in a regular conditional so that we don't create
|
||||
// this long error message unnecessarily
|
||||
if (types_.size() == 1) {
|
||||
std::stringstream msg;
|
||||
msg << "After type unification was performed, the Union with the "
|
||||
<< "original types {";
|
||||
for (auto i = 0; i < reference.size(); ++i) {
|
||||
msg << reference[i]->repr_str();
|
||||
if (i > 0) {
|
||||
msg << ",";
|
||||
}
|
||||
msg << " ";
|
||||
}
|
||||
msg << "} has the single type " << types_[0]->repr_str()
|
||||
<< ". Use the common supertype instead of creating a Union"
|
||||
<< "type";
|
||||
TORCH_INTERNAL_ASSERT(false, msg.str());
|
||||
}
|
||||
|
||||
can_hold_none_ = false;
|
||||
has_free_variables_ = false;
|
||||
|
||||
for (const TypePtr& type : types_) {
|
||||
if (type->kind() == NoneType::Kind) {
|
||||
can_hold_none_ = true;
|
||||
}
|
||||
if (type->hasFreeVariables()) {
|
||||
has_free_variables_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
UnionTypePtr UnionType::create(std::vector<TypePtr> reference) {
|
||||
auto union_type = new UnionType(std::move(reference));
|
||||
|
||||
// Some very special-cased logic for `Optional`. This will be deleted
|
||||
// in a later PR
|
||||
bool int_found = false;
|
||||
bool float_found = false;
|
||||
bool complex_found = false;
|
||||
bool nonetype_found = false;
|
||||
|
||||
auto update_is_opt_flags = [&](TypePtr t) {
|
||||
if (t == IntType::get()) {
|
||||
int_found = true;
|
||||
} else if (t == FloatType::get()) {
|
||||
float_found = true;
|
||||
} else if (t == ComplexType::get()) {
|
||||
complex_found = true;
|
||||
} else if (t == NoneType::get()) {
|
||||
nonetype_found = true;
|
||||
}
|
||||
};
|
||||
|
||||
for (const auto& t : union_type->containedTypes()) {
|
||||
update_is_opt_flags(t);
|
||||
}
|
||||
|
||||
bool numbertype_found = int_found && float_found && complex_found;
|
||||
|
||||
if (nonetype_found) {
|
||||
if (union_type->containedTypes().size() == 4 && numbertype_found) {
|
||||
return OptionalType::create(NumberType::get());
|
||||
}
|
||||
if (union_type->containedTypes().size() == 2) {
|
||||
auto not_none = union_type->containedTypes()[0] != NoneType::get()
|
||||
? union_type->containedTypes()[0]
|
||||
: union_type->containedTypes()[1];
|
||||
return OptionalType::create(not_none);
|
||||
}
|
||||
}
|
||||
|
||||
return UnionTypePtr(union_type);
|
||||
}
|
||||
|
||||
bool UnionType::operator==(const Type& rhs) const {
|
||||
if (auto union_rhs = rhs.cast<UnionType>()) {
|
||||
// We can't compare the type vectors for equality using `operator=`,
|
||||
// because the vectors hold `TypePtr`s and we want to compare `Type`
|
||||
// equality
|
||||
if (union_rhs->containedTypes().size() != this->containedTypes().size()) {
|
||||
return false;
|
||||
}
|
||||
// Check that all the types in `this->types_` are also in
|
||||
// `union_rhs->types_`
|
||||
return std::all_of(this->containedTypes().begin(), this->containedTypes().end(),
|
||||
[&](TypePtr lhs_type) {
|
||||
return std::any_of(union_rhs->containedTypes().begin(),
|
||||
union_rhs->containedTypes().end(),
|
||||
[&](TypePtr rhs_type) {
|
||||
return *lhs_type == *rhs_type;
|
||||
});
|
||||
});
|
||||
} else if (auto optional_rhs = rhs.cast<OptionalType>()) {
|
||||
if (optional_rhs->getElementType() == NumberType::get()) {
|
||||
return this->containedTypes().size() == 4
|
||||
&& this->can_hold_none_
|
||||
&& this->canHoldType(NumberType::get());
|
||||
}
|
||||
auto optional_lhs = this->toOptional();
|
||||
return optional_lhs && *optional_rhs == *((optional_lhs.value())->expect<OptionalType>());
|
||||
} else if (rhs.kind() == NumberType::Kind) {
|
||||
return this->containedTypes().size() == 3 && canHoldType(NumberType::get());
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool UnionType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
|
||||
std::vector<TypePtr> rhs_types;
|
||||
if (const auto union_rhs = rhs->cast<UnionType>()) {
|
||||
// Fast path
|
||||
if (this->containedTypes() == rhs->containedTypes()) {
|
||||
return true;
|
||||
}
|
||||
rhs_types = rhs->containedTypes().vec();
|
||||
} else if (const auto optional_rhs = rhs->cast<OptionalType>()) {
|
||||
rhs_types.push_back(NoneType::get());
|
||||
if (optional_rhs->getElementType() == NumberType::get()) {
|
||||
std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
|
||||
rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
|
||||
} else {
|
||||
rhs_types.push_back(optional_rhs->getElementType());
|
||||
}
|
||||
} else if (const auto number_rhs = rhs->cast<NumberType>()) {
|
||||
std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
|
||||
rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
|
||||
} else {
|
||||
rhs_types.push_back(rhs);
|
||||
}
|
||||
return std::all_of(this->containedTypes().begin(), this->containedTypes().end(),
|
||||
[&](TypePtr lhs_type) -> bool {
|
||||
return std::any_of(rhs_types.begin(),
|
||||
rhs_types.end(),
|
||||
[&](TypePtr rhs_type) -> bool {
|
||||
return lhs_type->isSubtypeOfExt(rhs_type, why_not);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) const {
|
||||
std::stringstream ss;
|
||||
|
||||
bool can_hold_numbertype = this->canHoldType(NumberType::get());
|
||||
|
||||
std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
|
||||
|
||||
auto is_numbertype = [&](TypePtr lhs) {
|
||||
for (const auto& rhs : number_types) {
|
||||
if (*lhs == *rhs) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
ss << "Union[";
|
||||
bool printed = false;
|
||||
for (size_t i = 0; i < types_.size(); ++i) {
|
||||
if (!can_hold_numbertype || !is_numbertype(types_[i])) {
|
||||
if (i > 0) {
|
||||
ss << ", ";
|
||||
printed = true;
|
||||
}
|
||||
if (is_annotation_str) {
|
||||
ss << this->containedTypes()[i]->annotation_str(printer);
|
||||
} else {
|
||||
ss << this->containedTypes()[i]->str();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (can_hold_numbertype) {
|
||||
if (printed) {
|
||||
ss << ", ";
|
||||
}
|
||||
if (is_annotation_str) {
|
||||
ss << NumberType::get()->annotation_str(printer);
|
||||
} else {
|
||||
ss << NumberType::get()->str();
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string UnionType::str() const {
|
||||
return this->unionStr(nullptr, /*is_annotation_str=*/false);
|
||||
}
|
||||
|
||||
std::string UnionType::annotation_str_impl(TypePrinter printer) const {
|
||||
return this->unionStr(printer, /*is_annotation_str=*/true);
|
||||
}
|
||||
|
||||
bool UnionType::canHoldType(TypePtr type) const {
|
||||
if (type == NumberType::get()) {
|
||||
return canHoldType(IntType::get())
|
||||
&& canHoldType(FloatType::get())
|
||||
&& canHoldType(ComplexType::get());
|
||||
} else {
|
||||
return std::any_of(this->containedTypes().begin(), this->containedTypes().end(),
|
||||
[&](TypePtr inner) {
|
||||
return type->isSubtypeOf(inner);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<TypePtr> UnionType::toOptional() const {
|
||||
if (!canHoldType(NoneType::get())) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
std::vector<TypePtr> copied_types = this->containedTypes().vec();
|
||||
|
||||
auto maybe_opt = UnionType::create(std::move(copied_types));
|
||||
|
||||
if (maybe_opt->kind() == UnionType::Kind) {
|
||||
return c10::nullopt;
|
||||
} else {
|
||||
return maybe_opt;
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<TypePtr> UnionType::subtractTypeSet(std::vector<TypePtr>& to_subtract) const {
|
||||
std::vector<TypePtr> types;
|
||||
|
||||
// Given a TypePtr `lhs`, this function says whether or not `lhs` (or
|
||||
// one of its parent types) is in the `to_subtract` vector
|
||||
auto should_subtract = [&](TypePtr lhs) -> bool {
|
||||
return std::any_of(to_subtract.begin(), to_subtract.end(),
|
||||
[&](TypePtr rhs) {
|
||||
return lhs->isSubtypeOf(rhs);
|
||||
});
|
||||
};
|
||||
|
||||
// Copy all the elements that should NOT be subtracted to the `types`
|
||||
// vector
|
||||
std::copy_if(this->containedTypes().begin(), this->containedTypes().end(),
|
||||
std::back_inserter(types),
|
||||
[&](const TypePtr t) {
|
||||
return !should_subtract(t);
|
||||
});
|
||||
|
||||
if (types.size() == 0) {
|
||||
return c10::nullopt;
|
||||
} else if (types.size() == 1) {
|
||||
return types[0];
|
||||
} else {
|
||||
return UnionType::create(std::move(types));
|
||||
}
|
||||
}
|
||||
|
||||
OptionalType::OptionalType(TypePtr contained)
|
||||
: UnionType({contained, NoneType::get()}, TypeKind::OptionalType) {
|
||||
bool is_numbertype = false;
|
||||
if (auto as_union = contained->cast<UnionType>()) {
|
||||
is_numbertype = as_union->containedTypes().size() == 3 &&
|
||||
as_union->canHoldType(NumberType::get());
|
||||
}
|
||||
if (UnionType::containedTypes().size() == 2) {
|
||||
contained_ = UnionType::containedTypes()[0]->kind()!= NoneType::Kind
|
||||
? UnionType::containedTypes()[0]
|
||||
: UnionType::containedTypes()[1];
|
||||
} else if (contained == NumberType::get() || is_numbertype) {
|
||||
contained_ = NumberType::get();
|
||||
types_.clear();
|
||||
types_.push_back(NumberType::get());
|
||||
types_.push_back(NoneType::get());
|
||||
} else {
|
||||
std::vector<TypePtr> to_subtract{NoneType::get()};
|
||||
auto without_none = this->subtractTypeSet(to_subtract);
|
||||
contained_ = UnionType::create({*without_none});
|
||||
}
|
||||
has_free_variables_ = contained_->hasFreeVariables();
|
||||
}
|
||||
|
||||
bool OptionalType::operator==(const Type& rhs) const {
|
||||
if (auto union_rhs = rhs.cast<UnionType>()) {
|
||||
auto optional_rhs = union_rhs->toOptional();
|
||||
// `**optional_rhs` = `*` to get value of `c10::optional<TypePtr>`,
|
||||
// then `*` to dereference the pointer
|
||||
return optional_rhs && *this == **optional_rhs;
|
||||
} else if (auto optional_rhs = rhs.cast<OptionalType>()) {
|
||||
return *this->getElementType() == *optional_rhs->getElementType();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool OptionalType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
|
||||
if (OptionalTypePtr optional_rhs = rhs->cast<OptionalType>()) {
|
||||
return getElementType()->isSubtypeOfExt(optional_rhs->getElementType(), why_not);
|
||||
} else if (UnionTypePtr union_rhs = rhs->cast<UnionType>()) {
|
||||
if (!union_rhs->canHoldType(NoneType::get())) {
|
||||
if (why_not) {
|
||||
*why_not << rhs->repr_str() << " cannot hold None";
|
||||
}
|
||||
return false;
|
||||
} else if (!union_rhs->canHoldType(this->getElementType())) {
|
||||
if (why_not) {
|
||||
*why_not << rhs->repr_str() << " cannot hold " << this->getElementType();
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
// NOLINTNEXTLINE(bugprone-argument-comment)
|
||||
return Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
}
|
||||
|
||||
bool NumberType::operator==(const Type& rhs) const {
|
||||
if (auto union_type = rhs.cast<UnionType>()) {
|
||||
return union_type->containedTypes().size() == 3 && union_type->canHoldType(NumberType::get());
|
||||
} else {
|
||||
return rhs.kind() == this->kind();
|
||||
}
|
||||
}
|
||||
|
||||
bool NumberType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
|
||||
if (auto union_type = rhs->cast<UnionType>()) {
|
||||
return union_type->canHoldType(NumberType::get());
|
||||
} else {
|
||||
return Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
}
|
||||
|
||||
TupleType::TupleType(
|
||||
std::vector<TypePtr> elements,
|
||||
c10::optional<c10::QualifiedName> name,
|
||||
@ -1732,8 +2188,10 @@ size_t ClassType::addAttribute(
|
||||
TORCH_CHECK(
|
||||
(type->kind() == TensorType::Kind) ||
|
||||
(type->kind() == OptionalType::Kind &&
|
||||
type->expectRef<OptionalType>().getElementType()->kind() ==
|
||||
type->expect<OptionalType>()->getElementType()->kind() ==
|
||||
TensorType::Kind) ||
|
||||
(type->kind() == UnionType::Kind &&
|
||||
TensorType::get()->isSubtypeOf(type->expect<UnionType>())) ||
|
||||
(type->kind() == NoneType::Kind),
|
||||
"Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ",
|
||||
toString(type));
|
||||
@ -1880,7 +2338,9 @@ void SymbolicShape::dump() const {
|
||||
|
||||
bool EnumType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
|
||||
return rhs->kind() == TypeKind::AnyType ||
|
||||
rhs->kind() == TypeKind::AnyEnumType || *this == *rhs;
|
||||
rhs->kind() == TypeKind::AnyEnumType ||
|
||||
*this == *rhs ||
|
||||
Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -7,8 +7,8 @@ Like all ATen methods/functions, native functions are made available
|
||||
from both ATen's C++ and Python APIs. In C++, they are made available
|
||||
either as methods on `Tensor` (`t.mymeth()`) and functions in the ATen
|
||||
namespace (`at::myfunc()`). In PyTorch, they are made available as
|
||||
methods on `Variable` or as functions on `torch._C._FunctionBase`
|
||||
(it is the user's responsibility to re-exporting these functions in
|
||||
methods on `Variable` or as functions on `torch._C._FunctionBase`.
|
||||
(It is the user's responsibility to re-export these functions in
|
||||
a more user-facing module.)
|
||||
|
||||
The rest of this document describes how to implement an ATen function.
|
||||
|
@ -50,7 +50,7 @@ class C10_API AllocationPlanner {
|
||||
private:
|
||||
AllocationPlan* allocation_plan_{nullptr};
|
||||
// Maps allocated ptr to its allocation id.
|
||||
// This is used when freeing the memory to lookup the allocation id
|
||||
// This is used when freeing the memory to look up the allocation id
|
||||
// in order to establish the lifetime of a particular allocation.
|
||||
ska::flat_hash_map<const void*, uint64_t> allocation_ptr_to_id_;
|
||||
uint64_t allocation_id_{0};
|
||||
|
@ -65,7 +65,7 @@ an RPC.
|
||||
input tensors. The output gradients of this function are sent to the source
|
||||
node to the appropriate ``send`` function during the backward pass.
|
||||
- Each ``send-recv`` pair is assigned a globally unique ``autograd_message_id``
|
||||
to uniquely identify the pair. This is useful to lookup the corresponding
|
||||
to uniquely identify the pair. This is useful to look up the corresponding
|
||||
function on a remote node during the backward pass.
|
||||
- For :ref:`rref`, whenever we call :meth:`torch.distributed.rpc.RRef.to_here`
|
||||
we attach an appropriate ``send-recv`` pair for the tensors involved.
|
||||
@ -98,7 +98,7 @@ This context serves the following purpose:
|
||||
2. During the forward pass we store the ``send`` and ``recv`` functions for
|
||||
each autograd pass in this context. This ensures we hold references to the
|
||||
appropriate nodes in the autograd graph to keep it alive. In addition to
|
||||
this, it is easy to lookup the appropriate ``send`` and ``recv`` functions
|
||||
this, it is easy to look up the appropriate ``send`` and ``recv`` functions
|
||||
during the backward pass.
|
||||
3. In general we also use this context to store some metadata for each
|
||||
distributed autograd pass.
|
||||
|
@ -66,6 +66,7 @@ set(JIT_TEST_SRCS
|
||||
${JIT_TEST_ROOT}/test_subgraph_matcher.cpp
|
||||
${JIT_TEST_ROOT}/test_subgraph_rewriter.cpp
|
||||
${JIT_TEST_ROOT}/test_subgraph_utils.cpp
|
||||
${JIT_TEST_ROOT}/test_union.cpp
|
||||
${JIT_TEST_ROOT}/test_utils.cpp
|
||||
${JIT_TEST_ROOT}/test_script_profile.cpp
|
||||
${JIT_TEST_ROOT}/test_jit_logging_levels.cpp
|
||||
|
@ -660,6 +660,31 @@ TEST(ContainerAliasingTest, PrimitveValuesDontAliasContainers) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContainerAliasingTest, UnionAliasing) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a : Dict(str, Tensor),
|
||||
%b : Tensor[],
|
||||
%c : Union(Dict(str, Tensor), Tensor[])):
|
||||
return (%a, %b, %c)
|
||||
)IR",
|
||||
&*graph);
|
||||
|
||||
AliasDb aliasDb(graph);
|
||||
auto a = graph->outputs().at(0);
|
||||
auto b = graph->outputs().at(1);
|
||||
auto c = graph->outputs().at(2);
|
||||
|
||||
EXPECT_TRUE(aliasDb.mayAlias(a, c));
|
||||
EXPECT_TRUE(aliasDb.mayAlias(b, c));
|
||||
EXPECT_TRUE(aliasDb.mayAlias(c, c));
|
||||
EXPECT_FALSE(aliasDb.mayAlias(a, b));
|
||||
EXPECT_TRUE(aliasDb.mayContainAlias(a, b));
|
||||
EXPECT_TRUE(aliasDb.mayContainAlias(a, c));
|
||||
EXPECT_TRUE(aliasDb.mayContainAlias(b, c));
|
||||
}
|
||||
|
||||
TEST(ContainerAliasingTest, InputsCanAliasOutputs) {
|
||||
// Test input aliasing
|
||||
auto graph = std::make_shared<Graph>();
|
||||
|
149
test/cpp/jit/test_union.cpp
Normal file
149
test/cpp/jit/test_union.cpp
Normal file
@ -0,0 +1,149 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
class UnionTypeTest : public ::testing::Test {
|
||||
public:
|
||||
// None
|
||||
const TypePtr none = NoneType::get();
|
||||
|
||||
// List[str]
|
||||
const TypePtr l1 = ListType::ofStrings();
|
||||
|
||||
// Optional[int]
|
||||
const TypePtr opt1 = OptionalType::create(IntType::get());
|
||||
|
||||
// Optional[float]
|
||||
const TypePtr opt2 = OptionalType::create(FloatType::get());
|
||||
|
||||
// Optional[List[str]]
|
||||
const TypePtr opt3 = OptionalType::create(ListType::ofStrings());
|
||||
|
||||
// Tuple[Optional[int], int]
|
||||
const TypePtr tup1 =
|
||||
TupleType::create({OptionalType::create(IntType::get()), IntType::get()});
|
||||
|
||||
// Tuple[int, int]
|
||||
const TypePtr tup2 = TupleType::create({IntType::get(), IntType::get()});
|
||||
|
||||
bool hasType(UnionTypePtr u, TypePtr t) {
|
||||
auto res = std::find(u->getTypes().begin(), u->getTypes().end(), t);
|
||||
return res != u->getTypes().end();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(UnionTypeTest, UnionOperatorEquals) {
|
||||
const UnionTypePtr u1 = UnionType::create({l1, tup2, StringType::get()});
|
||||
|
||||
// Same thing, but using different TypePtrs
|
||||
const TypePtr l1_ = ListType::ofStrings();
|
||||
const TypePtr tup2_ = TupleType::create({IntType::get(), IntType::get()});
|
||||
const UnionTypePtr u2 = UnionType::create({l1_, tup2_, StringType::get()});
|
||||
|
||||
ASSERT_TRUE(*u1 == *u2);
|
||||
}
|
||||
|
||||
TEST_F(UnionTypeTest, UnionCreate_OptionalT1AndOptionalT2) {
|
||||
// Goal: Union[int, float, None]
|
||||
const UnionTypePtr u = UnionType::create({opt1, opt2});
|
||||
|
||||
ASSERT_EQ(u->getTypes().size(), 3);
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, FloatType::get()));
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
|
||||
}
|
||||
|
||||
TEST_F(UnionTypeTest, UnionCreate_OptionalTAndT) {
|
||||
// Goal: Union[int, None]
|
||||
const UnionTypePtr u = UnionType::create({opt1, IntType::get()});
|
||||
|
||||
ASSERT_EQ(u->getTypes().size(), 2);
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
|
||||
}
|
||||
|
||||
TEST_F(UnionTypeTest, UnionCreate_TupleWithSubtypingRelationship) {
|
||||
// Goal: Union[Tuple[Optional[int], int], str]
|
||||
const UnionTypePtr u = UnionType::create({StringType::get(), tup1, tup2});
|
||||
|
||||
ASSERT_EQ(u->getTypes().size(), 2);
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, tup1));
|
||||
}
|
||||
|
||||
TEST_F(UnionTypeTest, UnionCreate_ContainerTAndT) {
|
||||
// Goal: Union[List[str], str]
|
||||
const UnionTypePtr u = UnionType::create({l1, StringType::get()});
|
||||
|
||||
ASSERT_EQ(u->getTypes().size(), 2);
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
|
||||
}
|
||||
|
||||
TEST_F(UnionTypeTest, UnionCreate_OptionalContainerTAndContainerTAndT) {
|
||||
// Goal: Union[List[str], None, str]
|
||||
const UnionTypePtr u = UnionType::create({l1, opt3, StringType::get()});
|
||||
|
||||
ASSERT_EQ(u->getTypes().size(), 3);
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
|
||||
ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
|
||||
}
|
||||
|
||||
TEST_F(UnionTypeTest, Subtyping_NumberType) {
|
||||
// Union[int, float, Complex]
|
||||
const UnionTypePtr union1 =
|
||||
UnionType::create({IntType::get(), FloatType::get(), ComplexType::get()});
|
||||
|
||||
// Union[int, float, Complex, None]
|
||||
const UnionTypePtr union2 = UnionType::create(
|
||||
{IntType::get(), FloatType::get(), ComplexType::get(), NoneType::get()});
|
||||
|
||||
const NumberTypePtr num = NumberType::get();
|
||||
|
||||
ASSERT_TRUE(num->isSubtypeOf(union1));
|
||||
ASSERT_TRUE(union1->isSubtypeOf(num));
|
||||
ASSERT_TRUE(*num == *union1);
|
||||
|
||||
ASSERT_TRUE(num->isSubtypeOf(union2));
|
||||
ASSERT_FALSE(union2->isSubtypeOf(num));
|
||||
ASSERT_FALSE(*num == *union2);
|
||||
}
|
||||
|
||||
TEST_F(UnionTypeTest, Subtyping_OptionalType) {
|
||||
// Union[int, None]
|
||||
const UnionTypePtr union1 =
|
||||
UnionType::create({IntType::get(), NoneType::get()});
|
||||
|
||||
// Union[int, str, None]
|
||||
const UnionTypePtr union2 =
|
||||
UnionType::create({IntType::get(), StringType::get(), NoneType::get()});
|
||||
|
||||
// Union[int, str, List[str]]
|
||||
const UnionTypePtr union3 = UnionType::create(
|
||||
{IntType::get(), StringType::get(), ListType::ofStrings()});
|
||||
|
||||
ASSERT_TRUE(none->isSubtypeOf(opt1));
|
||||
ASSERT_TRUE(none->isSubtypeOf(union1));
|
||||
ASSERT_TRUE(none->isSubtypeOf(union2));
|
||||
ASSERT_FALSE(none->isSubtypeOf(union3));
|
||||
|
||||
ASSERT_FALSE(opt1->isSubtypeOf(none));
|
||||
ASSERT_TRUE(opt1->isSubtypeOf(union1));
|
||||
ASSERT_TRUE(opt1->isSubtypeOf(union2));
|
||||
ASSERT_FALSE(opt1->isSubtypeOf(union3));
|
||||
|
||||
ASSERT_FALSE(union1->isSubtypeOf(none));
|
||||
ASSERT_TRUE(union1->isSubtypeOf(opt1));
|
||||
ASSERT_TRUE(union1->isSubtypeOf(union2));
|
||||
ASSERT_FALSE(union1->isSubtypeOf(union3));
|
||||
|
||||
ASSERT_FALSE(union2->isSubtypeOf(union1));
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -92,7 +92,7 @@ class TestList(JitTestCase):
|
||||
if 1 == 1:
|
||||
x = [1, 2, 3]
|
||||
return
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"previously has type List\[Tensor\]", "x"):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"previously had type List\[Tensor\]", "x"):
|
||||
self.checkScript(reassign_from_empty_literal, (), optimize=False)
|
||||
|
||||
def reassign_from_empty_builtin():
|
||||
@ -113,7 +113,7 @@ class TestList(JitTestCase):
|
||||
if 1 == 1:
|
||||
x = [1.0]
|
||||
return
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "previously has type", "x"):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "previously had type", "x"):
|
||||
self.checkScript(reassign_bad_type, (), optimize=False)
|
||||
|
||||
def reassign_nested():
|
||||
@ -123,7 +123,7 @@ class TestList(JitTestCase):
|
||||
if 1 == 1:
|
||||
x = [1.0]
|
||||
return
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "previously has type", "x"):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "previously had type", "x"):
|
||||
self.checkScript(reassign_nested, (), optimize=False)
|
||||
|
||||
def test_del(self):
|
||||
|
@ -92,10 +92,9 @@ class TestTyping(JitTestCase):
|
||||
|
||||
graph = torch.jit.script(fn).graph
|
||||
|
||||
print(graph)
|
||||
|
||||
# Check that we're making a `List[Tuple[str, Any]]`
|
||||
FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph)
|
||||
FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])"
|
||||
"[] = prim::ListConstruct()").run(graph)
|
||||
|
||||
def test_list_type_refinement_defaults_to_Any_list_comprehension(self):
|
||||
def fn(x):
|
||||
@ -116,10 +115,9 @@ class TestTyping(JitTestCase):
|
||||
|
||||
graph = torch.jit.script(fn).graph
|
||||
|
||||
print(graph)
|
||||
|
||||
# Check that we're making a `List[Tuple[str, Any]]`
|
||||
FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph)
|
||||
FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])"
|
||||
"[] = prim::ListConstruct()").run(graph)
|
||||
|
||||
def test_list_type_refinement_annotation_element_mismatch(self):
|
||||
def fn():
|
||||
@ -145,7 +143,8 @@ class TestTyping(JitTestCase):
|
||||
|
||||
graph = torch.jit.script(fn).graph
|
||||
|
||||
FileCheck().check(r"Dict(str, Any) = prim::DictConstruct").run(graph)
|
||||
FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])"
|
||||
" = prim::DictConstruct").run(graph)
|
||||
|
||||
def test_dict_type_refinement_defaults_to_Any_dict_comprehension(self):
|
||||
def fn(x):
|
||||
@ -161,7 +160,8 @@ class TestTyping(JitTestCase):
|
||||
|
||||
graph = torch.jit.script(fn).graph
|
||||
|
||||
FileCheck().check("Dict(str, Any) = prim::DictConstruct").run(graph)
|
||||
FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])"
|
||||
" = prim::DictConstruct").run(graph)
|
||||
|
||||
def test_dict_type_refinement_annotation_key_mismatch(self):
|
||||
def fn():
|
||||
|
657
test/jit/test_union.py
Normal file
657
test/jit/test_union.py
Normal file
@ -0,0 +1,657 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
class TestUnion(JitTestCase):
|
||||
"""
|
||||
This class tests the functionality of `Union`.
|
||||
|
||||
Note: It's important to be able to refine the type of a `Union` to
|
||||
one of its internal types. Currently, there are differences in the
|
||||
way Python expects `isinstance` checks and the way TorchScript
|
||||
expects `isinstance` checks. This means that we can't use
|
||||
`checkScript` in our test cases because either the eager mode or the
|
||||
script mode wouldn't run! So, some test cases have separate but
|
||||
equivalent functions to emulate `checkScript`.
|
||||
"""
|
||||
|
||||
def test_union_with_scalar_values(self):
|
||||
def fn(x: Union[int, float]) -> str:
|
||||
return "foo"
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
self.checkScript(fn, (1.0,))
|
||||
|
||||
scripted = torch.jit.script(fn)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[float, int\] but "
|
||||
"instead found type str"):
|
||||
scripted("1")
|
||||
|
||||
def test_union_with_collections(self):
|
||||
def fn(x: Union[Dict[str, int], List[int]]) -> str:
|
||||
return "foo"
|
||||
|
||||
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
|
||||
self.checkScript(fn, ([1, 2, 3],))
|
||||
|
||||
scripted = torch.jit.script(fn)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[List\[int\], Dict\[str, "
|
||||
r"int\]\] but instead found type "
|
||||
r"Dict\[str, str\]"):
|
||||
scripted({"foo": "bar", "baz": "qux"})
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[List\[int\], Dict\[str, "
|
||||
r"int\]\] but instead found type "
|
||||
r"List\[str\]"):
|
||||
scripted(["foo", "bar", "baz"])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[List\[int\], Dict\[str, "
|
||||
r"int\]\] but instead found type "
|
||||
"str"):
|
||||
scripted("1")
|
||||
|
||||
def test_union_with_enum(self):
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
make_global(Color)
|
||||
|
||||
def fn(x: Union[str, Color]) -> str:
|
||||
return "foo"
|
||||
|
||||
self.checkScript(fn, (Color.RED,))
|
||||
self.checkScript(fn, ("red",))
|
||||
|
||||
scripted = torch.jit.script(fn)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[__torch__.jit.test_union."
|
||||
r"Color, str\] but instead found "
|
||||
"type int"):
|
||||
scripted(1)
|
||||
|
||||
def test_union_in_class_constructor(self):
|
||||
|
||||
@torch.jit.script
|
||||
class A(object): # noqa: B903
|
||||
def __init__(self, x: Union[int, str]) -> None:
|
||||
self.x = x
|
||||
|
||||
def fn(x: Union[str, int]) -> A:
|
||||
return A(x)
|
||||
|
||||
self.assertEqual(fn("foo").x, "foo")
|
||||
self.assertEqual(fn(1).x, 1)
|
||||
|
||||
scripted = torch.jit.script(fn)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[int, str\] but instead "
|
||||
r"found type List\[str\]"):
|
||||
scripted(["foo", "bar", "baz"])
|
||||
|
||||
def test_union_return_type(self):
|
||||
def fn(x: int) -> Union[int, str]:
|
||||
return "foo"
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_union_as_annotation(self):
|
||||
def fn() -> Union[int, str]:
|
||||
x: Union[int, str] = "foo"
|
||||
return x
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_union_as_annotation_in_typed_container(self):
|
||||
def fn() -> None:
|
||||
l: List[Union[int, str]] = []
|
||||
u1: Union[int, str] = "foo"
|
||||
u2: Union[int, str] = 1
|
||||
l.append(u1)
|
||||
l.append(u2)
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_union_as_annotation_py2(self):
|
||||
def fn():
|
||||
# type: () -> Union[int, str]
|
||||
x: Union[int, str] = "foo"
|
||||
return x
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_union_as_internal_tuple_type(self):
|
||||
def fn():
|
||||
t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
|
||||
return t
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_union_variable_can_be_reassigned(self):
|
||||
@torch.jit.script
|
||||
def aux1(i: int):
|
||||
return int(i ** 2)
|
||||
|
||||
@torch.jit.script
|
||||
def aux2(s: str):
|
||||
return s + s
|
||||
|
||||
def fn() -> Union[int, str]:
|
||||
x: Union[int, str] = "foo"
|
||||
i: int = 1
|
||||
x = i
|
||||
y: int = aux1(x)
|
||||
z: str = aux2(str(y))
|
||||
x = z
|
||||
return x
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_union_does_not_replace_existing_annotated_type(self):
|
||||
def fn():
|
||||
x: List[int] = [1, 2, 3]
|
||||
x.append("foo")
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
|
||||
scripted = torch.jit.script(fn)
|
||||
scripted()
|
||||
|
||||
def test_union_does_not_replace_existing_annotated_type_union(self):
|
||||
def fn():
|
||||
x: List[Union[int, str]] = [1, "foo", 3]
|
||||
x.append(2.0)
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
|
||||
scripted = torch.jit.script(fn)
|
||||
scripted()
|
||||
|
||||
def test_union_does_not_replace_existing_annotated_type_empty_container(self):
|
||||
def fn():
|
||||
x: List[int] = []
|
||||
x.append("foo")
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
|
||||
scripted = torch.jit.script(fn)
|
||||
scripted()
|
||||
|
||||
def test_unions_of_unions_are_flattened(self):
|
||||
@torch.jit.script
|
||||
def fn(x: Union[Union[int, str], float]) -> str:
|
||||
return "foo"
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[float, int, str]") \
|
||||
.run(s)
|
||||
|
||||
def test_unions_of_a_single_argument_vanish(self):
|
||||
@torch.jit.script
|
||||
def fn(x: Union[int]) -> str:
|
||||
return "foo"
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : int") \
|
||||
.run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped(self):
|
||||
@torch.jit.script
|
||||
def fn(x: Union[int, str, int]) -> str:
|
||||
return "foo"
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[int, str]") \
|
||||
.run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped_optional(self):
|
||||
@torch.jit.script
|
||||
def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
|
||||
return "foo"
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[float, int, NoneType]") \
|
||||
.run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped_subtyping(self):
|
||||
@torch.jit.script
|
||||
def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
|
||||
return "foo"
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[(int?, int), str]") \
|
||||
.run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped_container(self):
|
||||
@torch.jit.script
|
||||
def fn(x: Union[List[str], List[float], List[str]]) -> str:
|
||||
return "foo"
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[float[], str[]]") \
|
||||
.run(s)
|
||||
|
||||
def test_union_argument_order_is_ignored(self):
|
||||
@torch.jit.script
|
||||
def fn1(x: Union[int, str]) -> str:
|
||||
return "foo"
|
||||
|
||||
@torch.jit.script
|
||||
def fn2(x: Union[str, int]) -> str:
|
||||
return "foo"
|
||||
|
||||
for s in (fn1.graph, fn2.graph):
|
||||
FileCheck().check("x : Union[int, str]") \
|
||||
.run(s)
|
||||
|
||||
def test_union_argument_order_is_ignored_container(self):
|
||||
@torch.jit.script
|
||||
def fn1(x: Union[List[str], List[int]]) -> str:
|
||||
return "foo"
|
||||
|
||||
@torch.jit.script
|
||||
def fn2(x: Union[List[int], List[str]]) -> str:
|
||||
return "foo"
|
||||
|
||||
for s in (fn1.graph, fn2.graph):
|
||||
FileCheck().check("x : Union[int[], str[]]") \
|
||||
.run(s)
|
||||
|
||||
def test_union_T_None_is_equivalent_to_optional_T(self):
|
||||
@torch.jit.script
|
||||
def inner(x: Union[int, None]) -> int:
|
||||
if x is not None:
|
||||
return x
|
||||
else:
|
||||
return 5
|
||||
|
||||
@torch.jit.script
|
||||
def fn1() -> int:
|
||||
a: Optional[int] = 5
|
||||
b: Optional[int] = None
|
||||
a_ = inner(a)
|
||||
b_ = inner(b)
|
||||
return a_ + b_
|
||||
|
||||
self.assertEqual(fn1(), 10)
|
||||
|
||||
@torch.jit.script
|
||||
def inner2(x: Optional[int]) -> int:
|
||||
if x is not None:
|
||||
return x
|
||||
else:
|
||||
return 5
|
||||
|
||||
@torch.jit.script
|
||||
def fn2() -> int:
|
||||
a: Union[int, None] = 5
|
||||
b: Union[int, None] = None
|
||||
a_ = inner(a)
|
||||
b_ = inner(b)
|
||||
return a_ + b_
|
||||
|
||||
self.assertEqual(fn2(), 10)
|
||||
|
||||
def test_union_optional_of_union_is_flattened(self):
|
||||
@torch.jit.script
|
||||
def fn(flag: int) -> Union[str, int, None]:
|
||||
y: Union[int, str, None] = "foo"
|
||||
if flag == 0:
|
||||
x: Optional[Union[int, str]] = y
|
||||
elif flag == 1:
|
||||
x: Optional[Union[int, str]] = 1
|
||||
else:
|
||||
x: Optional[Union[int, str]] = None
|
||||
return x
|
||||
|
||||
# Can't use `checkScript` because it will flag the fact that
|
||||
# the original code has `Optional[Union[int, str]]` but the
|
||||
# saved/loaded code has `Union[int, NoneType, str]` (even
|
||||
# though this is exactly what we want)
|
||||
self.assertEqual(fn(0), "foo")
|
||||
self.assertEqual(fn(1), 1)
|
||||
self.assertEqual(fn(2), None)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(fn, buffer)
|
||||
buffer = io.BytesIO(buffer.getvalue())
|
||||
l = torch.jit.load(buffer)
|
||||
|
||||
s = l.code
|
||||
|
||||
FileCheck().check("Union[int, NoneType, str]") \
|
||||
.check("Union[int, NoneType, str]") \
|
||||
.run(s)
|
||||
|
||||
def test_union_subclasses_larger_union(self):
|
||||
def fn() -> Union[int, str, torch.Tensor]:
|
||||
x: Union[int, str] = "foo"
|
||||
return x
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
# TODO: We would like to eventually support this. The issue is being
|
||||
# tracked at https://github.com/pytorch/pytorch/issues/58167
|
||||
def test_union_as_dict_key(self):
|
||||
def fn():
|
||||
x: Dict[Union[int, str], str] = {}
|
||||
x["foo"] = "bar"
|
||||
x[1] = 2
|
||||
return x[1]
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "only int, float, "
|
||||
"complex, Tensor and string keys "
|
||||
"are supported"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_union_as_dict_value(self):
|
||||
def fn():
|
||||
x: Dict[str, Union[int, str]] = {}
|
||||
x["foo"] = "bar"
|
||||
x["baz"] = 2
|
||||
return x["baz"]
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_union_module_with_union_instance_variable(self):
|
||||
class M(torch.nn.Module):
|
||||
|
||||
x: Union[int, str]
|
||||
|
||||
def __init__(self, x: Union[int, str]):
|
||||
super().__init__()
|
||||
self.x: Union[int, str] = x
|
||||
|
||||
def forward(self, y: Union[int, str]):
|
||||
self.x = y
|
||||
return self.x
|
||||
|
||||
self.checkModule(M(2,), (1,))
|
||||
self.checkModule(M("bar"), ("foo",))
|
||||
|
||||
def test_union_module_with_union_class_variable(self):
|
||||
class M(torch.nn.Module):
|
||||
x: Union[int, str] = "foo"
|
||||
|
||||
def __init__(self, y: int):
|
||||
super().__init__()
|
||||
x = y
|
||||
|
||||
def forward(self, z: str):
|
||||
x = z
|
||||
return x
|
||||
|
||||
self.checkModule(M(1), ("foo",))
|
||||
|
||||
def test_union_type_refinement(self):
|
||||
def fn(x: Union[int, str]) -> str:
|
||||
if isinstance(x, str):
|
||||
z = x + "bar"
|
||||
return x
|
||||
else:
|
||||
return "baz"
|
||||
|
||||
self.checkScript(fn, ("foo",))
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_union_type_refinement_union_rhs(self):
|
||||
def fn(x: int) -> str:
|
||||
if torch.jit.isinstance(x, Union[int, str]):
|
||||
return "bar"
|
||||
else:
|
||||
return "baz"
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_union_type_refinement_tuple_rhs(self):
|
||||
def fn(x: Union[int, float, List[str]]) -> str:
|
||||
if isinstance(x, (int, float)):
|
||||
if isinstance(x, int):
|
||||
return str(x)
|
||||
else:
|
||||
return "foo"
|
||||
else:
|
||||
if len(x):
|
||||
return x[0]
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
self.checkScript(fn, (1.0,))
|
||||
self.checkScript(fn, (["a", "b", "c"],))
|
||||
|
||||
def test_union_type_refinement_tuple_rhs_noncontained_type(self):
|
||||
def fn(x: Union[int, List[str]]) -> str:
|
||||
if isinstance(x, (int, float)):
|
||||
y = x + x
|
||||
return str(y)
|
||||
else:
|
||||
if len(x):
|
||||
return x[0]
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
self.checkScript(fn, (["a", "b", "c"],))
|
||||
|
||||
def test_union_type_refinement_tuple_rhs_union(self):
|
||||
@torch.jit.script
|
||||
def fn(x: int) -> str:
|
||||
if torch.jit.isinstance(x, (Union[int, str], float)):
|
||||
y = x + x
|
||||
return str(y)
|
||||
else:
|
||||
return "foo"
|
||||
|
||||
# TODO: There's currently an unrelated bug in
|
||||
# `torch.jit.isinstance` that makes it fail for tuple literals.
|
||||
# Posted here: https://github.com/pytorch/pytorch/issues/60095
|
||||
# Change `assertEqual` to `checkScript` when the bug is fixed
|
||||
self.assertEqual(fn(1), "2")
|
||||
|
||||
def test_union_type_refinement_statically_false(self):
|
||||
@torch.jit.script
|
||||
def fn(x: int) -> str:
|
||||
if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
|
||||
z = x + "foo"
|
||||
return z
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
s = fn.graph
|
||||
|
||||
# Check that we don't have any branching statements
|
||||
FileCheck().check_not("block0()") \
|
||||
.check_not("block1()") \
|
||||
.run(s)
|
||||
|
||||
def test_union_type_refinement_statically_true(self):
|
||||
@torch.jit.script
|
||||
def fn(x: Union[List[int], int]) -> Union[List[int], int]:
|
||||
if not torch.jit.isinstance(x, (int, List[int])):
|
||||
return x
|
||||
else:
|
||||
l = [1, 2, 3]
|
||||
y: Union[List[int], int] = l
|
||||
return y
|
||||
|
||||
s = fn.graph
|
||||
|
||||
# Check that we don't have any branching statements
|
||||
FileCheck().check_not("block0()") \
|
||||
.check_not("block1()") \
|
||||
.run(s)
|
||||
|
||||
def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
|
||||
def fn(x: Union[List[int], int]) -> int:
|
||||
if torch.jit.isinstance(x, (int, float, str)):
|
||||
# We should know that `x` is an `int` here
|
||||
z = x + 1
|
||||
return z
|
||||
else:
|
||||
return 100
|
||||
|
||||
self.checkScript(fn, ([1, 2, 3],))
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_union_type_refinement_partial_static_refinement_union_rhs(self):
|
||||
def fn(x: Union[List[int], int]) -> int:
|
||||
if torch.jit.isinstance(x, Union[int, float, str]):
|
||||
# We should know that `x` is an `int` here
|
||||
z = x + 1
|
||||
return z
|
||||
else:
|
||||
return 100
|
||||
|
||||
self.checkScript(fn, ([1, 2, 3],))
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_union_type_refinement_internal_declaration(self):
|
||||
def fn(flag: bool) -> str:
|
||||
x: Union[int, str, None] = None
|
||||
if (flag):
|
||||
y = "foo"
|
||||
else:
|
||||
y = 1
|
||||
if isinstance(x, str):
|
||||
return x
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
self.checkScript(fn, (True,))
|
||||
self.checkScript(fn, (False,))
|
||||
|
||||
def test_union_branching_with_union_return_and_homogenous_types(self):
|
||||
def fn(x: int) -> Union[int, str]:
|
||||
if x % 2:
|
||||
return "foo"
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
self.checkScript(fn, (8,))
|
||||
|
||||
def test_union_branching_does_not_autoinfer_undeclared_union(self):
|
||||
def fn(x: int) -> str:
|
||||
if x % 2:
|
||||
y = "foo"
|
||||
else:
|
||||
y = x
|
||||
if isinstance(y, str):
|
||||
return y
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "y is set to type str"
|
||||
" in the true branch and type int "
|
||||
"in the false branch"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_union_branching_does_not_widen_existing_inferred_type(self):
|
||||
def fn(x: int) -> str:
|
||||
y = "foo"
|
||||
if x % 2:
|
||||
y = "bar"
|
||||
else:
|
||||
y = x
|
||||
if isinstance(y, str):
|
||||
return y
|
||||
else:
|
||||
return "baz"
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "previously had type "
|
||||
"str but is now being assigned to a"
|
||||
" value of type int"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_union_schema_matching_on_internal_type(self):
|
||||
def fn(x: Union[List[int], Dict[str, int]]) -> int:
|
||||
if torch.jit.isinstance(x, List[int]):
|
||||
return x[0]
|
||||
else:
|
||||
return list(x.values())[0]
|
||||
|
||||
self.checkScript(fn, ([1, 2, 3],))
|
||||
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
|
||||
|
||||
def test_union_subtractive_refinement(self):
|
||||
def fn(x: Union[List[int], int]) -> int:
|
||||
if not isinstance(x, int):
|
||||
x.append(1)
|
||||
return x[0]
|
||||
else:
|
||||
return x
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
self.checkScript(fn, ([1, 2, 3],))
|
||||
|
||||
def test_union_subtractive_refinement_with_container(self):
|
||||
def fn(x: Union[List[int], int]) -> int:
|
||||
if not torch.jit.isinstance(x, List[int]):
|
||||
return x
|
||||
else:
|
||||
x.append(1)
|
||||
return x[0]
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
self.checkScript(fn, ([1, 2, 3],))
|
||||
|
||||
def test_union_memory_aliasing(self):
|
||||
def fn():
|
||||
x : List[torch.Tensor] = []
|
||||
z : List[Optional[List[torch.Tensor]]] = []
|
||||
z.append(x)
|
||||
x_alias = z[0]
|
||||
if torch.jit.isinstance(x_alias, List[torch.Tensor]):
|
||||
x_alias.append(torch.tensor(3))
|
||||
return x
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_union_serialization_preserves_type_annotations(self):
|
||||
# This function will fail after being torch.jit.save'd and
|
||||
# torch.jit.load'd if the type annotations aren't preserved
|
||||
# for Union during serialization. We need the `Union[str, int]`
|
||||
# annotation to make sure that `y` is typed as a Union instead
|
||||
# of as a str in one branch and an int in the other
|
||||
def fn(x: int) -> str:
|
||||
if x % 2:
|
||||
y: Union[str, int] = "bar"
|
||||
else:
|
||||
y: Union[str, int] = x
|
||||
if isinstance(y, str):
|
||||
return y
|
||||
else:
|
||||
return "baz"
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
self.checkScript(fn, (8,))
|
@ -62,6 +62,7 @@ from jit.test_parametrization import TestParametrization # noqa: F401
|
||||
from jit.test_attr import TestGetDefaultAttr # noqa: F401
|
||||
from jit.test_aten_pow import TestAtenPow # noqa: F401
|
||||
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
|
||||
from jit.test_union import TestUnion # noqa: F401
|
||||
|
||||
# Torch
|
||||
from torch import Tensor
|
||||
@ -2518,32 +2519,6 @@ graph(%Ra, %Rb):
|
||||
t = Test()
|
||||
self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
|
||||
|
||||
def test_union_to_optional(self):
|
||||
def test1(u: Union[int, None]) -> int:
|
||||
if u is not None:
|
||||
return u
|
||||
else:
|
||||
return 0
|
||||
scripted = torch.jit.script(test1)
|
||||
self.assertEqual(scripted(10), test1(10))
|
||||
|
||||
def test2(u: Union[None, int]) -> int:
|
||||
if u is not None:
|
||||
return u
|
||||
else:
|
||||
return 0
|
||||
scripted = torch.jit.script(test2)
|
||||
self.assertEqual(scripted(40), test2(40))
|
||||
|
||||
def test3(u: Union[float, int]) -> int:
|
||||
if u is not None:
|
||||
return u
|
||||
else:
|
||||
return 0
|
||||
expected_result = "General Union types are not currently supported"
|
||||
with self.assertRaisesRegex(RuntimeError, expected_result):
|
||||
torch.jit.script(test3)
|
||||
|
||||
def test_mutable_default_values(self):
|
||||
with self.assertRaisesRegex(Exception, "Mutable default parameters"):
|
||||
@torch.jit.script
|
||||
@ -8900,6 +8875,7 @@ dedent """
|
||||
torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
|
||||
|
||||
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
|
||||
@unittest.skipIf(True, "Skipping while landing PR stack")
|
||||
def test_torch_functional(self):
|
||||
def stft(input, n_fft):
|
||||
# type: (Tensor, int) -> Tensor
|
||||
@ -9809,8 +9785,9 @@ dedent """
|
||||
bar()
|
||||
|
||||
def test_if_different_type(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Type mismatch: c0 is set to type int "
|
||||
"in the true branch and type float in the false branch:"):
|
||||
with self.assertRaisesRegex(RuntimeError, "c0 is set to type "
|
||||
"int in the true branch and type "
|
||||
"float in the false branch"):
|
||||
@torch.jit.script
|
||||
def diff_type_used():
|
||||
if 1 == 2:
|
||||
@ -9819,7 +9796,7 @@ dedent """
|
||||
c0 = 1.0
|
||||
return c0
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously has type float"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously had type float"):
|
||||
@torch.jit.script
|
||||
def diff_existing_type(x):
|
||||
c0 = 1.0
|
||||
@ -10602,7 +10579,7 @@ dedent """
|
||||
with self.assertRaisesRegex(RuntimeError, r'Expected a value of'
|
||||
r' type \'List\[int\]\' for argument'
|
||||
r' \'size\' but instead found type '
|
||||
r'\'List\[Any\]\''):
|
||||
r'\'List\[Union\[List\[int\], int\]\]'):
|
||||
@torch.jit.script
|
||||
def f6(a):
|
||||
a.expand(size=[3, [4]])
|
||||
@ -12672,7 +12649,7 @@ dedent """
|
||||
for pair in self.type_input_return_pairs():
|
||||
cu = torch.jit.CompilationUnit(self.format_code(code, pair))
|
||||
test_str.append(str(cu.foo.schema))
|
||||
self.assertExpected("\n".join(test_str))
|
||||
self.assertExpected("\n".join(test_str) + "\n")
|
||||
|
||||
# String frontend , Python 3-style type annotations , Script method
|
||||
def test_annot_string_py3_method(self):
|
||||
@ -12691,7 +12668,7 @@ dedent """
|
||||
tm = TestModule()
|
||||
tm.define(self.format_code(code, pair))
|
||||
test_str.append(str(tm.foo.schema))
|
||||
self.assertExpectedStripMangled("\n".join(test_str))
|
||||
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
|
||||
|
||||
# String frontend , MyPy-style type comments , Script function
|
||||
def test_annot_string_mypy_fn(self):
|
||||
@ -12704,7 +12681,7 @@ dedent """
|
||||
for pair in self.type_input_return_pairs():
|
||||
cu = torch.jit.CompilationUnit(self.format_code(code, pair))
|
||||
test_str.append(str(cu.foo.schema))
|
||||
self.assertExpectedStripMangled("\n".join(test_str))
|
||||
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
|
||||
|
||||
# String frontend , MyPy-style type comments , Script method
|
||||
def test_annot_string_mypy_method(self):
|
||||
@ -12725,7 +12702,7 @@ dedent """
|
||||
tm = TestModule()
|
||||
tm.define(self.format_code(code, pair))
|
||||
test_str.append(str(tm.foo.schema))
|
||||
self.assertExpectedStripMangled("\n".join(test_str))
|
||||
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
|
||||
|
||||
# Python AST Frontend , Python 3-style type annotations , Script function
|
||||
def test_annot_ast_py3_fn(self):
|
||||
@ -12742,7 +12719,7 @@ dedent """
|
||||
for pair in self.type_input_return_pairs():
|
||||
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
|
||||
test_str.append(str(fn.schema))
|
||||
self.assertExpectedStripMangled("\n".join(test_str))
|
||||
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
|
||||
|
||||
def test_multiline_annot_ast_py3_fn(self):
|
||||
code = dedent('''
|
||||
@ -12817,7 +12794,7 @@ dedent """
|
||||
for pair in self.type_input_return_pairs():
|
||||
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
|
||||
test_str.append(str(fn.foo.schema))
|
||||
self.assertExpectedStripMangled("\n".join(test_str))
|
||||
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
|
||||
|
||||
# Python AST Frontend , MyPy-style type comments , Script function
|
||||
def test_annot_ast_mypy_fn(self):
|
||||
@ -12833,7 +12810,7 @@ dedent """
|
||||
for pair in self.type_input_return_pairs():
|
||||
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
|
||||
test_str.append(str(fn.schema))
|
||||
self.assertExpected("\n".join(test_str))
|
||||
self.assertExpected("\n".join(test_str) + "\n")
|
||||
|
||||
# Python AST Frontend , MyPy-style type comments , Script method
|
||||
def test_annot_ast_mypy_method(self):
|
||||
@ -12851,7 +12828,7 @@ dedent """
|
||||
for pair in self.type_input_return_pairs():
|
||||
fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
|
||||
test_str.append(str(fn.foo.schema))
|
||||
self.assertExpectedStripMangled("\n".join(test_str))
|
||||
self.assertExpectedStripMangled("\n".join(test_str) + "\n")
|
||||
|
||||
# Tests that "# type: ignore[*]" is supported in type lines and is
|
||||
# properly ignored.
|
||||
@ -13521,8 +13498,8 @@ dedent """
|
||||
self.checkScript(fn, ("y"))
|
||||
|
||||
def index_str_to_tensor(s):
|
||||
# type: (str) -> int
|
||||
return torch.tensor(ord(s))
|
||||
# type: (str) -> Tensor
|
||||
return torch.tensor(ord(s)) # noqa: T484
|
||||
|
||||
s = u'\u00a3'.encode('utf8')[:1]
|
||||
self.checkScript(index_str_to_tensor, (s,))
|
||||
|
@ -1,5 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from functools import partial, wraps
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
@ -684,6 +685,7 @@ class TestJit(JitCommonTestCase):
|
||||
# and runtimes (eager, traced, scripted).
|
||||
# TODO WARNING: inplace x {traced, scripted} not currently tested
|
||||
@_variant_ops(op_db)
|
||||
@unittest.skipIf(True, "Temporarily skipping while landing Union PR stack")
|
||||
def test_variant_consistency_jit(self, device, dtype, op):
|
||||
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
|
||||
op.supports_complex_autograd(torch.device(device).type))
|
||||
|
@ -210,6 +210,7 @@ class TestPublicBindings(unittest.TestCase):
|
||||
"TupleType",
|
||||
"Type",
|
||||
"unify_type_list",
|
||||
"UnionType",
|
||||
"Use",
|
||||
"Value",
|
||||
"autocast_decrement_nesting",
|
||||
|
@ -1001,6 +1001,9 @@ class TupleType(JitType):
|
||||
def __init__(self, a: List[Optional[JitType]]) -> None: ...
|
||||
def elements(self) -> List[JitType]: ...
|
||||
|
||||
class UnionType(JitType):
|
||||
def __init__(self, a: List[JitType]) -> None: ...
|
||||
|
||||
class ClassType(JitType):
|
||||
def __init__(self, qualified_name: str) -> None: ...
|
||||
|
||||
|
@ -885,33 +885,28 @@ def is_dict(ann) -> bool:
|
||||
(getattr(ann, '__origin__', None) is Dict or
|
||||
getattr(ann, '__origin__', None) is dict)
|
||||
|
||||
def is_optional(ann) -> bool:
|
||||
def is_union(ann):
|
||||
if ann is Union:
|
||||
raise_error_container_parameter_missing("Union")
|
||||
|
||||
return (hasattr(ann, '__module__') and
|
||||
ann.__module__ == 'typing' and
|
||||
(getattr(ann, '__origin__', None) is Union))
|
||||
|
||||
def is_optional(ann):
|
||||
if ann is Optional:
|
||||
raise_error_container_parameter_missing("Optional")
|
||||
|
||||
# Optional[T] is just shorthand for Union[T, None], so check for both
|
||||
def safe_is_subclass(the_type, super_type):
|
||||
# Don't throw if `the_type` isn't a class type (e.g. if it is
|
||||
# another type annotation instance)
|
||||
if not inspect.isclass(the_type):
|
||||
return False
|
||||
return issubclass(the_type, super_type)
|
||||
def is_optional_as_optional(ann):
|
||||
return (hasattr(ann, '__module__') and
|
||||
ann.__module__ == 'typing' and
|
||||
(getattr(ann, '__origin__', None) is Optional))
|
||||
|
||||
if not hasattr(ann, '__module__'):
|
||||
return False
|
||||
def is_union_as_optional(ann):
|
||||
ann_args = ann.__args__
|
||||
return len(ann_args) == 2 and None in ann_args
|
||||
|
||||
union_optional = False
|
||||
if ann.__module__ == 'typing' and \
|
||||
(getattr(ann, '__origin__', None) is Union):
|
||||
args = getattr(ann, '__args__', ())
|
||||
if len(args) == 2:
|
||||
union_optional = (safe_is_subclass(args[1], type(None)) and not safe_is_subclass(args[0], type(None))) \
|
||||
or (safe_is_subclass(args[0], type(None)) and not safe_is_subclass(args[1], type(None)))
|
||||
|
||||
optional = ann.__module__ == 'typing' and \
|
||||
(getattr(ann, '__origin__', None) is Optional)
|
||||
|
||||
return optional or union_optional
|
||||
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
|
||||
|
||||
def is_future(ann) -> bool:
|
||||
if ann is Future:
|
||||
@ -1192,14 +1187,15 @@ def container_checker(obj, target_type) -> bool:
|
||||
elif not isinstance(el, el_type):
|
||||
return False
|
||||
return True
|
||||
elif origin_type is Union: # actually handles Optional Case
|
||||
elif origin_type is Union: # also handles Optional
|
||||
if obj is None: # check before recursion because None is always fine
|
||||
return True
|
||||
optional_type = get_args(target_type)[0]
|
||||
optional_origin = get_origin(optional_type)
|
||||
if optional_origin:
|
||||
return container_checker(obj, optional_type)
|
||||
elif isinstance(obj, optional_type):
|
||||
inner_types = get_args(target_type)
|
||||
for t in inner_types:
|
||||
t_origin = get_origin(t)
|
||||
if (t_origin):
|
||||
return container_checker(obj, t)
|
||||
elif isinstance(obj, t):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -792,7 +792,7 @@ In practice, the interpreter will allocate one Stack, and it will eventually rea
|
||||
|
||||
[runtime/operator.h](runtime/operator.h)
|
||||
|
||||
The Operator object represents a single registered operator in the system. It combines a FunctionSchema that describes how an Operation executes with a method to lookup the corresponding Operation given the `Node` representing the operator in a `Graph`. Most Operators are defined by providing a FunctionSchema and an Operation function. However, primitives like prim::Unpack require knowledge of their `Node` to know how to operate (e.g. how many elements to unpack). These Operators have a function that takes a `Node*` and returns an operation.
|
||||
The Operator object represents a single registered operator in the system. It combines a FunctionSchema that describes how an Operation executes with a method to look up the corresponding Operation given the Node representing the operator in a Graph. Most Operators are defined by providing a FunctionSchema and an Operation function. However, primitives like prim::Unpack require knowledge of their Node to know how to operate (e.g. how many elements to unpack). These Operators have a function that takes a `Node*` and returns an operation.
|
||||
|
||||
|
||||
## Interpreter ##
|
||||
@ -1282,13 +1282,14 @@ Note the alias set `*`. This is the **wildcard set**. Optimization passes must a
|
||||
This annotation language is consumed by the `FunctionSchema` parser, which produces `AliasInfo` objects summarizing the aliasing relationships for each schema `Argument`.
|
||||
|
||||
### Alias Analysis in the IR
|
||||
|
||||
[ir/alias_analysis.h](ir/alias_analysis.h)
|
||||
|
||||
An alias analysis pass consumes the per-operator aliasing information to construct a database of aliasing and mutation relationships in a graph, called `AliasDb`. This section focuses on the alias analysis pass; the public interface to `AliasDb` will be described later.
|
||||
|
||||
The core data structure in the AliasDb is called `AliasTracker`, which is a DAG where the edges are "may point to" relationships and the vertices are aliasing `Element`s. The most common kind of `Element` is an IR `Value`, but there are other kinds of things that can alias that aren't first-class `Value`s in the IR, like wildcards or contained types (such as in a list or tuple).
|
||||
The core data structure in the AliasDb is called `MemoryDAG`, which is a DAG where the edges are "may point to" relationships and the vertices are aliasing `Element`s. The most common kind of `Element` is an IR `Value`, but there are other kinds of things that can alias that aren't first-class `Value`s in the IR, like wildcards or contained types (such as in a list or tuple).
|
||||
|
||||
The alias analysis pass walks through the nodes in a graph, examining schema `AliasInfo` objects and adding edges in the `AliasTracker` DAG accordingly. For example, for the node:
|
||||
The alias analysis pass walks through the nodes in a graph, examining schema `AliasInfo` objects and adding edges in the `MemoryDAG` accordingly. For example, for the node:
|
||||
```
|
||||
%output : Tensor = aten::view(%self, %size)
|
||||
```
|
||||
@ -1321,7 +1322,7 @@ A few things to note:
|
||||
|
||||
The last point demonstrates a key concept: *leaf elements uniquely describe memory locations*. Since a leaf element doesn't point to anything, the memory that backs it must have been freshly allocated by some op. Thus we can use leaf elements to represent disjoint memory locations.
|
||||
|
||||
So to determine whether `a` and `b` may alias, we traverse the `AliasTracker` DAG and figure out if `a` and `b` share any leaf nodes. If they do, then we know `a` and `b` might point to the same memory location, i.e. `a` and `b` may alias. This kind of query is common enough that `AliasTracker` does path compression to speed up leaf-finding, so that aliasing queries can be serviced in amortized constant time.
|
||||
So to determine whether `a` and `b` may alias, we traverse the `MemoryDAG` DAG and figure out if `a` and `b` share any leaf nodes. If they do, then we know `a` and `b` might point to the same memory location, i.e. `a` and `b` may alias. This kind of query is common enough that `MemoryDAG` does path compression to speed up leaf-finding, so that aliasing queries can be serviced in amortized constant time.
|
||||
|
||||
### Writing optimization passes with `AliasDb`
|
||||
`AliasDb` provides a high-level interface to help people write mutability-safe optimization passes.
|
||||
|
@ -93,10 +93,8 @@ struct ControlFlowLoadStores {
|
||||
for (const auto& x : mutated_variables) {
|
||||
auto true_type = true_vars->findInAnyFrame(x);
|
||||
auto false_type = false_vars->findInAnyFrame(x);
|
||||
auto unified = unifyTypes(true_type, false_type);
|
||||
if (!unified) {
|
||||
continue;
|
||||
}
|
||||
auto unified =
|
||||
unifyTypes(true_type, false_type, /*default_to_union=*/true);
|
||||
|
||||
addBlockOutput(true_block, true_type, x);
|
||||
addBlockOutput(false_block, false_type, x);
|
||||
|
@ -150,8 +150,10 @@ struct ExitTransformer {
|
||||
registerBlockOutputs(if_view.thenBlock(), true_outs);
|
||||
registerBlockOutputs(if_view.elseBlock(), false_outs);
|
||||
for (const auto i : c10::irange(true_outs.size())) {
|
||||
auto out_type =
|
||||
unifyTypes(true_outs.at(i)->type(), false_outs.at(i)->type());
|
||||
auto out_type = unifyTypes(
|
||||
true_outs.at(i)->type(),
|
||||
false_outs.at(i)->type(),
|
||||
/*default_to_union=*/true);
|
||||
n->addOutput()->setType(*out_type);
|
||||
}
|
||||
}
|
||||
|
@ -185,7 +185,9 @@ NoneStatus canBeNone(Value* v) {
|
||||
if (v->node()->mustBeNone()) {
|
||||
return ALWAYS;
|
||||
}
|
||||
if (v->type()->kind() == OptionalType::Kind) {
|
||||
if (v->type()->kind() == OptionalType::Kind ||
|
||||
(v->type()->kind() == UnionType::Kind &&
|
||||
v->type()->expect<UnionType>()->canHoldType(NoneType::get()))) {
|
||||
return MAYBE;
|
||||
}
|
||||
return NEVER;
|
||||
@ -385,7 +387,7 @@ struct Environment {
|
||||
std::stringstream why_not;
|
||||
if (!as_simple_value->type()->isSubtypeOfExt(parent_type, &why_not)) {
|
||||
auto error = ErrorReport(loc);
|
||||
error << "Variable '" << name << "' previously has type "
|
||||
error << "Variable '" << name << "' previously had type "
|
||||
<< simple_parent->type()->repr_str()
|
||||
<< " but is now being assigned to a value of type "
|
||||
<< as_simple_value->type()->repr_str();
|
||||
@ -547,6 +549,7 @@ struct Environment {
|
||||
if (!retval && required) {
|
||||
throwVarNotFoundError(ident, range);
|
||||
}
|
||||
|
||||
return retval;
|
||||
}
|
||||
|
||||
@ -1010,57 +1013,61 @@ struct to_ir {
|
||||
}
|
||||
|
||||
void emitReturn(const Return& stmt) {
|
||||
TypePtr result_type = def_stack_.back().declared_return_type_;
|
||||
Value* result = emitExpr(stmt.expr(), result_type);
|
||||
TypePtr declared_return_type =
|
||||
def_stack_.back().declared_return_type_; // nullptr if not annotated
|
||||
auto actual_return = emitExpr(stmt.expr(), declared_return_type);
|
||||
|
||||
// result type is annotated, every return must convert to that type
|
||||
if (result_type) {
|
||||
if (declared_return_type) {
|
||||
// this guard skips implicit conversion from None -> Tensor for the return
|
||||
// type. otherwise forgetting a return a function returning a tensor will
|
||||
// cause a None to be converted to a tensor.
|
||||
if (!(result_type->isSubtypeOf(TensorType::get()) &&
|
||||
result->type()->isSubtypeOf(NoneType::get()))) {
|
||||
result = tryConvertToType(
|
||||
if (!(actual_return->type()->isSubtypeOf(TensorType::get()) &&
|
||||
actual_return->type()->isSubtypeOf(NoneType::get()))) {
|
||||
actual_return = tryConvertToType(
|
||||
stmt.range(),
|
||||
*graph,
|
||||
result_type,
|
||||
result,
|
||||
declared_return_type,
|
||||
actual_return,
|
||||
/*allow_conversions=*/true);
|
||||
}
|
||||
|
||||
if (!result->type()->isSubtypeOf(result_type)) {
|
||||
if (!actual_return->type()->isSubtypeOf(declared_return_type)) {
|
||||
throw ErrorReport(stmt.range())
|
||||
<< "Return value was annotated as having type "
|
||||
<< result_type->repr_str() << " but is actually of type "
|
||||
<< result->type()->repr_str();
|
||||
<< declared_return_type->repr_str() << " but is actually of type "
|
||||
<< actual_return->type()->repr_str();
|
||||
}
|
||||
} else {
|
||||
result_type = def_stack_.back().merged_return_type_;
|
||||
if (!result_type) {
|
||||
result_type = result->type();
|
||||
declared_return_type = def_stack_.back().merged_return_type_;
|
||||
if (!declared_return_type) {
|
||||
declared_return_type = actual_return->type();
|
||||
}
|
||||
auto merged_result_type = unifyTypes(result_type, result->type());
|
||||
if (!merged_result_type) {
|
||||
auto merged_return_type =
|
||||
unifyTypes(declared_return_type, actual_return->type());
|
||||
if (!merged_return_type) {
|
||||
throw ErrorReport(stmt.range())
|
||||
<< "Previous return statement returned a value of type "
|
||||
<< result_type->repr_str()
|
||||
<< declared_return_type->repr_str()
|
||||
<< " but this return statement returns a value of type "
|
||||
<< result->type()->repr_str();
|
||||
<< actual_return->type()->repr_str();
|
||||
}
|
||||
result_type = merged_result_type.value();
|
||||
declared_return_type = merged_return_type.value();
|
||||
}
|
||||
AT_ASSERT(result_type);
|
||||
AT_ASSERT(declared_return_type);
|
||||
|
||||
def_stack_.back().merged_return_type_ = result_type;
|
||||
def_stack_.back().merged_return_type_ = declared_return_type;
|
||||
|
||||
// If the annotated return type is Any and the result type is not Any,
|
||||
// cast the result to Any to facilitate type unification between return
|
||||
// statements on different code paths (e.g. different branches of an if,
|
||||
// body and containing scope of a loop).
|
||||
if (result_type == AnyType::get() && result->type() != AnyType::get()) {
|
||||
result = graph->insertUncheckedCast(result, result_type);
|
||||
if (declared_return_type == AnyType::get() &&
|
||||
actual_return->type() != AnyType::get()) {
|
||||
actual_return =
|
||||
graph->insertUncheckedCast(actual_return, declared_return_type);
|
||||
}
|
||||
|
||||
graph->insertNode(graph->create(prim::ReturnStmt, {result}, 0));
|
||||
graph->insertNode(graph->create(prim::ReturnStmt, {actual_return}, 0));
|
||||
exit_blocks.insert(environment_stack->block());
|
||||
}
|
||||
|
||||
@ -1142,10 +1149,10 @@ struct to_ir {
|
||||
return {};
|
||||
}
|
||||
// statement must be var {is, is not} None
|
||||
auto name = Var(lhs).name().name();
|
||||
// XXX - while it should in theory be possible to specialize
|
||||
// the `x is None` to know x has type NoneType, we have previously not
|
||||
// done this. Unfortunately, doing this will make the type None
|
||||
const std::string& name = Var(lhs).name().name();
|
||||
// While it should in theory be possible to specialize
|
||||
// the `x is None` to know x has type NoneType, we have previously
|
||||
// not done this. Unfortunately, doing this will make the type None
|
||||
// propagate further in all loaded models. The handling of
|
||||
// unwrap_optional will fail in these cases since export did
|
||||
// not expect that the input would be none and an unannotated None.
|
||||
@ -1154,7 +1161,7 @@ struct to_ir {
|
||||
// and (2) only enable this OPTIONAL_NONE when loading newer
|
||||
// graphs because it is incompatible with older graphs.
|
||||
// Refinement none(name, RefinementKind::OPTIONAL_NONE);
|
||||
if (auto optional_type = lhs_value->type()->cast<OptionalType>()) {
|
||||
if (const auto optional_type = lhs_value->type()->cast<OptionalType>()) {
|
||||
Refinement present(name, optional_type->getElementType());
|
||||
if (tok == TK_IS) {
|
||||
return RefinementSet({}, {present});
|
||||
@ -1162,6 +1169,21 @@ struct to_ir {
|
||||
return RefinementSet({present}, {});
|
||||
}
|
||||
}
|
||||
if (const auto union_type = lhs_value->type()->cast<UnionType>()) {
|
||||
std::vector<TypePtr> to_subtract{NoneType::get()};
|
||||
c10::optional<TypePtr> remaining =
|
||||
union_type->subtractTypeSet(to_subtract);
|
||||
std::vector<Refinement> all_present;
|
||||
if (remaining) {
|
||||
Refinement present{name, *remaining};
|
||||
all_present.push_back(std::move(present));
|
||||
}
|
||||
if (tok == TK_IS) {
|
||||
return RefinementSet({}, all_present);
|
||||
} else { // TK_ISNOT
|
||||
return RefinementSet(all_present, {});
|
||||
}
|
||||
}
|
||||
return RefinementSet();
|
||||
}
|
||||
|
||||
@ -1340,7 +1362,7 @@ struct to_ir {
|
||||
auto unified = unifyTypes(
|
||||
lt->getElementType(),
|
||||
out->type(),
|
||||
/*default_to_any=*/true,
|
||||
/*default_to_union=*/true,
|
||||
element_type_hint);
|
||||
|
||||
if (lt->getElementType() != AnyType::get() &&
|
||||
@ -1458,7 +1480,7 @@ struct to_ir {
|
||||
c10::optional<TypePtr> unified = unifyTypes(
|
||||
dt->getValueType(),
|
||||
v->type(),
|
||||
/*default_to_any=*/true,
|
||||
/*default_to_union=*/true,
|
||||
value_type_hint);
|
||||
|
||||
// Warn the user if we inferred the type of the values to be `Any`
|
||||
@ -1755,13 +1777,32 @@ struct to_ir {
|
||||
graph->createStore(x, fv)->insertBefore(false_block->return_node());
|
||||
}
|
||||
|
||||
auto unified = unifyTypes(tv->type(), fv->type());
|
||||
SugaredValuePtr maybe_sugared_x = environment_stack->findInAnyFrame(x);
|
||||
TypePtr full_type = nullptr;
|
||||
if (maybe_sugared_x) {
|
||||
Value* maybe_simple = asSimple(maybe_sugared_x);
|
||||
if (maybe_simple) {
|
||||
full_type = maybe_simple->type();
|
||||
}
|
||||
}
|
||||
|
||||
// attempt to unify the types. we allow variables to be set to different
|
||||
// types in each branch as long as that variable is not already in scope,
|
||||
// or if that variable does not get used later. here, we save the error
|
||||
// so that the error message will be more informative in the case that is
|
||||
// used later. When a is accessed in (a + 1), the error will get printed
|
||||
// Try to unify the types. If we found a type annotation earlier
|
||||
// in the environment, and if that type annotation is some form
|
||||
// of union, then we need to tell `unifyTypes` not to throw an
|
||||
// error if the branched return types we found are heterogenous
|
||||
bool default_to_union = full_type &&
|
||||
(full_type->kind() == UnionType::Kind ||
|
||||
full_type->kind() == OptionalType::Kind ||
|
||||
full_type->kind() == NumberType::Kind);
|
||||
auto unified = unifyTypes(
|
||||
tv->type(), fv->type(), /*default_to_union=*/default_to_union);
|
||||
|
||||
// We allow variables to be set to different types in each branch
|
||||
// as long as that variable is not already in scope or if that
|
||||
// variable does not get used later. Here, we save the error so
|
||||
// that the error message will be more informative in the case
|
||||
// that is used later. When `a` is accessed in `(a + 1)`, the
|
||||
// error will get printed:
|
||||
// if cond:
|
||||
// a = 1
|
||||
// else:
|
||||
@ -1799,76 +1840,146 @@ struct to_ir {
|
||||
}
|
||||
|
||||
CondValue emitIsInstance(const Expr& obj, const Expr& classinfo) {
|
||||
// turn (float, (int, tuple)) into a flat list of types and type kind
|
||||
// category checks: tuple_check = true, types = {float, int}
|
||||
struct GatheredTypes {
|
||||
GatheredTypes(ScriptTypeParser parser) : typeParser_(std::move(parser)) {}
|
||||
void gather(const Expr& classinfo) {
|
||||
if (classinfo.kind() == TK_TUPLE_LITERAL) {
|
||||
for (Expr e : TupleLiteral(classinfo).inputs()) {
|
||||
gather(e);
|
||||
Value* lhs_val = emitExpr(obj);
|
||||
std::vector<TypePtr> lhs_types;
|
||||
std::vector<TypePtr> rhs_types;
|
||||
|
||||
std::function<void(const Expr&)> gather_rhs = [&](const Expr& expr) {
|
||||
if (expr.kind() == TK_TUPLE_LITERAL) {
|
||||
for (Expr e : TupleLiteral(expr).inputs()) {
|
||||
gather_rhs(e);
|
||||
}
|
||||
return;
|
||||
}
|
||||
TypePtr type = typeParser_.parseTypeFromExpr(classinfo);
|
||||
types.emplace_back(type);
|
||||
}
|
||||
bool staticallyTrue(const TypePtr& actual_type) {
|
||||
// is this isinstance check statically true?
|
||||
for (const TypePtr& typ : types) {
|
||||
if (actual_type->isSubtypeOf(typ)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool maybeOfKind(TypeKind kind, const TypePtr& actual_type) {
|
||||
if (actual_type->kind() == AnyType::Kind) {
|
||||
return true;
|
||||
}
|
||||
if (auto op = actual_type->cast<OptionalType>()) {
|
||||
return op->getElementType()->kind() == kind;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool staticallyFalse(const TypePtr& actual_type) {
|
||||
for (const TypePtr& typ : types) {
|
||||
if (typ->isSubtypeOf(actual_type)) {
|
||||
return false;
|
||||
}
|
||||
if ((typ->isSubtypeOf(AnyListType::get()) &&
|
||||
maybeOfKind(ListType::Kind, actual_type)) ||
|
||||
(typ->isSubtypeOf(AnyTupleType::get()) &&
|
||||
maybeOfKind(TupleType::Kind, actual_type))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
ScriptTypeParser typeParser_;
|
||||
std::vector<TypePtr> types;
|
||||
TypePtr type = typeParser_.parseTypeFromExpr(expr);
|
||||
rhs_types.emplace_back(type);
|
||||
};
|
||||
GatheredTypes gathered(typeParser_);
|
||||
gathered.gather(classinfo);
|
||||
auto val = emitExpr(obj);
|
||||
|
||||
lhs_types.push_back(lhs_val->type());
|
||||
gather_rhs(classinfo);
|
||||
|
||||
standardizeVectorForUnion(&lhs_types);
|
||||
standardizeVectorForUnion(&rhs_types);
|
||||
|
||||
RefinementSet refinement;
|
||||
if (gathered.types.size() == 1 &&
|
||||
gathered.types.at(0)->isSubtypeOf(val->type()) &&
|
||||
obj.kind() == TK_VAR) {
|
||||
std::string ident = Var(obj).name().name();
|
||||
Refinement isinstance(std::move(ident), gathered.types.at(0));
|
||||
refinement = RefinementSet({isinstance}, {});
|
||||
|
||||
TypePtr unified_true = nullptr;
|
||||
TypePtr unified_false = nullptr;
|
||||
|
||||
std::vector<TypePtr> isinstance_types;
|
||||
std::vector<TypePtr> not_isinstance_types;
|
||||
|
||||
std::vector<Refinement> true_refinements;
|
||||
std::vector<Refinement> false_refinements;
|
||||
|
||||
bool all_lhs_subtype_some_rhs = true;
|
||||
|
||||
// We can discard any rhs types that we know statically would be
|
||||
// impossible. For example, if we had:
|
||||
//
|
||||
// def fn(x: Optional[str]):
|
||||
// if isinstance(x, (List[str], str, int)):
|
||||
// ...
|
||||
//
|
||||
// then `x` would be `str` in the true branch and `None` in the
|
||||
// false branch, not `(List[str], str, int)` in the true branch
|
||||
// and `None` in the false branch
|
||||
for (const TypePtr& lhs_type : lhs_types) {
|
||||
if (lhs_type == AnyType::get()) {
|
||||
isinstance_types.insert(
|
||||
isinstance_types.end(), rhs_types.begin(), rhs_types.end());
|
||||
not_isinstance_types.push_back(AnyType::get());
|
||||
// Edge case: we can still say that all lhs types subtype some
|
||||
// rhs type if `lhs` is `Any` and `rhs` is `Any`
|
||||
if (isinstance_types.size() != 1 ||
|
||||
isinstance_types[0] != AnyType::get()) {
|
||||
all_lhs_subtype_some_rhs = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (gathered.staticallyTrue(val->type())) {
|
||||
auto get_smaller_type = [&](TypePtr t1, TypePtr t2) -> TypePtr {
|
||||
if (t1->isSubtypeOf(t2)) {
|
||||
return t1;
|
||||
} else if (t2->isSubtypeOf(t1)) {
|
||||
return t2;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
TypePtr found_refinement = nullptr;
|
||||
for (const TypePtr& rhs_type : rhs_types) {
|
||||
TypePtr maybe_smaller_type = get_smaller_type(lhs_type, rhs_type);
|
||||
if (!maybe_smaller_type) {
|
||||
continue;
|
||||
} else if (*maybe_smaller_type == *lhs_type) {
|
||||
// Cover the case that we have something like
|
||||
// lhs = `List[str]` and rhs = `list`
|
||||
found_refinement = lhs_type;
|
||||
} else if (*maybe_smaller_type == *rhs_type) {
|
||||
// We want the narrowest possible type
|
||||
found_refinement = found_refinement
|
||||
? *(unifyTypes(found_refinement, rhs_type))
|
||||
: rhs_type;
|
||||
}
|
||||
}
|
||||
|
||||
if (found_refinement) {
|
||||
if (*found_refinement == *lhs_type) {
|
||||
all_lhs_subtype_some_rhs &= true;
|
||||
}
|
||||
isinstance_types.push_back(found_refinement);
|
||||
} else {
|
||||
// If the lhs couldn't be a subtype of the rhs (or couldn't
|
||||
// be "refined" to itself, as in the `List[str]` and `list`
|
||||
// case above), then we add `lhs_type` to the false branch
|
||||
// refinements. This is because the type can still be itself
|
||||
// if the `isinstance` check is false
|
||||
not_isinstance_types.push_back(lhs_type);
|
||||
all_lhs_subtype_some_rhs = false;
|
||||
}
|
||||
}
|
||||
|
||||
// For use with `unifyTypeList`
|
||||
std::stringstream nowhere;
|
||||
|
||||
// Get a single type for the true and false branches
|
||||
if (!isinstance_types.empty()) {
|
||||
unified_true =
|
||||
*unifyTypeList(isinstance_types, nowhere, /*default_to_union=*/true);
|
||||
}
|
||||
if (obj.kind() == TK_VAR && unified_true) {
|
||||
std::string ident = Var(obj).name().name();
|
||||
true_refinements = {Refinement(ident, unified_true)};
|
||||
}
|
||||
|
||||
// Get a single type for the true and false branches
|
||||
if (!not_isinstance_types.empty()) {
|
||||
unified_false = *unifyTypeList(
|
||||
not_isinstance_types, nowhere, /*default_to_union=*/true);
|
||||
}
|
||||
if (obj.kind() == TK_VAR && unified_false) {
|
||||
std::string ident = Var(obj).name().name();
|
||||
false_refinements = {Refinement(ident, unified_false)};
|
||||
}
|
||||
|
||||
refinement = RefinementSet(true_refinements, false_refinements);
|
||||
|
||||
bool is_statically_false = isinstance_types.empty();
|
||||
|
||||
// If the statement is statically true
|
||||
if (all_lhs_subtype_some_rhs) {
|
||||
return CondValue(*graph, obj.range(), true, std::move(refinement));
|
||||
}
|
||||
if (gathered.staticallyFalse(val->type())) {
|
||||
|
||||
if (is_statically_false) {
|
||||
return CondValue(*graph, obj.range(), false, std::move(refinement));
|
||||
}
|
||||
|
||||
// check maybe true/false at runtime, need an actual op
|
||||
Value* result =
|
||||
graph->insertNode(graph->createIsInstance(val, gathered.types))
|
||||
graph->insertNode(graph->createIsInstance(lhs_val, rhs_types))
|
||||
->output();
|
||||
return CondValue(result, std::move(refinement), c10::nullopt);
|
||||
}
|
||||
@ -2124,6 +2235,7 @@ struct to_ir {
|
||||
}
|
||||
|
||||
// emit assserions as an if branch so that assertions will reuse the
|
||||
// message
|
||||
void emitAssert(const Assert& stmt) {
|
||||
CondValue cond_value = emitCondExpr(stmt.test());
|
||||
List<Stmt> true_branch = List<Stmt>::create(stmt.range(), {});
|
||||
@ -2979,7 +3091,9 @@ struct to_ir {
|
||||
// after annotation so that variables assigned to this None will still
|
||||
// get the right type. To do this, we make a None constant that
|
||||
// has the type Optional[T]
|
||||
if (type->kind() == OptionalType::Kind &&
|
||||
if ((type->kind() == OptionalType::Kind ||
|
||||
(type->kind() == UnionType::Kind &&
|
||||
type->expect<UnionType>()->canHoldType(NoneType::get()))) &&
|
||||
expr->type()->isSubtypeOf(NoneType::get())) {
|
||||
Node* none = graph->createNone();
|
||||
none->output()->setType(type);
|
||||
@ -3435,8 +3549,9 @@ struct to_ir {
|
||||
size_t n_binders,
|
||||
const TypePtr& type_hint = nullptr) {
|
||||
switch (tree.kind()) {
|
||||
case TK_VAR:
|
||||
case TK_VAR: {
|
||||
return environment_stack->getSugaredVar(Var(tree).name());
|
||||
}
|
||||
case '.': {
|
||||
auto select = Select(tree);
|
||||
auto sv = emitSugaredExpr(select.value(), 1);
|
||||
@ -3710,7 +3825,7 @@ struct to_ir {
|
||||
type_hint ? type_hint->expect<ListType>()->getElementType() : nullptr;
|
||||
|
||||
c10::optional<TypePtr> unified = unifyTypeList(
|
||||
types, nowhere, /*default_to_any=*/true, element_type_hint);
|
||||
types, nowhere, /*default_to_union=*/true, element_type_hint);
|
||||
|
||||
if (!type_hint && *unified == AnyType::get()) {
|
||||
TORCH_WARN(
|
||||
@ -3881,7 +3996,7 @@ struct to_ir {
|
||||
c10::optional<TypePtr> unified = unifyTypeList(
|
||||
types,
|
||||
/*why_not=*/nowhere,
|
||||
/*default_to_any=*/true,
|
||||
/*default_to_union=*/true,
|
||||
value_type_hint);
|
||||
|
||||
if (!type_hint && *unified == AnyType::get()) {
|
||||
|
@ -8,9 +8,10 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// try to match a list of inputs and keyword 'attributes' to this schema,
|
||||
// if it works return the flat list of positional inputs to the call
|
||||
// if it returns nullopt, then failure_messages contains a good error report
|
||||
// Try to match a list of inputs and keyword 'attributes' to this
|
||||
// schema. Return the flat list of positional inputs to the call or
|
||||
// `c10::nullopt` on failure (`failure_messages` contains a good error
|
||||
// report in this case)
|
||||
|
||||
struct MatchedSchema {
|
||||
std::vector<Value*> inputs;
|
||||
|
@ -32,6 +32,7 @@ using c10::StringType;
|
||||
using c10::Symbol;
|
||||
using c10::TensorType;
|
||||
using c10::TupleType;
|
||||
using c10::UnionType;
|
||||
using c10::VarType;
|
||||
|
||||
namespace torch {
|
||||
@ -331,6 +332,18 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||
L.expect(')');
|
||||
alias_info = parseAliasAnnotation();
|
||||
value = DictType::create(key_type, value_type);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
|
||||
L.next();
|
||||
L.expect('(');
|
||||
std::vector<TypePtr> types;
|
||||
types.emplace_back(parseType().first);
|
||||
while (L.cur().kind != ')') {
|
||||
L.expect(',');
|
||||
types.emplace_back(parseType().first);
|
||||
}
|
||||
L.expect(')');
|
||||
alias_info = parseAliasAnnotation();
|
||||
value = UnionType::create(types);
|
||||
} else if (
|
||||
complete_tensor_types && L.cur().kind == TK_IDENT &&
|
||||
parseTensorDType(L.cur().text())) {
|
||||
|
@ -42,7 +42,7 @@ TypePtr ScriptTypeParser::subscriptToType(
|
||||
}
|
||||
std::vector<TypePtr> subscript_expr_types;
|
||||
for (auto expr : subscript.subscript_exprs()) {
|
||||
subscript_expr_types.push_back(parseTypeFromExprImpl(expr));
|
||||
subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr));
|
||||
}
|
||||
return TupleType::create(subscript_expr_types);
|
||||
} else if (typeName == "List" || typeName == "list") {
|
||||
@ -65,6 +65,13 @@ TypePtr ScriptTypeParser::subscriptToType(
|
||||
parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
|
||||
return OptionalType::create(elem_type);
|
||||
|
||||
} else if (typeName == "Union") {
|
||||
std::vector<TypePtr> subscript_expr_types;
|
||||
subscript_expr_types.reserve(subscript.subscript_exprs().size());
|
||||
for (auto expr : subscript.subscript_exprs()) {
|
||||
subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr));
|
||||
}
|
||||
return UnionType::create(subscript_expr_types);
|
||||
} else if (typeName == "Future" || typeName == "torch.jit.Future") {
|
||||
if (subscript.subscript_exprs().size() != 1) {
|
||||
throw ErrorReport(subscript)
|
||||
@ -83,30 +90,6 @@ TypePtr ScriptTypeParser::subscriptToType(
|
||||
auto elem_type =
|
||||
parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
|
||||
return RRefType::create(elem_type);
|
||||
} else if (typeName == "Union") {
|
||||
// In Python 3.9+, Union[NoneType, T] or Union[T, NoneType] are
|
||||
// treated as Optional[T]. Adding the same support for Union in Torchscript.
|
||||
const char* const err =
|
||||
"General Union types are not currently supported."
|
||||
" Only Union[T, NoneType] (i.e. Optional[T]) is "
|
||||
"supported.";
|
||||
if (subscript.subscript_exprs().size() != 2) {
|
||||
throw ErrorReport(subscript) << (err);
|
||||
}
|
||||
auto first_type = parseTypeFromExprImpl(subscript.subscript_exprs()[0]);
|
||||
auto second_type = parseTypeFromExprImpl(subscript.subscript_exprs()[1]);
|
||||
|
||||
bool first_none = first_type == NoneType::get();
|
||||
bool second_none = second_type == NoneType::get();
|
||||
|
||||
if (first_none && !second_none) {
|
||||
return OptionalType::create(second_type);
|
||||
} else if (!first_none && second_none) {
|
||||
return OptionalType::create(first_type);
|
||||
} else {
|
||||
throw ErrorReport(subscript.range()) << err;
|
||||
}
|
||||
|
||||
} else if (typeName == "Dict" || typeName == "dict") {
|
||||
if (subscript.subscript_exprs().size() != 2) {
|
||||
throw ErrorReport(subscript)
|
||||
|
@ -13,94 +13,139 @@ namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
// For any mutable type, map it to a type such that all other types which it can
|
||||
// alias will be mapped to the same type. This function follows a similar logic
|
||||
// to `unifyTypes` because any two mutable types which can be unified
|
||||
// can alias each other.
|
||||
// getMutableTypePtr(Optional[List[int]]) == getMutableTypePtr([List[int]])
|
||||
// If a type is not mutable, return nullopt
|
||||
// This class helps convert types to their mutable equivalent by looking up
|
||||
// cached conversions.
|
||||
TypePtr toSingleType(AliasTypeSet& mut_types) {
|
||||
return mut_types.size() == 1 ? mut_types[0]
|
||||
: c10::UnionType::create(mut_types);
|
||||
}
|
||||
|
||||
// This class determines whether a type is mutable, and, if so, it maps
|
||||
// the type to its "mutable equivalent" (see definition in
|
||||
// `mapTypeToAliasTypeSet`). It uses a cache of TypePtrs to speed up these
|
||||
// type lookups
|
||||
class MutableTypePtrHelper {
|
||||
public:
|
||||
explicit MutableTypePtrHelper(
|
||||
std::unordered_map<TypePtr, TypePtr>* mutable_type_cache)
|
||||
std::unordered_map<TypePtr, AliasTypeSet>* mutable_type_cache)
|
||||
: mutable_type_cache_(mutable_type_cache) {}
|
||||
|
||||
c10::optional<TypePtr> getMutableType(const TypePtr& type) {
|
||||
// Map any mutable type to a type such that all other types which the
|
||||
// mutable type can alias will be mapped to the same type. For
|
||||
// example, calling this method on `Optional[List[int]]` should be
|
||||
// the same as calling this method on `List[int]`.
|
||||
//
|
||||
// Rules:
|
||||
// - If the type is not mutable, return `nullopt`
|
||||
// - If the type is a `Tuple`, that means that it's an immutable
|
||||
// object that can itself contain mutable objects. We want to make
|
||||
// sure that the mutable objects are correctly aliased, so we
|
||||
// remove the immutable objects. (For example,
|
||||
// `Tuple[int, Tensor]` would become `Tuple[Tensor]`, while
|
||||
// `Tuple[int, str]` would be returned as `nullopt`.) This is a
|
||||
// convenience that makes it easy to check if the `Tuple`
|
||||
// contains only immutable objects, though it's not technically
|
||||
// necessary
|
||||
// - For any Tensor type (including Tensor types that are part of
|
||||
// a larger container, e.g. `List[Tensor]`), return the
|
||||
// "unshaped" version of that Tensor. An "unshaped" Tensor is a
|
||||
// Tensor with shape information removed. For example, a Tensor
|
||||
// of dimension 4 would map to the same type as a Tensor of
|
||||
// dimension 1. This allows us to treat all subclasses of Tensor
|
||||
// as a single, homogenous "Tensor" type.
|
||||
c10::optional<AliasTypeSet> mapTypeToAliasTypeSet(const TypePtr& type) {
|
||||
if (mutable_type_cache_) {
|
||||
auto maybe_type = mutable_type_cache_->find(type);
|
||||
if (maybe_type != mutable_type_cache_->end()) {
|
||||
return maybe_type->second;
|
||||
auto maybe_type_mapping = mutable_type_cache_->find(type);
|
||||
if (maybe_type_mapping != mutable_type_cache_->end()) {
|
||||
return maybe_type_mapping->second;
|
||||
}
|
||||
}
|
||||
auto mutable_type = getMutableTypeImpl(type);
|
||||
if (mutable_type_cache_ && mutable_type) {
|
||||
mutable_type_cache_->emplace(type, *mutable_type);
|
||||
auto mutable_types = mapTypeToAliasTypeSetImpl(type);
|
||||
if (mutable_type_cache_ && mutable_types) {
|
||||
mutable_type_cache_->emplace(type, *mutable_types);
|
||||
}
|
||||
return mutable_type;
|
||||
return mutable_types;
|
||||
}
|
||||
|
||||
private:
|
||||
c10::optional<TypePtr> getMutableTypeImpl(const TypePtr& type) {
|
||||
c10::optional<AliasTypeSet> mapTypeToAliasTypeSetImpl(const TypePtr& type) {
|
||||
switch (type->kind()) {
|
||||
case TypeKind::ListType:
|
||||
case TypeKind::DictType:
|
||||
case TypeKind::ClassType:
|
||||
case TypeKind::TensorType:
|
||||
// TODO: lookup cached contained types. this is kind of tricky
|
||||
// because a List[Optional[T]] should still be
|
||||
// List[Optional[Unshaped(T)]], however the getMutableType(Optional[T])
|
||||
// == T
|
||||
return unshapedType(type);
|
||||
case TypeKind::OptionalType:
|
||||
return getMutableType(type->castRaw<OptionalType>()->getElementType());
|
||||
// TODO: Look up cached contained types. this is kind of tricky
|
||||
// because a `List[Optional[T]]` should still be
|
||||
// `List[Optional[Unshaped(T)]]`, but
|
||||
// `mapTypeToAliasTypeSet(Optional[T])` should be `T`
|
||||
return AliasTypeSet{unshapedType(type)};
|
||||
case TypeKind::UnionType: {
|
||||
AliasTypeSet mutable_types;
|
||||
for (const TypePtr& inner :
|
||||
type->expect<UnionType>()->containedTypes()) {
|
||||
if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
|
||||
mutable_types.insert(
|
||||
mutable_types.end(),
|
||||
(*maybe_inner_types).begin(),
|
||||
(*maybe_inner_types).end());
|
||||
}
|
||||
}
|
||||
if (mutable_types.size() == 0) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
return mutable_types;
|
||||
}
|
||||
case TypeKind::OptionalType: {
|
||||
auto inner = type->castRaw<OptionalType>()->getElementType();
|
||||
return mapTypeToAliasTypeSet(inner);
|
||||
}
|
||||
case TypeKind::AnyType:
|
||||
return type;
|
||||
return {AliasTypeSet{type}};
|
||||
case TypeKind::FutureType: {
|
||||
if (auto elem =
|
||||
getMutableType(type->castRaw<FutureType>()->getElementType())) {
|
||||
return FutureType::create(*elem);
|
||||
if (auto maybe_mut_types = mapTypeToAliasTypeSet(
|
||||
type->castRaw<FutureType>()->getElementType())) {
|
||||
auto mut_type = toSingleType(*maybe_mut_types);
|
||||
return {AliasTypeSet{FutureType::create(mut_type)}};
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
case TypeKind::TupleType: {
|
||||
std::vector<TypePtr> mutable_types;
|
||||
for (const auto& elem : type->expectRef<TupleType>().elements()) {
|
||||
if (auto mut_elem = getMutableType(elem)) {
|
||||
mutable_types.push_back(*mut_elem);
|
||||
for (const TypePtr& inner : type->expectRef<TupleType>().elements()) {
|
||||
if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
|
||||
mutable_types.insert(
|
||||
mutable_types.end(),
|
||||
(*maybe_inner_types).begin(),
|
||||
(*maybe_inner_types).end());
|
||||
}
|
||||
}
|
||||
if (mutable_types.size() == 0) {
|
||||
return c10::nullopt;
|
||||
} else {
|
||||
return TupleType::create(mutable_types);
|
||||
}
|
||||
return {AliasTypeSet{TupleType::create(mutable_types)}};
|
||||
}
|
||||
default:
|
||||
return c10::nullopt;
|
||||
}
|
||||
}
|
||||
std::unordered_map<TypePtr, TypePtr>* mutable_type_cache_;
|
||||
std::unordered_map<TypePtr, AliasTypeSet>* mutable_type_cache_;
|
||||
};
|
||||
|
||||
bool isMutableTypeImpl(
|
||||
const TypePtr& type,
|
||||
std::unordered_map<TypePtr, TypePtr>* mutable_type_cache) {
|
||||
// check common cases to avoid recursively constructing type in
|
||||
// getMutableTypePtrImpl
|
||||
std::unordered_map<TypePtr, AliasTypeSet>* mutable_type_cache) {
|
||||
// Check common cases to avoid recursively constructing type in
|
||||
// `mapTypeToAliasTypeSetPtrImpl`
|
||||
auto kind = type->kind();
|
||||
if (kind == TypeKind::TensorType || kind == TypeKind::ListType ||
|
||||
kind == TypeKind::ClassType || kind == TypeKind::DictType) {
|
||||
return true;
|
||||
}
|
||||
MutableTypePtrHelper helper(mutable_type_cache);
|
||||
return helper.getMutableType(type) != c10::nullopt;
|
||||
return helper.mapTypeToAliasTypeSet(type) != c10::nullopt;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// static isMutableType does not use cache of type -> mutable type equivalent
|
||||
// Static `isMutableType` does not use cache of type -> mutable type equivalent
|
||||
bool AliasDb::isMutableType(const TypePtr& type) {
|
||||
return isMutableTypeImpl(type, nullptr);
|
||||
}
|
||||
@ -109,7 +154,7 @@ bool AliasDb::isMutableType(const Value* v) {
|
||||
return isMutableType(v->type());
|
||||
}
|
||||
|
||||
// makes use of type -> mutable cache
|
||||
// Make use of type -> mutable cache
|
||||
bool AliasDb::isMutableTypeInternal(const TypePtr& type) const {
|
||||
return isMutableTypeImpl(type, &mapped_mutable_types_);
|
||||
}
|
||||
@ -118,21 +163,17 @@ bool AliasDb::isMutableTypeInternal(const Value* v) const {
|
||||
return isMutableTypeInternal(v->type());
|
||||
}
|
||||
|
||||
c10::optional<TypePtr> AliasDb::getMutableTypePtr(const TypePtr& type) const {
|
||||
c10::optional<AliasTypeSet> AliasDb::mapTypeToAliasTypeSetPtr(
|
||||
const TypePtr& type) const {
|
||||
MutableTypePtrHelper helper(&mapped_mutable_types_);
|
||||
return helper.getMutableType(type);
|
||||
}
|
||||
|
||||
bool AliasDb::isContainerType(const TypePtr& type) const {
|
||||
auto mut_type = getMutableTypePtr(type);
|
||||
return mut_type && (*mut_type)->containedTypes().size() > 0;
|
||||
return helper.mapTypeToAliasTypeSet(type);
|
||||
}
|
||||
|
||||
AliasDb::~AliasDb() = default;
|
||||
|
||||
// Structure used during analysis to keeps track of all writes at a high level.
|
||||
// When analysis is completed this will be used to construct a more efficient
|
||||
// WriteIndex.
|
||||
// Structure used during analysis to keep track of all writes at a high
|
||||
// level. When the analysis is completed, this will be used to construct
|
||||
// a more efficient WriteIndex
|
||||
struct AliasDb::WriteRegistry {
|
||||
void registerWrite(const Value* v, Node* n) {
|
||||
writes_[n].emplace_back(v);
|
||||
@ -170,7 +211,7 @@ AliasDb::AliasDb(std::shared_ptr<Graph> graph, bool isFrozen)
|
||||
writeIndex_ = TWriteIndex();
|
||||
auto& writeIndex = *writeIndex_; // to make operator[] less ugly
|
||||
|
||||
// build the write index
|
||||
// Build the write index
|
||||
for (const auto& write : writeRegistry_->writes_) {
|
||||
Node* node = write.first;
|
||||
const std::vector<const Value*> writtenValues = write.second;
|
||||
@ -207,7 +248,7 @@ AliasDb::AliasDb(std::shared_ptr<Graph> graph, bool isFrozen)
|
||||
// out of sync (since we have no way of registering new writes)
|
||||
writeRegistry_ = nullptr;
|
||||
|
||||
// initialize the write cache
|
||||
// Initialize the write cache
|
||||
buildWrittenToLocationsIndex();
|
||||
GRAPH_DEBUG(toString());
|
||||
}
|
||||
@ -324,10 +365,10 @@ MemoryLocations AliasDb::getReads(Node* n) const {
|
||||
|
||||
std::string AliasDb::getElementName(const Element* e) const {
|
||||
if (e->values.empty()) {
|
||||
// not the most efficient way, but given the fact there are
|
||||
// Not the most efficient way, but given the fact there are
|
||||
// not too many types and even fewer of them will end up in
|
||||
// wildcardIndex_, we should be fine with a linear search
|
||||
// each time we hit a wildcard leaf
|
||||
// `wildcardIndex_`, we should be fine with a linear search
|
||||
// each time we hit a Wildcard leaf
|
||||
for (const auto& ent : wildcardIndex_) {
|
||||
if (ent.second == e) {
|
||||
return std::string("WILDCARD for type ") + ent.first->str();
|
||||
@ -362,17 +403,27 @@ std::string AliasDb::toString() const {
|
||||
ss << "\n===2. ALIAS DB===\n";
|
||||
for (const auto& ptrPair : elementMap_) {
|
||||
const auto element = ptrPair.second;
|
||||
int ct = 0;
|
||||
if (!element->pointsTo.empty()) {
|
||||
ss << getElementName(element) << " points to: ";
|
||||
for (const auto pointedTo : element->pointsTo) {
|
||||
ss << getElementName(memoryDAG_->fromIndex(pointedTo)) << ", ";
|
||||
if (ct > 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
++ct;
|
||||
ss << getElementName(memoryDAG_->fromIndex(pointedTo));
|
||||
}
|
||||
ss << "\n";
|
||||
}
|
||||
ct = 0;
|
||||
if (!element->containedElements.empty()) {
|
||||
ss << getElementName(element) << " contains: ";
|
||||
for (const auto contained : element->containedElements) {
|
||||
ss << getElementName(memoryDAG_->fromIndex(contained)) << ", ";
|
||||
ss << getElementName(memoryDAG_->fromIndex(contained));
|
||||
if (ct > 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
++ct;
|
||||
}
|
||||
ss << "\n";
|
||||
}
|
||||
@ -839,8 +890,7 @@ void AliasDb::analyzeLoop(Node* node) {
|
||||
TORCH_INTERNAL_ASSERT(blockOutputs.size() == node->outputs().size());
|
||||
|
||||
// Run alias analysis on the loop body, iterating until the block output
|
||||
// alias info converges.
|
||||
// Copy node input aliases to block input
|
||||
// alias info converges. Copy node input aliases to block input
|
||||
mapAliases(blockInputs, loopCarriedInputs);
|
||||
|
||||
// Populate block output alias info by analyzing the body
|
||||
@ -996,7 +1046,7 @@ bool AliasDb::functionalNonEscapingListUse(const Use& use) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// List or dict or tuple: construct: create an aliasing element for the actual
|
||||
// List or dict or tuple construct: create an aliasing element for the actual
|
||||
// container, then mark all inputs as wildcards, since they've gone inside the
|
||||
// container. Then, add the wildcard sets of appropriate type to the contained
|
||||
// elements of the container.
|
||||
@ -1073,52 +1123,50 @@ void AliasDb::makePointerTo(const Value* from, const Value* to) {
|
||||
return;
|
||||
}
|
||||
|
||||
// the contained types of immutable type containers (optional, tuple, future)
|
||||
// are unified, so these types can be mutable or immutable
|
||||
// and point to a type which is mutable or immutable.
|
||||
// Any is mutable but can point to an immutable type through refinement
|
||||
// The contained types of immutable type containers (`Optional`,
|
||||
// `Tuple`, `Future`, and `Union`) are unified, so these types can be
|
||||
// mutable or immutable and point to a type which is mutable or
|
||||
// immutable. `Any` is mutable but can point to an immutable type
|
||||
// through refinement
|
||||
if (isMutableTypeInternal(from) != isMutableTypeInternal(to)) {
|
||||
bool expected_kind = false;
|
||||
for (auto kind : {from->type()->kind(), to->type()->kind()}) {
|
||||
expected_kind = expected_kind ||
|
||||
(kind == TypeKind::OptionalType || kind == TypeKind::FutureType ||
|
||||
kind == TypeKind::TupleType) // immutable type containers
|
||||
kind == TypeKind::TupleType ||
|
||||
kind == TypeKind::UnionType) // immutable type containers
|
||||
|| kind == TypeKind::AnyType;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
expected_kind, from->type()->str(), to->type()->str());
|
||||
return;
|
||||
}
|
||||
|
||||
// both immutable
|
||||
if (!isMutableTypeInternal(from)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (from == to) {
|
||||
return;
|
||||
}
|
||||
|
||||
// At this point, we are dealing with two mutable types.
|
||||
auto fromEl = getOrCreateElement(from);
|
||||
auto toEl = getOrCreateElement(to);
|
||||
// At this point, we are dealing with two mutable types
|
||||
auto from_el = getOrCreateElement(from);
|
||||
auto to_el = getOrCreateElement(to);
|
||||
|
||||
memoryDAGBuilder_->makePointerTo(fromEl, toEl);
|
||||
memoryDAGBuilder_->makePointerTo(from_el, to_el);
|
||||
}
|
||||
|
||||
void AliasDb::addToContainedElements(
|
||||
const Value* elem,
|
||||
const Value* inner,
|
||||
const Value* container) {
|
||||
if (!isMutableTypeInternal(elem)) {
|
||||
if (!isMutableTypeInternal(inner)) {
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(isContainerType(container->type()));
|
||||
auto inner_el = getOrCreateElement(inner);
|
||||
auto cont_el = getOrCreateElement(container);
|
||||
|
||||
auto elemEl = getOrCreateElement(elem);
|
||||
auto contEl = getOrCreateElement(container);
|
||||
|
||||
memoryDAGBuilder_->addToContainedElements(elemEl, contEl);
|
||||
memoryDAGBuilder_->addToContainedElements(inner_el, cont_el);
|
||||
}
|
||||
|
||||
bool AliasDb::mayAlias(const Value* a, const Value* b) const {
|
||||
@ -1203,8 +1251,8 @@ void AliasDb::createValue(const Value* value) {
|
||||
void AliasDb::giveFreshAlias(
|
||||
const Value* value,
|
||||
bool add_wildcard_to_contained_elems) {
|
||||
auto maybe_mut_type = getMutableTypePtr(value->type());
|
||||
if (!maybe_mut_type) {
|
||||
auto maybe_mut_types = mapTypeToAliasTypeSetPtr(value->type());
|
||||
if (!maybe_mut_types) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1217,7 +1265,11 @@ void AliasDb::giveFreshAlias(
|
||||
auto new_elem = memoryDAGBuilder_->makeFreshValue(value);
|
||||
elementMap_[value] = new_elem;
|
||||
if (add_wildcard_to_contained_elems) {
|
||||
addContainedTypesToFreshElement(new_elem, *maybe_mut_type);
|
||||
if ((*maybe_mut_types).size() > 1) {
|
||||
pointUnionTypeElementToAllContainedTypes(new_elem, *maybe_mut_types);
|
||||
} else {
|
||||
addContainedTypesToFreshElement(new_elem, *maybe_mut_types);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1639,56 +1691,86 @@ bool AliasDb::mayAliasWildcard(const at::ArrayRef<Value*> vs) const {
|
||||
}
|
||||
|
||||
c10::optional<Element*> AliasDb::tryGetOrCreateWildcard(const TypePtr& type) {
|
||||
auto updated_type = getMutableTypePtr(type);
|
||||
if (!updated_type) {
|
||||
auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type);
|
||||
if (!maybe_mut_types) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
auto mapped_type = *updated_type;
|
||||
auto existing_wildcard = wildcardIndex_.find(mapped_type);
|
||||
auto mut_type = toSingleType(*maybe_mut_types);
|
||||
auto existing_wildcard = wildcardIndex_.find(mut_type);
|
||||
if (existing_wildcard != wildcardIndex_.end()) {
|
||||
return existing_wildcard->second;
|
||||
}
|
||||
|
||||
auto wildcard_elem = memoryDAGBuilder_->makeFreshValue(nullptr);
|
||||
wildcardIndex_.emplace(mapped_type, wildcard_elem);
|
||||
addContainedTypesToFreshElement(wildcard_elem, mapped_type);
|
||||
wildcardIndex_.emplace(mut_type, wildcard_elem);
|
||||
if ((*maybe_mut_types).size() > 1) {
|
||||
pointUnionTypeElementToAllContainedTypes(wildcard_elem, *maybe_mut_types);
|
||||
} else {
|
||||
addContainedTypesToFreshElement(wildcard_elem, *maybe_mut_types);
|
||||
}
|
||||
return wildcard_elem;
|
||||
}
|
||||
|
||||
void AliasDb::pointUnionTypeElementToAllContainedTypes(
|
||||
Element* container_elem,
|
||||
const AliasTypeSet& mut_types) {
|
||||
for (const auto& mut_type : mut_types) {
|
||||
auto maybe_elem = tryGetOrCreateWildcard(mut_type);
|
||||
if (maybe_elem) {
|
||||
TORCH_INTERNAL_ASSERT(*maybe_elem != container_elem);
|
||||
memoryDAGBuilder_->makePointerTo(container_elem, *maybe_elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AliasDb::addContainedTypesToFreshElement(
|
||||
Element* container_elem,
|
||||
const TypePtr& mut_type) {
|
||||
const AliasTypeSet& mut_types) {
|
||||
for (const auto& mut_type : mut_types) {
|
||||
for (const auto& contained : mut_type->containedTypes()) {
|
||||
auto maybe_elem = tryGetOrCreateWildcard(contained);
|
||||
if (maybe_elem) {
|
||||
memoryDAGBuilder_->addToContainedElements(*maybe_elem, container_elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Search the wildcard index for an element that corresponds to the given type.
|
||||
// Const version returns nullptr
|
||||
Element* AliasDb::getWildcard(const TypePtr& type) const {
|
||||
auto maybe_mut_type = getMutableTypePtr(type);
|
||||
if (!maybe_mut_type) {
|
||||
return nullptr;
|
||||
auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type);
|
||||
if (!maybe_mut_types) {
|
||||
return {};
|
||||
}
|
||||
TypePtr mut_type = *maybe_mut_type;
|
||||
auto wildcard = wildcardIndex_.find(mut_type);
|
||||
if (wildcard != wildcardIndex_.end()) {
|
||||
return wildcard->second;
|
||||
if ((*maybe_mut_types).size() > 1) {
|
||||
auto union_type = UnionType::create(*maybe_mut_types);
|
||||
// Get a <TypePtr, Element*> pair where the TypePtr is this Union
|
||||
// type and the Element is the corresponding Wildcard
|
||||
auto maybe_union_pair = wildcardIndex_.find(union_type);
|
||||
if (maybe_union_pair != wildcardIndex_.end()) {
|
||||
return (*maybe_union_pair).second;
|
||||
}
|
||||
return nullptr;
|
||||
} else {
|
||||
// Get a <TypePtr, Element*> pair where the TypePtr is the given
|
||||
// type and the Element is the corresponding Wildcard
|
||||
auto type_pair = wildcardIndex_.find((*maybe_mut_types)[0]);
|
||||
if (type_pair != wildcardIndex_.end()) {
|
||||
return type_pair->second;
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
// Register `v` as a wildcard value.
|
||||
c10::optional<Element*> AliasDb::setWildcard(const Value* v) {
|
||||
auto maybe_wildcardElement = tryGetOrCreateWildcard(v->type());
|
||||
c10::optional<Element*> maybe_wildcardElement =
|
||||
tryGetOrCreateWildcard(v->type());
|
||||
if (!maybe_wildcardElement) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
// Ensure that we create a corresponding element for `v` still, as it is an
|
||||
// invariant that all mutable values have an element.
|
||||
// Ensure that we create a corresponding Element for `v` still, as it is an
|
||||
// invariant that all mutable values have an Element
|
||||
getOrCreateElement(v);
|
||||
wildcards_.insert(v);
|
||||
return *maybe_wildcardElement;
|
||||
|
@ -34,6 +34,12 @@ namespace jit {
|
||||
* Values that contain other mutable types, such as List[Tensor], are
|
||||
* initialized as containing the Wildcard set for all contained mutable types.
|
||||
*
|
||||
* The AliasDb API references the idea of "mutable" vs "immutable"
|
||||
* types. "Mutable" means that the object's value can change, while
|
||||
* "immutable" means that the value is fixed. (For example, `List` is
|
||||
* mutable, so you can add and delete elements from it. On the other
|
||||
* hand, you can't modify a Tuple once you create it, making `Tuple` an
|
||||
* immutable container.)
|
||||
*/
|
||||
class AliasDb {
|
||||
public:
|
||||
@ -95,7 +101,7 @@ class AliasDb {
|
||||
const at::ArrayRef<Value*>& a,
|
||||
const at::ArrayRef<Value*>& b) const;
|
||||
|
||||
// Move 'n' (already in the graph) after 'movePoint' in the topological order.
|
||||
// Move `n` (already in the graph) after `movePoint` in the topological order.
|
||||
//
|
||||
// Tries to preserve value dependencies, so other nodes might be moved. We
|
||||
// make two guarantees about the postcondition of the node list:
|
||||
@ -125,6 +131,10 @@ class AliasDb {
|
||||
TORCH_API bool dumpToGraphvizFile(const char* filename) const;
|
||||
TORCH_API std::string toGraphviz() const;
|
||||
|
||||
// Returns `true` if the given element is mutable or if it is a
|
||||
// container type with an internal mutable element (e.g.
|
||||
// `Tuple[int, Tensor]` has an internal mutable type `Tensor`, so
|
||||
// it would be considered a "mutable type" in AliasDb)
|
||||
static bool isMutableType(const Value* v);
|
||||
static bool isMutableType(const TypePtr& type);
|
||||
|
||||
@ -181,7 +191,7 @@ class AliasDb {
|
||||
// Register `v` as a wildcard value.
|
||||
c10::optional<Element*> setWildcard(const Value* v);
|
||||
|
||||
// Is this a value which will not alias
|
||||
// Is this a value which will not alias?
|
||||
bool nonAliasingValue(const Value* elem) const;
|
||||
|
||||
/**
|
||||
@ -221,11 +231,10 @@ class AliasDb {
|
||||
bool add_wildcard_to_contained_elems = true);
|
||||
Element* getOrCreateElement(const Value* value);
|
||||
|
||||
c10::optional<TypePtr> getMutableTypePtr(const TypePtr& type) const;
|
||||
c10::optional<AliasTypeSet> mapTypeToAliasTypeSetPtr(
|
||||
const TypePtr& type) const;
|
||||
bool functionalNonEscapingListUse(const Use& use) const;
|
||||
|
||||
bool isContainerType(const TypePtr& type) const;
|
||||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
|
||||
// If the Module is frozen then consider attributes as freshly created
|
||||
@ -239,21 +248,24 @@ class AliasDb {
|
||||
|
||||
// Mapping of values to MemoryDAG elements
|
||||
ska::flat_hash_map<const Value*, Element*> elementMap_;
|
||||
// All wildcard elements (one for each unique mutable type).
|
||||
// All wildcard Elements (one for each unique mutable type)
|
||||
std::unordered_map<TypePtr, Element*, HashType, EqualType> wildcardIndex_;
|
||||
Element* getWildcard(const TypePtr& type) const;
|
||||
c10::optional<Element*> tryGetOrCreateWildcard(const TypePtr& type);
|
||||
void addContainedTypesToFreshElement(
|
||||
Element* container_elem,
|
||||
const TypePtr& mut_type);
|
||||
const AliasTypeSet& mut_types);
|
||||
void pointUnionTypeElementToAllContainedTypes(
|
||||
Element* container_elem,
|
||||
const AliasTypeSet& mut_types);
|
||||
|
||||
std::vector<Element*> getElements(at::ArrayRef<Value*> vs) const;
|
||||
bool mayAliasWildcard(const Value* v) const;
|
||||
bool mayAliasWildcard(const at::ArrayRef<Value*> vs) const;
|
||||
bool hasWriters(const at::ArrayRef<Value*>& values) const;
|
||||
|
||||
// cached mapping of type ptrs to their mutable types
|
||||
mutable std::unordered_map<TypePtr, TypePtr> mapped_mutable_types_;
|
||||
// Cached mapping of type ptrs to their mutable types
|
||||
mutable std::unordered_map<TypePtr, AliasTypeSet> mapped_mutable_types_;
|
||||
|
||||
/**
|
||||
* State for tracking write info.
|
||||
|
@ -511,7 +511,7 @@ void Graph::lint() const {
|
||||
// - Params and return do NOT occur in nodes
|
||||
// - next_unique_ is greater than all uniques in graph
|
||||
// - uniques in all_nodes are unique
|
||||
// - every use will occur later in the topsort
|
||||
// - every use will occur later in the toposort
|
||||
|
||||
struct LintScope {
|
||||
LintScope() = default;
|
||||
@ -787,7 +787,9 @@ bool Value::mustBeNone() const {
|
||||
}
|
||||
bool Value::mustNotBeNone() const {
|
||||
return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
|
||||
!type()->cast<OptionalType>();
|
||||
!type()->cast<OptionalType>() &&
|
||||
!(type()->cast<UnionType>() &&
|
||||
type()->expect<UnionType>()->canHoldType(NoneType::get()));
|
||||
}
|
||||
|
||||
std::string Value::debugNameBase() const {
|
||||
@ -1765,20 +1767,23 @@ Node* Graph::createEnumValue(Value* e) {
|
||||
return n;
|
||||
}
|
||||
|
||||
Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
|
||||
Node* Graph::createList(
|
||||
const TypePtr& contained_type,
|
||||
at::ArrayRef<Value*> values) {
|
||||
auto n = create(prim::ListConstruct, values);
|
||||
for (const auto& v : values) {
|
||||
TORCH_CHECK(
|
||||
v->type()->isSubtypeOf(elem_type),
|
||||
v->type()->isSubtypeOf(contained_type),
|
||||
"Expected a list element that subtypes '",
|
||||
elem_type->repr_str(),
|
||||
contained_type->repr_str(),
|
||||
"' but got an element of type '",
|
||||
v->type()->repr_str(),
|
||||
"'");
|
||||
}
|
||||
n->output()->setType(ListType::create(elem_type));
|
||||
n->output()->setType(ListType::create(contained_type));
|
||||
return n;
|
||||
}
|
||||
|
||||
Node* Graph::createListUnpack(Value* v, size_t size) {
|
||||
ListTypePtr list_type = v->type()->expect<ListType>();
|
||||
TypePtr elem_type = list_type->getElementType();
|
||||
|
@ -84,7 +84,7 @@ using namespace ::c10::cuda;
|
||||
struct Function;
|
||||
struct MatchedSchema;
|
||||
|
||||
// Graph represents one "function" of computation.
|
||||
// A Graph represents one "function" of computation.
|
||||
// It uses a simple ownership model where the graph owns all the nodes inside
|
||||
// it. All references inside the graph are raw pointers. Destroying the Graph
|
||||
// will invalidate any pointers to nodes in the graph.
|
||||
@ -104,9 +104,9 @@ TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);
|
||||
// A list of nodes, with inputs and outputs
|
||||
struct Block;
|
||||
|
||||
// Each use is represented by this type, see Node::uses()
|
||||
// 'user' is the consumer of the value, offset is the index into
|
||||
// 'user's input this where the produces will be found.
|
||||
// Each use is represented by this type, see 'Node::uses()'
|
||||
// 'user' is the consumer of the value, 'offset' is the index into
|
||||
// 'user's input this where the producers will be found.
|
||||
struct Use {
|
||||
Use(Node* user, size_t offset) : user(user), offset(offset) {}
|
||||
Node* user;
|
||||
@ -338,14 +338,16 @@ struct TORCH_API Node {
|
||||
protected:
|
||||
Node(Graph* graph_, NodeKind kind_); // defined after graph
|
||||
public:
|
||||
// each node but Return/Param
|
||||
// is associated with exactly one place in the node list...
|
||||
// of the graph_
|
||||
// this circular is a doubly-linked list, the Return node is used as the
|
||||
// sentinel for the beginning and end of the list such that the list never has
|
||||
// null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev
|
||||
// pointer using an array to allow the same iterator class for forward and
|
||||
// reverse node lists This list represents a topological sort
|
||||
// Each Node but Return/Param Nodes are associated with exactly one
|
||||
// place in the Node list of the Graph. The Graph itself is a circular
|
||||
// doubly-linked list. The Return Node is used as the sentinel for the
|
||||
// "beginning"/"end" of the list. This means that you can tell when
|
||||
// you've traversed the entire list without means worrying about null
|
||||
// pointers. `next_in_graph[0]` is the pointer to the next Node, while
|
||||
// `next_in_graph[1]` is the pointer to the previous Node. The
|
||||
// linked list is implemented as an array to allow the same iterator
|
||||
// class for forward and reversed Node lists. Taken together, this
|
||||
// list also represents a topological sort of the Nodes in the Graph.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-non-private-member-variables-in-classes,modernize-avoid-c-arrays)
|
||||
Node* next_in_graph[2] = {nullptr, nullptr};
|
||||
|
||||
@ -980,7 +982,6 @@ struct TORCH_API Node {
|
||||
// subclasses should extend if they have additional information to copy.
|
||||
// 'this' will be allocated with s->allocNewInstance(g) so it should have
|
||||
// the same concrete type as 's'
|
||||
//
|
||||
virtual void cloneFrom(Node* s);
|
||||
};
|
||||
|
||||
@ -1247,7 +1248,7 @@ struct Graph {
|
||||
TORCH_API Node* createEnumName(Value* e);
|
||||
TORCH_API Node* createEnumValue(Value* e);
|
||||
TORCH_API Node* createList(
|
||||
const TypePtr& elem_type,
|
||||
const TypePtr& contained_type,
|
||||
at::ArrayRef<Value*> values);
|
||||
TORCH_API Node* createListUnpack(Value* v, size_t size);
|
||||
TORCH_API Node* createDict(
|
||||
|
@ -42,6 +42,17 @@ class TypeParser {
|
||||
return simpleTypeIt->second;
|
||||
} else if (token == "List") {
|
||||
return CreateSingleElementType<ListType>();
|
||||
} else if (token == "Union") {
|
||||
std::vector<TypePtr> types;
|
||||
expect("[");
|
||||
while (cur() != "]") {
|
||||
types.emplace_back(parse());
|
||||
if (cur() != "]") {
|
||||
expect(",");
|
||||
}
|
||||
}
|
||||
expect("]");
|
||||
return UnionType::create(types);
|
||||
} else if (token == "Optional") {
|
||||
return CreateSingleElementType<OptionalType>();
|
||||
} else if (token == "Future") {
|
||||
|
@ -288,6 +288,24 @@ class ShapePropagator {
|
||||
return zerodim;
|
||||
}
|
||||
|
||||
bool mergeTypes(
|
||||
ArrayRef<Value*> lhs,
|
||||
ArrayRef<Value*> rhs,
|
||||
ArrayRef<Value*> outputs) {
|
||||
AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size());
|
||||
bool changed = false;
|
||||
for (size_t i = 0; i < lhs.size(); ++i) {
|
||||
auto old_output_type = outputs[i]->type();
|
||||
auto new_type =
|
||||
unifyTypes(lhs[i]->type(), rhs[i]->type(), /*default_to_union=*/true);
|
||||
AT_ASSERT(new_type);
|
||||
outputs[i]->setType(*new_type);
|
||||
if (*old_output_type != *outputs[i]->type())
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
void broadcastBinary(
|
||||
Node* node,
|
||||
std::vector<TensorTypePtr>& types,
|
||||
|
@ -8,6 +8,7 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace {
|
||||
|
||||
void makePointerToImpl(Element* from, Element* to) {
|
||||
from->pointsTo.set(to->index);
|
||||
to->pointedFrom.set(from->index);
|
||||
@ -131,11 +132,13 @@ Element* MemoryDAGBuilder::makeFreshValue(const Value* v) {
|
||||
return makeFreshValueImpl(v, indexToElementMap_);
|
||||
}
|
||||
|
||||
// This function builds up a bitset representing the "alias set" for
|
||||
// `e` (`MemoryLocations` is just a typedef'd c10::SparseBitVector).
|
||||
const MemoryLocations& MemoryDAG::getMemoryLocations(const Element* e) const {
|
||||
// Note on cache invalidation: all mutation should occur through
|
||||
// MemoryDAGBuilder. Thus, once we consume the builder to create an immutable
|
||||
// MemoryDAG, we can cache here without worrying that we might potentially get
|
||||
// invalidated.
|
||||
// MemoryDAGBuilder. Thus, once we consume the builder to create an
|
||||
// immutable MemoryDAG, we can cache here without worrying that we
|
||||
// might potentially get invalidated.
|
||||
if (e->cachedMemoryLocations_) {
|
||||
return *e->cachedMemoryLocations_;
|
||||
}
|
||||
@ -174,7 +177,6 @@ void MemoryDAG::setWildcards(
|
||||
makePointerToImpl(from, wildcardElement);
|
||||
}
|
||||
}
|
||||
|
||||
// Track which memory locations we edited with a new pointer to the wildcard
|
||||
// element.
|
||||
cacheUpdates[wildcardElement] |= pointeeSet;
|
||||
@ -189,7 +191,6 @@ void MemoryDAG::setWildcards(
|
||||
for (const std::unique_ptr<Element>& e : this->indexToElementMap_) {
|
||||
if (e->values.empty()) {
|
||||
// This element is a wildcard element, we can skip it.
|
||||
TORCH_INTERNAL_ASSERT(e->pointsTo.empty());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/sparse_bitset.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/type_hashing.h>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
@ -20,6 +23,9 @@ struct Element;
|
||||
struct Value;
|
||||
class MemoryDAG;
|
||||
|
||||
using TypePtr = std::shared_ptr<c10::Type>;
|
||||
using AliasTypeSet = std::vector<TypePtr>;
|
||||
|
||||
/**
|
||||
* Helper to build up the points-to graph.
|
||||
*
|
||||
@ -38,13 +44,15 @@ class TORCH_API MemoryDAGBuilder {
|
||||
|
||||
void addToContainedElements(Element* contained, Element* container);
|
||||
|
||||
// Make a fresh element (i.e. an element that doesn't point to anything) and
|
||||
// Make a fresh Element (i.e. an Element that doesn't point to anything) and
|
||||
// return it.
|
||||
Element* makeFreshValue(const Value* v);
|
||||
|
||||
friend MemoryDAG;
|
||||
|
||||
private:
|
||||
// `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses
|
||||
// the map to construct the `MemoryDAG`
|
||||
std::vector<std::unique_ptr<Element>> indexToElementMap_;
|
||||
};
|
||||
|
||||
@ -54,8 +62,8 @@ class TORCH_API MemoryDAGBuilder {
|
||||
// AliasDb to provide a higher-level API.
|
||||
//
|
||||
// We maintain a DAG where:
|
||||
// - Vertices (called "elements") represent values and
|
||||
// other aliasing entities (e.g. like the stuff inside a list)
|
||||
// - Vertices (called "Elements") represent Values and
|
||||
// other aliasing entities (e.g. the stuff inside a list)
|
||||
// - Edges represent a "points-to" relationship.
|
||||
//
|
||||
// Leaves in this DAG are entities that don't point to anything, and thus
|
||||
@ -80,7 +88,7 @@ class TORCH_API MemoryDAG {
|
||||
bool mayAlias(const Element* a, const Element* b) const;
|
||||
bool mayAlias(Element* a, Element* b) const;
|
||||
|
||||
// Does a hold reference to any memory that is stored in elem, or vice versa?
|
||||
// Does `a` hold reference to any memory that is stored in `b`, or vice versa?
|
||||
bool mayContainAlias(const Element* a, const Element* b) const;
|
||||
bool mayContainAlias(Element* a, Element* b) const;
|
||||
|
||||
@ -96,12 +104,13 @@ class TORCH_API MemoryDAG {
|
||||
MemoryLocations& cont) const;
|
||||
|
||||
/**
|
||||
* The following methods are special cases where we need to reach mutate the
|
||||
* The following methods are special cases where we need to mutate the
|
||||
* internals of MemoryDAG for efficiency reasons. Don't call them unless you
|
||||
* know what you're doing! In particular, don't add new mutating methods
|
||||
* without ensuring that you are maintaining cache consistency for memory
|
||||
* locations.
|
||||
*/
|
||||
|
||||
// Adding wildcards can trigger extremely expensive cache invalidations. This
|
||||
// method adds them in a more efficient cache-aware way.
|
||||
void setWildcards(
|
||||
@ -117,9 +126,10 @@ class TORCH_API MemoryDAG {
|
||||
std::vector<std::unique_ptr<Element>> indexToElementMap_;
|
||||
};
|
||||
|
||||
// `Element` represents the vertex in the points-to graph. It represents
|
||||
// anything that could have an aliasing relationship, mostly IR `Value`s, but
|
||||
// also the "inside of a list", or wildcards.
|
||||
// `Element` represents a vertex in the points-to graph. It represents
|
||||
// anything that could have an aliasing relationship--mostly IR
|
||||
// `Value`s, but also wildcards or the type inside a container (e.g. `T`
|
||||
// in `List[T]`)
|
||||
struct Element {
|
||||
Element(const Value* value_, unsigned index_);
|
||||
// wildcard constructor
|
||||
|
@ -89,6 +89,19 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type)
|
||||
: c10::ivalue::Tuple::create(std::move(values));
|
||||
}
|
||||
case TypeKind::UnionType: {
|
||||
auto actual_type = toTypeInferredIValue(obj);
|
||||
auto actual_type_ptr = actual_type.type();
|
||||
auto union_type = type->expect<UnionType>();
|
||||
if (!actual_type_ptr->isSubtypeOf(union_type)) {
|
||||
throw py::cast_error(c10::str(
|
||||
"Expected a member of ",
|
||||
union_type->annotation_str(),
|
||||
" but instead found type ",
|
||||
actual_type.type()->annotation_str()));
|
||||
}
|
||||
return actual_type;
|
||||
}
|
||||
case TypeKind::StringType:
|
||||
return ConstantString::create(py::cast<std::string>(obj));
|
||||
case TypeKind::DeviceObjType: {
|
||||
|
@ -869,6 +869,12 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
}
|
||||
return types;
|
||||
});
|
||||
py::class_<UnionType, Type, std::shared_ptr<UnionType>>(m, "UnionType")
|
||||
.def(py::init(
|
||||
[](const std::vector<TypePtr>& a) { return UnionType::create(a); }))
|
||||
.def("containedTypes", [](UnionType& self) {
|
||||
return self.containedTypes().vec();
|
||||
});
|
||||
py::class_<ListType, Type, std::shared_ptr<ListType>>(m, "ListType")
|
||||
.def(py::init([](TypePtr a) { return ListType::create(a); }))
|
||||
.def_static("ofInts", &ListType::ofInts)
|
||||
|
@ -47,7 +47,8 @@ void postSetStateValidate(const IValue& v) {
|
||||
// const auto attrType = objType->getAttribute(i);
|
||||
// Verify that all the non-optional attributes have been initialized
|
||||
// TODO: Issue #20497
|
||||
if (attrType->kind() != TypeKind::OptionalType &&
|
||||
if (attrType->kind() != TypeKind::UnionType &&
|
||||
attrType->kind() != TypeKind::OptionalType &&
|
||||
attrType->kind() != TypeKind::NoneType) {
|
||||
TORCH_CHECK(
|
||||
!slot.isNone(),
|
||||
|
@ -482,12 +482,13 @@ void SourceImporterImpl::importClass(
|
||||
} break;
|
||||
case TK_DEF: {
|
||||
Def def = Def(statement);
|
||||
if (pre_hook_names.find(def.name().name()) != pre_hook_names.end()) {
|
||||
pre_hook_def_map.emplace(def.name().name(), def);
|
||||
pre_hook_resolver_map.emplace(def.name().name(), shared_from_this());
|
||||
} else if (hook_names.find(def.name().name()) != hook_names.end()) {
|
||||
hook_def_map.emplace(def.name().name(), def);
|
||||
hook_resolver_map.emplace(def.name().name(), shared_from_this());
|
||||
const auto def_name = def.name().name();
|
||||
if (pre_hook_names.find(def_name) != pre_hook_names.end()) {
|
||||
pre_hook_def_map.emplace(def_name, def);
|
||||
pre_hook_resolver_map.emplace(def_name, shared_from_this());
|
||||
} else if (hook_names.find(def_name) != hook_names.end()) {
|
||||
hook_def_map.emplace(def_name, def);
|
||||
hook_resolver_map.emplace(def_name, shared_from_this());
|
||||
} else {
|
||||
methods.emplace_back(def);
|
||||
method_resolvers.push_back(shared_from_this());
|
||||
|
@ -511,14 +511,32 @@ struct PythonPrintImpl {
|
||||
}
|
||||
indent();
|
||||
printValueList(body_, lhs);
|
||||
// We need to preserve Union/Optional type annotations, but only if
|
||||
// we're not assigning values as part of a tuple unpacking statement
|
||||
// (Python doesn't allow type annotations in multiple assignment)
|
||||
if (lhs.size() == 1) {
|
||||
Value* v = lhs.at(0);
|
||||
if (!annotated_unions_.count(v) && !expr_table_.count(v) &&
|
||||
(v->type()->kind() == UnionType::Kind ||
|
||||
v->type()->kind() == OptionalType::Kind)) {
|
||||
body_ << " : " << v->type()->annotation_str();
|
||||
annotated_unions_.insert(v);
|
||||
}
|
||||
}
|
||||
body_ << " = ";
|
||||
// or if value is being assigned to something of a union type
|
||||
printValueList(body_, rhs);
|
||||
body_ << "\n";
|
||||
}
|
||||
|
||||
bool requiresAnnotation(Value* lhs, Value* rhs) {
|
||||
if (lhs->type()->kind() == UnionType::Kind ||
|
||||
lhs->type()->kind() == OptionalType::Kind) {
|
||||
return annotated_unions_.insert(lhs).second;
|
||||
} else {
|
||||
return *lhs->type() != *rhs->type();
|
||||
}
|
||||
}
|
||||
|
||||
void printAnnotatedAssignment(
|
||||
at::ArrayRef<Value*> lhs,
|
||||
@ -1302,10 +1320,12 @@ struct PythonPrintImpl {
|
||||
body_ << arg_name;
|
||||
if (print_first_argument_type) {
|
||||
body_ << ": " << arg.type()->annotation_str(type_printer_);
|
||||
annotated_unions_.insert(*param_it);
|
||||
}
|
||||
} else {
|
||||
body_ << ",\n " << arg_name << ": "
|
||||
<< arg.type()->annotation_str(type_printer_);
|
||||
annotated_unions_.insert(*param_it);
|
||||
}
|
||||
if (arg.default_value()) {
|
||||
printDefaultValue(arg, body_, *arg.default_value());
|
||||
@ -1559,6 +1579,12 @@ struct PythonPrintImpl {
|
||||
// table.
|
||||
PrintDepsTable& deps_table_;
|
||||
|
||||
// We need to preserve Union/Optional type annotations, but we should
|
||||
// only print the annotation on variable declaration (not on any
|
||||
// following uses). This set tracks the Value*s that we've already
|
||||
// printed with annotations
|
||||
std::unordered_set<Value*> annotated_unions_;
|
||||
|
||||
// A function that, given a named type, returns us the correct string to print
|
||||
// for it.
|
||||
c10::TypePrinter type_printer_;
|
||||
|
@ -23,8 +23,8 @@ static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
|
||||
|
||||
// Pickled objects are stored in a form compatible with Python pickling.
|
||||
// In torchscript List[T]/Dict[K, V] are statically typed and contain
|
||||
// dynamic type tags allow T, K, and V to be recovered. But this info
|
||||
// is not stored in the Python pickling information. However, we
|
||||
// dynamic type tags that allow T, K, and V to be recovered. But this
|
||||
// info is not stored in the Python pickling information. However, we
|
||||
// can recover this information from the static type of the top-level
|
||||
// object being unpickled, because we have a record of the type of the
|
||||
// objects it contains as attributes.
|
||||
@ -108,6 +108,19 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||
to_process.emplace_back(std::move(elem));
|
||||
}
|
||||
} break;
|
||||
case UnionType::Kind: {
|
||||
auto t = w.static_type->expect<UnionType>();
|
||||
if (t->containedTypes().size() == 2 &&
|
||||
t->canHoldType(NoneType::get())) {
|
||||
if (!w.value.isNone()) {
|
||||
auto inner = t->containedTypes()[0] != NoneType::get()
|
||||
? t->containedTypes()[0]
|
||||
: t->containedTypes()[1];
|
||||
Work elem = {inner, w.value};
|
||||
to_process.emplace_back(std::move(elem));
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case ListType::Kind: {
|
||||
// specialized lists do not need their type refined, so we can exit
|
||||
// early here
|
||||
|
@ -449,7 +449,7 @@ if _enabled:
|
||||
setattr(RecursiveScriptClass, method_name, method_template)
|
||||
|
||||
# this is a Python 'non-data descriptor' that causes the first access
|
||||
# to ScriptModule's forward to lookup the forward method and stash
|
||||
# to ScriptModule's forward to look up the forward method and stash
|
||||
# it in the objects dict. Due to the standard rules for attribute lookup,
|
||||
# subsequent lookups will just directly return the previously looked up method.
|
||||
# This is necessary because nn.Module defines forward as a method. If we
|
||||
|
@ -6,13 +6,13 @@ import builtins
|
||||
import torch
|
||||
import warnings
|
||||
from .._jit_internal import List, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
|
||||
is_optional, _qualified_name, Any, Future, is_future, is_ignored_fn
|
||||
is_optional, _qualified_name, Any, Future, is_future, is_ignored_fn, Union, is_union
|
||||
from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingList3 # type: ignore[attr-defined]
|
||||
from ._state import _get_script_class
|
||||
|
||||
from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
|
||||
ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, NoneType, \
|
||||
DeviceObjType, StreamObjType, FutureType, EnumType
|
||||
ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, \
|
||||
NoneType, DeviceObjType, StreamObjType, FutureType, EnumType, UnionType
|
||||
|
||||
|
||||
from textwrap import dedent
|
||||
@ -45,7 +45,8 @@ class EvalEnv(object):
|
||||
'List': List,
|
||||
'Dict': Dict,
|
||||
'Optional': Optional,
|
||||
'Future': Future,
|
||||
'Union': Union,
|
||||
'Future': Future
|
||||
}
|
||||
|
||||
def __init__(self, rcb):
|
||||
@ -245,6 +246,9 @@ def split_type_line(type_line):
|
||||
def try_real_annotations(fn, loc):
|
||||
"""Tries to use the Py3.5+ annotation syntax to get the type."""
|
||||
try:
|
||||
# Note: anything annotated as `Optional[T]` will automatically
|
||||
# be returned as `Union[T, None]` per
|
||||
# https://github.com/python/typing/blob/master/src/typing.py#L850
|
||||
sig = inspect.signature(fn)
|
||||
except ValueError:
|
||||
return None
|
||||
@ -276,7 +280,6 @@ def get_enum_value_type(e: Type[enum.Enum], loc):
|
||||
return torch._C.unify_type_list(ir_types)
|
||||
|
||||
def is_tensor(ann):
|
||||
|
||||
if issubclass(ann, torch.Tensor):
|
||||
return True
|
||||
|
||||
@ -326,6 +329,19 @@ def try_ann_to_type(ann, loc):
|
||||
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
|
||||
assert valid_type, msg.format(repr(ann), repr(contained))
|
||||
return OptionalType(valid_type)
|
||||
if is_union(ann):
|
||||
inner: List = []
|
||||
# We need these extra checks because both `None` and invalid
|
||||
# values will return `None`
|
||||
# TODO: Determine if the other cases need to be fixed as well
|
||||
for a in ann.__args__:
|
||||
if a is None:
|
||||
inner.append(NoneType.get())
|
||||
maybe_type = try_ann_to_type(a, loc)
|
||||
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
|
||||
assert maybe_type, msg.format(repr(ann), repr(maybe_type))
|
||||
inner.append(maybe_type)
|
||||
return UnionType(inner) # type: ignore[arg-type]
|
||||
if torch.distributed.rpc.is_available() and is_rref(ann):
|
||||
return RRefType(try_ann_to_type(ann.__args__[0], loc))
|
||||
if is_future(ann):
|
||||
@ -390,6 +406,8 @@ __all__ = [
|
||||
'is_list',
|
||||
'Dict',
|
||||
'is_dict',
|
||||
'is_optional',
|
||||
'is_union',
|
||||
'TensorType',
|
||||
'TupleType',
|
||||
'FloatType',
|
||||
|
@ -452,6 +452,7 @@ def get_default_args(fn):
|
||||
return {}
|
||||
|
||||
signature = inspect.signature(fn)
|
||||
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
|
Reference in New Issue
Block a user