diff --git a/aten/src/ATen/BatchedFallback.cpp b/aten/src/ATen/BatchedFallback.cpp index 1eb0b4abdc70..ebc1396ca35e 100644 --- a/aten/src/ATen/BatchedFallback.cpp +++ b/aten/src/ATen/BatchedFallback.cpp @@ -36,7 +36,7 @@ static bool areAnyArgumentsTensorList(const FunctionSchema& schema) { return std::any_of( schema.arguments().begin(), schema.arguments().end(), - [] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); }); + [] (const Argument& arg) { return arg.type()->isSubtypeOf(*ListType::ofTensors()); }); } // Returns if an operator is in-place. An operator is inplace if: diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h index fa6d5c01f26e..1253ec9265c5 100644 --- a/aten/src/ATen/core/List_inl.h +++ b/aten/src/ATen/core/List_inl.h @@ -63,7 +63,7 @@ List toTypedList(impl::GenericList list) { // as List before we changed that argument to be List>. When deserializing, we // have list.use_count() == 1 and can deserialize the List directly as List>. TORCH_CHECK(*list.impl_->elementType == *getTypePtr() - || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr())) + || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr())) , "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch."); return List(std::move(list.impl_)); } diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index ef1780ffa72b..5289f9fa0114 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -172,13 +172,13 @@ private: " arguments but this PyTorch build only supports ", c10::utils::bitset::NUM_BITS()); c10::utils::bitset dispatch_arg_indices_reverse; for (size_t index = 0; index < schema.arguments().size(); ++index) { - if (schema.arguments()[index].type()->isSubtypeOf(TensorType::get()) || + if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) || schema.arguments()[index].type()->isSubtypeOf( - ListType::ofTensors()) || + *ListType::ofTensors()) || schema.arguments()[index].type()->isSubtypeOf( - ListType::ofOptionalTensors()) || + *ListType::ofOptionalTensors()) || schema.arguments()[index].type()->isSubtypeOf( - OptionalType::ofTensor())) { + *OptionalType::ofTensor())) { dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index); } } diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index 3769e8a10704..712192b08230 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -76,7 +76,7 @@ inline bool Argument::isBackwardCompatibleWith( if (lhs->kwarg_only() && !rhs->kwarg_only()) { return false; } - if (!rhs->type()->isSubtypeOfExt(lhs->type(), why_not)) { + if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) { return false; } if (rhs->default_value().has_value() && @@ -179,7 +179,7 @@ inline void FunctionSchema::checkArg( // Fast-path for the common case return; } - if (!value.type()->isSubtypeOf(argument.type())) { + if (!value.type()->isSubtypeOf(*argument.type())) { TORCH_CHECK( false, formatTypeMismatchMsg( @@ -304,7 +304,7 @@ inline bool isSubtypeOfList( if (c.name() != p.name()) { return false; } - if (!c.type()->isSubtypeOfExt(p.type(), why_not)) { + if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) { return false; } } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index e07d64a1c829..d78dd30c885f 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -46,7 +46,7 @@ struct TORCH_API AnyType : public Type { } static const TypeKind Kind = TypeKind::AnyType; // global singleton - static AnyTypePtr get(); + static const AnyTypePtr& get(); private: AnyType() : Type(TypeKind::AnyType) {} @@ -66,7 +66,7 @@ template struct SingleElementType : public Type { static const TypeKind Kind = K; - TypePtr getElementType() const { + const TypePtr& getElementType() const { return elem; } @@ -104,7 +104,7 @@ struct TORCH_API UnionType : public Type { static const TypeKind Kind = TypeKind::UnionType; - bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override; + bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override; std::string str() const override; @@ -169,7 +169,7 @@ struct TORCH_API OptionalType : public UnionType { bool operator==(const Type& rhs) const override; - TypePtr getElementType() const { + const TypePtr& getElementType() const { return contained_; } @@ -189,7 +189,7 @@ struct TORCH_API OptionalType : public UnionType { return create(contained_types[0]); } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; // common cast Optional[Tensor] for undefined tensor type static OptionalTypePtr ofTensor(); @@ -578,7 +578,7 @@ struct TORCH_API TensorType : public Type { } bool operator==(const Type& rhs) const override; - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; std::string str() const override; @@ -710,7 +710,7 @@ struct TORCH_API TensorType : public Type { c10::optional undefined() const { return undefined_; } - static TensorTypePtr get(); + static const TensorTypePtr& get(); static const TypeKind Kind = TypeKind::TensorType; @@ -788,7 +788,7 @@ struct TORCH_API ListType return create(contained_types.at(0)); } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; // common cast List[Tensor] static ListTypePtr ofTensors(); @@ -914,12 +914,12 @@ struct TORCH_API FutureType return create(contained_types.at(0)); } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { if (Type::isSubtypeOfExt(rhs, why_not)) { return true; } - if (auto rhs_ = rhs->cast()) { - return getElementType()->isSubtypeOfExt(rhs_->getElementType(), why_not); + if (auto rhs_ = rhs.castRaw()) { + return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not); } return false; } @@ -1034,7 +1034,7 @@ struct TORCH_API TupleType : public NamedType { } bool operator==(const Type& rhs) const override; - bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override; + bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override; std::string str() const override; bool hasFreeVariables() const override { @@ -1129,7 +1129,7 @@ struct TORCH_API EnumType : public NamedType { return false; } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; std::shared_ptr compilation_unit() const { auto cu = cu_.lock(); @@ -1182,7 +1182,7 @@ struct TORCH_API AnyEnumType : public Type { } static const TypeKind Kind = TypeKind::AnyEnumType; // global singleton - static AnyEnumTypePtr get(); + static const AnyEnumTypePtr& get(); private: AnyEnumType() : Type(TypeKind::AnyEnumType) {} @@ -1198,14 +1198,14 @@ using NumberTypePtr = std::shared_ptr; struct TORCH_API NumberType : public Type { bool operator==(const Type& rhs) const override; - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; std::string str() const override { return "Scalar"; // match what PythonArgParser says for clarity } static const TypeKind Kind = TypeKind::NumberType; // global singleton - static NumberTypePtr get(); + static const NumberTypePtr& get(); protected: NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {} @@ -1227,13 +1227,13 @@ struct TORCH_API FloatType : public NumberType { std::string str() const override { return "float"; } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { // NOLINTNEXTLINE(bugprone-parent-virtual-call) - return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); + return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); } static const TypeKind Kind = TypeKind::FloatType; // global singleton - static FloatTypePtr get(); + static const FloatTypePtr& get(); private: FloatType() : NumberType(TypeKind::FloatType) {} @@ -1252,13 +1252,13 @@ struct TORCH_API ComplexType : public NumberType { std::string str() const override { return "complex"; } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { // NOLINTNEXTLINE(bugprone-parent-virtual-call) - return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); + return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); } static const TypeKind Kind = TypeKind::ComplexType; // global singleton - static ComplexTypePtr get(); + static const ComplexTypePtr& get(); private: ComplexType() : NumberType(TypeKind::ComplexType) {} @@ -1277,13 +1277,13 @@ struct TORCH_API IntType : public NumberType { std::string str() const override { return "int"; } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { // NOLINTNEXTLINE(bugprone-parent-virtual-call) - return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); + return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); } static const TypeKind Kind = TypeKind::IntType; // global singleton - static IntTypePtr get(); + static const IntTypePtr& get(); private: IntType() : NumberType(TypeKind::IntType) {} @@ -1304,7 +1304,7 @@ struct TORCH_API BoolType : public Type { } static const TypeKind Kind = TypeKind::BoolType; // global singleton - static BoolTypePtr get(); + static const BoolTypePtr& get(); private: BoolType() : Type(TypeKind::BoolType) {} @@ -1326,7 +1326,7 @@ struct TORCH_API StringType : public Type { } static const TypeKind Kind = TypeKind::StringType; // global singleton - static StringTypePtr get(); + static const StringTypePtr& get(); private: StringType() : Type(TypeKind::StringType) {} @@ -1346,7 +1346,7 @@ struct TORCH_API StorageType : public Type { } static const TypeKind Kind = TypeKind::StorageType; // global singleton - static StorageTypePtr get(); + static const StorageTypePtr& get(); private: StorageType() : Type(TypeKind::StorageType) {} @@ -1393,11 +1393,11 @@ struct TORCH_API NoneType : public Type { std::string str() const override { return "NoneType"; } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const override; static const TypeKind Kind = TypeKind::NoneType; // global singleton - static NoneTypePtr get(); + static const NoneTypePtr& get(); private: NoneType() : Type(TypeKind::NoneType) {} @@ -1415,7 +1415,7 @@ struct TORCH_API GeneratorType : public Type { } static const TypeKind Kind = TypeKind::GeneratorType; // global singleton - static GeneratorTypePtr get(); + static const GeneratorTypePtr& get(); private: GeneratorType() : Type(TypeKind::GeneratorType) {} @@ -1433,7 +1433,7 @@ struct TORCH_API QuantizerType : public Type { } static const TypeKind Kind = TypeKind::QuantizerType; // global singleton - static QuantizerTypePtr get(); + static const QuantizerTypePtr& get(); private: QuantizerType() : Type(TypeKind::QuantizerType) {} @@ -1451,7 +1451,7 @@ struct TORCH_API QSchemeType : public Type { } static const TypeKind Kind = TypeKind::QSchemeType; // global singleton - static QSchemeTypePtr get(); + static const QSchemeTypePtr& get(); private: QSchemeType() : Type(TypeKind::QSchemeType) {} @@ -1469,7 +1469,7 @@ struct TORCH_API DeviceObjType : public Type { } static const TypeKind Kind = TypeKind::DeviceObjType; // global singleton - static DeviceObjTypePtr get(); + static const DeviceObjTypePtr& get(); private: DeviceObjType() : Type(TypeKind::DeviceObjType) {} @@ -1487,7 +1487,7 @@ struct TORCH_API StreamObjType : public Type { } static const TypeKind Kind = TypeKind::StreamObjType; // global singleton - static StreamObjTypePtr get(); + static const StreamObjTypePtr& get(); private: StreamObjType() : Type(TypeKind::StreamObjType) {} @@ -1533,7 +1533,7 @@ struct TORCH_API CapsuleType : public Type { } static const TypeKind Kind = TypeKind::CapsuleType; // global singleton - static CapsuleTypePtr get(); + static const CapsuleTypePtr& get(); private: CapsuleType() : Type(TypeKind::CapsuleType) {} @@ -1551,7 +1551,7 @@ struct TORCH_API PyObjectType : public Type { } static const TypeKind Kind = TypeKind::PyObjectType; // global singleton - static PyObjectTypePtr get(); + static const PyObjectTypePtr& get(); private: PyObjectType() : Type(TypeKind::PyObjectType) {} @@ -1589,18 +1589,18 @@ TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s); // Be careful with calls because this can be very slow. If calling this // on a graph, use `EraseShapeInformation` in shape_analysis.h inline TypePtr unshapedType(const TypePtr& type) { - if (type->isSubtypeOf(TensorType::get())) { + if (type->isSubtypeOf(*TensorType::get())) { return TensorType::get(); } return type->withContained(fmap(type->containedTypes(), unshapedType)); } inline TypePtr TensorType::fromNumberType(TypePtr typ) { - if (typ->isSubtypeOf(IntType::get())) { + if (typ->isSubtypeOf(*IntType::get())) { return TensorType::createContiguous(at::kLong, at::kCPU, {}); - } else if (typ->isSubtypeOf(FloatType::get())) { + } else if (typ->isSubtypeOf(*FloatType::get())) { return TensorType::createContiguous(at::kDouble, at::kCPU, {}); - } else if (typ->isSubtypeOf(BoolType::get())) { + } else if (typ->isSubtypeOf(*BoolType::get())) { return TensorType::createContiguous(at::kBool, at::kCPU, {}); } else if (typ->kind() == NumberType::Kind) { return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt); @@ -2013,7 +2013,7 @@ struct TORCH_API ClassType : public NamedType { return attributes_.size(); } - const TypePtr getAttribute(size_t slot) const { + TypePtr getAttribute(size_t slot) const { AT_ASSERT(slot < attributes_.size()); return attributes_.at(slot).getType(); } @@ -2104,7 +2104,7 @@ struct TORCH_API ClassType : public NamedType { "'"); TypePtr atype = getAttribute(*slot_idx); TORCH_CHECK( - ty->isSubtypeOf(atype), + ty->isSubtypeOf(*atype), ty->repr_str(), " is not compatible with the type ", atype->repr_str(), @@ -2198,7 +2198,7 @@ struct TORCH_API ClassType : public NamedType { auto ptr = ClassType::create(name(), compilation_unit_, is_module()); AT_ASSERT(numAttributes() == contained_types.size()); for(size_t i = 0; i < attributes_.size(); ++i) { - AT_ASSERT(attributes_[i].getType()->isSubtypeOf(contained_types[i])); + AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i])); ptr->addAttribute(attributes_[i].getName(), contained_types[i]); } // Copy methods over @@ -2273,7 +2273,7 @@ struct TORCH_API ClassType : public NamedType { // These variants are not registered in the global class table. ClassTypePtr refine(at::ArrayRef refined_slots) const; - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; static const TypeKind Kind = TypeKind::ClassType; @@ -2360,13 +2360,13 @@ struct TORCH_API InterfaceType : public NamedType { return std::string("InterfaceType<") + name()->name() + ">"; } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; // try to find a method of this interface, // returns nullptr if not found. const FunctionSchema* getMethod(const std::string& name) const; void addMethod(FunctionSchema schema); - const std::vector& methods() { + const std::vector& methods() const { return *methods_; } @@ -2414,7 +2414,7 @@ return "Layout"; } static const TypeKind Kind = TypeKind::LayoutType; // global singleton -static LayoutTypePtr get(); +static const LayoutTypePtr& get(); private: LayoutType() : EnumerationType() {} @@ -2429,7 +2429,7 @@ return "ScalarType"; } static const TypeKind Kind = TypeKind::ScalarTypeType; // global singleton -static ScalarTypeTypePtr get(); +static const ScalarTypeTypePtr& get(); private: ScalarTypeType() : EnumerationType() {} @@ -2448,7 +2448,7 @@ struct TORCH_API AnyListType : public Type { } static const TypeKind Kind = TypeKind::AnyListType; // global singleton - static AnyListTypePtr get(); + static const AnyListTypePtr& get(); private: AnyListType() : Type(TypeKind::AnyListType) {} @@ -2469,7 +2469,7 @@ struct TORCH_API AnyTupleType : public Type { static const TypeKind Kind = TypeKind::AnyTupleType; // global singleton - static AnyTupleTypePtr get(); + static const AnyTupleTypePtr& get(); private: AnyTupleType() : Type(TypeKind::AnyTupleType) {} @@ -2488,7 +2488,7 @@ struct TORCH_API AnyClassType : public Type { } static const TypeKind Kind = TypeKind::AnyClassType; // global singleton - static AnyClassTypePtr get(); + static const AnyClassTypePtr& get(); private: AnyClassType() : Type(TypeKind::AnyClassType) {} diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h index a9be1e8d6865..b34788ba802d 100644 --- a/aten/src/ATen/core/jit_type_base.h +++ b/aten/src/ATen/core/jit_type_base.h @@ -87,11 +87,24 @@ struct TORCH_API Type : std::enable_shared_from_this { // This additional information should only contain details that are not obvious // from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false // but not clear why `Foo <: InterfaceBar` might be false. - virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const; + virtual bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const; virtual bool is_module() const; - bool isSubtypeOf(const TypePtr& rhs) const { + bool isSubtypeOf(const Type& rhs) const { return isSubtypeOfExt(rhs, nullptr); } + // Compatibility shims to accommodate existing code that passes shared_ptrs around. + // Ideally, we would just delete this, but it should be harmless. + template + typename std::enable_if::value, bool>::type + isSubtypeOf(const std::shared_ptr& rhs) const { + return isSubtypeOf(*rhs); + } + + template + typename std::enable_if::value, bool>::type + isSubtypeOfExt(const std::shared_ptr& rhs, std::ostream* why_not) const { + return isSubtypeOfExt(*rhs, why_not); + } // How this type will appear in FunctionSchema declarations virtual std::string str() const = 0; diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index cad3969d200b..515afdca77b5 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -153,74 +153,74 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { return out; } -AnyTypePtr AnyType::get() { +const AnyTypePtr& AnyType::get() { static AnyTypePtr value(new AnyType()); return value; } -TensorTypePtr TensorType::get() { +const TensorTypePtr& TensorType::get() { static auto value = TensorType::create( {}, {}, SymbolicShape(), VaryingShape{}, {}); return value; } -NumberTypePtr NumberType::get() { +const NumberTypePtr& NumberType::get() { static NumberTypePtr value(new NumberType()); return value; } -IntTypePtr IntType::get() { +const IntTypePtr& IntType::get() { static IntTypePtr value(new IntType()); return value; } -FloatTypePtr FloatType::get() { +const FloatTypePtr& FloatType::get() { static FloatTypePtr value(new FloatType()); return value; } -ComplexTypePtr ComplexType::get() { +const ComplexTypePtr& ComplexType::get() { static ComplexTypePtr value(new ComplexType()); return value; } -BoolTypePtr BoolType::get() { +const BoolTypePtr& BoolType::get() { static BoolTypePtr value(new BoolType()); return value; } -StorageTypePtr StorageType::get() { +const StorageTypePtr& StorageType::get() { static StorageTypePtr value(new StorageType()); return value; } -NoneTypePtr NoneType::get() { +const NoneTypePtr& NoneType::get() { static NoneTypePtr value(new NoneType()); return value; } -GeneratorTypePtr GeneratorType::get() { +const GeneratorTypePtr& GeneratorType::get() { static GeneratorTypePtr value(new GeneratorType()); return value; } -QuantizerTypePtr QuantizerType::get() { +const QuantizerTypePtr& QuantizerType::get() { static QuantizerTypePtr value(new QuantizerType()); return value; } -QSchemeTypePtr QSchemeType::get() { +const QSchemeTypePtr& QSchemeType::get() { static QSchemeTypePtr value(new QSchemeType()); return value; } -StringTypePtr StringType::get() { +const StringTypePtr& StringType::get() { static StringTypePtr value(new StringType()); return value; } -DeviceObjTypePtr DeviceObjType::get() { +const DeviceObjTypePtr& DeviceObjType::get() { static DeviceObjTypePtr value(new DeviceObjType()); return value; } -StreamObjTypePtr StreamObjType::get() { +const StreamObjTypePtr& StreamObjType::get() { static StreamObjTypePtr value(new StreamObjType()); return value; } -ScalarTypeTypePtr ScalarTypeType::get() { +const ScalarTypeTypePtr& ScalarTypeType::get() { static ScalarTypeTypePtr value(new ScalarTypeType()); return value; } -LayoutTypePtr LayoutType::get() { +const LayoutTypePtr& LayoutType::get() { static LayoutTypePtr value(new LayoutType()); return value; } @@ -228,11 +228,11 @@ OptionalTypePtr OptionalType::ofTensor() { static auto value = OptionalType::create(TensorType::get()); return value; } -PyObjectTypePtr PyObjectType::get() { +const PyObjectTypePtr& PyObjectType::get() { static PyObjectTypePtr value(new PyObjectType()); return value; } -CapsuleTypePtr CapsuleType::get() { +const CapsuleTypePtr& CapsuleType::get() { static CapsuleTypePtr value(new CapsuleType()); return value; } @@ -265,31 +265,31 @@ ListTypePtr ListType::ofStrings() { return value; } -AnyListTypePtr AnyListType::get() { +const AnyListTypePtr& AnyListType::get() { static AnyListTypePtr value(new AnyListType()); return value; } -AnyTupleTypePtr AnyTupleType::get() { +const AnyTupleTypePtr& AnyTupleType::get() { static AnyTupleTypePtr value(new AnyTupleType()); return value; } -AnyClassTypePtr AnyClassType::get() { +const AnyClassTypePtr& AnyClassType::get() { static AnyClassTypePtr value(new AnyClassType()); return value; } -AnyEnumTypePtr AnyEnumType::get() { +const AnyEnumTypePtr& AnyEnumType::get() { static AnyEnumTypePtr value(new AnyEnumType()); return value; } c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_union=false, TypePtr type_hint=nullptr) { // check direct subtyping relation - if (t1->isSubtypeOf(t2)) { + if (t1->isSubtypeOf(*t2)) { return t2; - } else if (t2->isSubtypeOf(t1)) { + } else if (t2->isSubtypeOf(*t1)) { return t1; } @@ -298,9 +298,9 @@ c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool return t1->expectRef().merge(*t2->expect()); } - if (t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) { + if (t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get())) { return OptionalType::create(t2); - } else if (t2->isSubtypeOf(NoneType::get()) && !t1->isSubtypeOf(NoneType::get())) { + } else if (t2->isSubtypeOf(*NoneType::get()) && !t1->isSubtypeOf(*NoneType::get())) { return OptionalType::create(t1); } @@ -351,16 +351,16 @@ c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool auto t1_unshaped = unshapedType(t1); auto t2_unshaped = unshapedType(t2); - if (t1_unshaped->isSubtypeOf(t2_unshaped)) { + if (t1_unshaped->isSubtypeOf(*t2_unshaped)) { return t2_unshaped; - } else if (t2_unshaped->isSubtypeOf(t1_unshaped)) { + } else if (t2_unshaped->isSubtypeOf(*t1_unshaped)) { return t1_unshaped; } // Check whether or not `type_hint` is a common parent. This case // could occur if we had two class types that had been annotated with // a common interface - if (type_hint && t1->isSubtypeOf(type_hint) && t2->isSubtypeOf(type_hint)) { + if (type_hint && t1->isSubtypeOf(*type_hint) && t2->isSubtypeOf(*type_hint)) { return type_hint; } @@ -505,7 +505,7 @@ MatchTypeReturn matchTypeVariables( // NOLINTNEXTLINE(performance-no-automatic-move) return optionedMatch; } - } else if (!actual->isSubtypeOf(NoneType::get())) { + } else if (!actual->isSubtypeOf(*NoneType::get())) { // If the actual type is a non-optional, allow matching to the formal if // its element type matches the actual. // Don't match None because it is already an optional (but one of @@ -594,19 +594,19 @@ const char * typeKindToString(TypeKind kind) { return ""; } -bool Type::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { - if (rhs->kind() == TypeKind::AnyType || *this == *rhs) { +bool Type::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { + if (rhs.kind() == TypeKind::AnyType || *this == rhs) { return true; } - if (auto opt_rhs = rhs->cast()) { - return this->isSubtypeOfExt(opt_rhs->getElementType(), why_not); + if (auto opt_rhs = rhs.castRaw()) { + return this->isSubtypeOfExt(*opt_rhs->getElementType(), why_not); } - if (auto union_rhs = rhs->cast()) { + if (auto union_rhs = rhs.castRaw()) { // Check if `this` is a subtype of any of the types within the Union return std::any_of(union_rhs->containedTypes().begin(), union_rhs->containedTypes().end(), - [&](TypePtr inner) { - return this->isSubtypeOfExt(inner, why_not); + [&](const TypePtr& inner) { + return this->isSubtypeOfExt(*inner, why_not); }); } return false; @@ -837,8 +837,8 @@ TupleTypePtr TupleType::createNamed(const c10::optional& qua field_types, qualName, schema)); // NOLINT(modernize-make-shared) } -bool NoneType::isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const { - if (rhs->kind() == OptionalType::Kind) { +bool NoneType::isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const { + if (rhs.kind() == OptionalType::Kind) { return true; } return Type::isSubtypeOfExt(rhs, why_not); @@ -882,8 +882,8 @@ void filterDuplicateSubtypes(std::vector* types) { auto get_supertype = [](const TypePtr t1, const TypePtr t2) -> c10::optional { // We don't want nested Optionals. Also, prematurely unifying to // `Optional` could prevent us from coalescing other types - if ((t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) - || (!t1->isSubtypeOf(NoneType::get()) && t2->isSubtypeOf(NoneType::get()))) { + if ((t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get())) + || (!t1->isSubtypeOf(*NoneType::get()) && t2->isSubtypeOf(*NoneType::get()))) { return c10::nullopt; } else { return unifyTypes(t1, t2, /*default_to_union=*/false); @@ -1064,34 +1064,36 @@ bool UnionType::operator==(const Type& rhs) const { } } -bool UnionType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { - std::vector rhs_types; - if (const auto union_rhs = rhs->cast()) { +bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { + std::vector rhs_types; + if (const auto union_rhs = rhs.cast()) { // Fast path - if (this->containedTypes() == rhs->containedTypes()) { + if (this->containedTypes() == rhs.containedTypes()) { return true; } - rhs_types = rhs->containedTypes().vec(); - } else if (const auto optional_rhs = rhs->cast()) { - rhs_types.push_back(NoneType::get()); + for (const auto& typePtr: rhs.containedTypes()) { + rhs_types.push_back(typePtr.get()); + } + } else if (const auto optional_rhs = rhs.cast()) { + rhs_types.push_back(NoneType::get().get()); if (optional_rhs->getElementType() == NumberType::get()) { - std::vector number_types{IntType::get(), FloatType::get(), ComplexType::get()}; + std::array number_types{IntType::get().get(), FloatType::get().get(), ComplexType::get().get()}; rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end()); } else { - rhs_types.push_back(optional_rhs->getElementType()); + rhs_types.push_back(optional_rhs->getElementType().get()); } - } else if (const auto number_rhs = rhs->cast()) { - std::vector number_types{IntType::get(), FloatType::get(), ComplexType::get()}; + } else if (const auto number_rhs = rhs.cast()) { + std::array number_types{IntType::get().get(), FloatType::get().get(), ComplexType::get().get()}; rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end()); } else { - rhs_types.push_back(rhs); + rhs_types.push_back(&rhs); } return std::all_of(this->containedTypes().begin(), this->containedTypes().end(), - [&](TypePtr lhs_type) -> bool { + [&](const TypePtr& lhs_type) -> bool { return std::any_of(rhs_types.begin(), rhs_types.end(), - [&](TypePtr rhs_type) -> bool { - return lhs_type->isSubtypeOfExt(rhs_type, why_not); + [&](const Type* rhs_type) -> bool { + return lhs_type->isSubtypeOfExt(*rhs_type, why_not); }); }); } @@ -1157,8 +1159,8 @@ bool UnionType::canHoldType(TypePtr type) const { && canHoldType(ComplexType::get()); } else { return std::any_of(this->containedTypes().begin(), this->containedTypes().end(), - [&](TypePtr inner) { - return type->isSubtypeOf(inner); + [&](const TypePtr& inner) { + return type->isSubtypeOf(*inner); }); } } @@ -1184,10 +1186,10 @@ c10::optional UnionType::subtractTypeSet(std::vector& to_subtr // Given a TypePtr `lhs`, this function says whether or not `lhs` (or // one of its parent types) is in the `to_subtract` vector - auto should_subtract = [&](TypePtr lhs) -> bool { + auto should_subtract = [&](const TypePtr& lhs) -> bool { return std::any_of(to_subtract.begin(), to_subtract.end(), - [&](TypePtr rhs) { - return lhs->isSubtypeOf(rhs); + [&](const TypePtr& rhs) { + return lhs->isSubtypeOf(*rhs); }); }; @@ -1195,7 +1197,7 @@ c10::optional UnionType::subtractTypeSet(std::vector& to_subtr // vector std::copy_if(this->containedTypes().begin(), this->containedTypes().end(), std::back_inserter(types), - [&](const TypePtr t) { + [&](const TypePtr& t) { return !should_subtract(t); }); @@ -1245,18 +1247,18 @@ bool OptionalType::operator==(const Type& rhs) const { } } -bool OptionalType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { - if (OptionalTypePtr optional_rhs = rhs->cast()) { - return getElementType()->isSubtypeOfExt(optional_rhs->getElementType(), why_not); - } else if (UnionTypePtr union_rhs = rhs->cast()) { +bool OptionalType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { + if (auto optional_rhs = rhs.castRaw()) { + return getElementType()->isSubtypeOfExt(*optional_rhs->getElementType(), why_not); + } else if (auto union_rhs = rhs.castRaw()) { if (!union_rhs->canHoldType(NoneType::get())) { if (why_not) { - *why_not << rhs->repr_str() << " cannot hold None"; + *why_not << rhs.repr_str() << " cannot hold None"; } return false; } else if (!union_rhs->canHoldType(this->getElementType())) { if (why_not) { - *why_not << rhs->repr_str() << " cannot hold " << this->getElementType(); + *why_not << rhs.repr_str() << " cannot hold " << this->getElementType(); } return false; } else { @@ -1276,8 +1278,8 @@ bool NumberType::operator==(const Type& rhs) const { } } -bool NumberType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { - if (auto union_type = rhs->cast()) { +bool NumberType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { + if (auto union_type = rhs.cast()) { return union_type->canHoldType(NumberType::get()); } else { return Type::isSubtypeOfExt(rhs, why_not); @@ -1305,14 +1307,14 @@ TupleType::TupleType( } } -bool TupleType::isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const { +bool TupleType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const { if (Type::isSubtypeOfExt(rhs_, why_not)) { return true; } - if (rhs_->kind() == AnyTupleType::Kind) { + if (rhs_.kind() == AnyTupleType::Kind) { return true; } - auto rhs = rhs_->cast(); + auto rhs = rhs_.cast(); if (!rhs) return false; // unnamed tuple is not a subtype of nametuple @@ -1336,15 +1338,15 @@ bool TupleType::isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const bool names_match = !rhs->schema() || test_names_match(schema(), rhs->schema()); // co-variant rules for tuples return names_match && compare(*rhs, [&](const TypePtr a, const TypePtr b) { - return a->isSubtypeOfExt(b, why_not); + return a->isSubtypeOfExt(*b, why_not); }); } -bool ListType::isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const { +bool ListType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const { if (Type::isSubtypeOfExt(rhs_, why_not)) { return true; } - if (rhs_->kind() == AnyListType::Kind) { + if (rhs_.kind() == AnyListType::Kind) { return true; } return false; @@ -1622,8 +1624,8 @@ const SymbolicShape& TensorType::symbolic_sizes() const { return sizes_; } -bool TensorType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { - if (auto rhs_p = rhs->cast()) { +bool TensorType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { + if (auto rhs_p = rhs.cast()) { // if we have the same pointer, avoid computing the merge if (this == rhs_p.get()) { return true; @@ -2034,7 +2036,7 @@ ClassTypePtr ClassType::refine(at::ArrayRef refined_slots) const { auto ptr = ClassType::create(name(), compilation_unit_, is_module()); AT_ASSERT(numAttributes() == refined_slots.size()); for (size_t i = 0; i < attributes_.size(); ++i) { - AT_ASSERT(refined_slots[i]->isSubtypeOf(attributes_[i].getType())); + AT_ASSERT(refined_slots[i]->isSubtypeOf(*attributes_[i].getType())); ptr->addAttribute(attributes_[i].getName(), refined_slots[i], (attributes_[i].getKind() == AttributeKind::PARAMETER), (attributes_[i].getKind() == AttributeKind::BUFFER)); } @@ -2045,18 +2047,18 @@ ClassTypePtr ClassType::refine(at::ArrayRef refined_slots) const { return ptr; } -bool ClassType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { - if (rhs->cast()) { +bool ClassType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { + if (rhs.castRaw()) { return true; } // to improve performance, this check can be cached - if (auto iface = rhs->cast()) { + if (auto iface = rhs.cast()) { // ClassType is not a subtype of InterfaceType if the InterfaceType is a // Module Interface Type but the Class Type is not a Module Class Type if (!is_module() && iface->is_module()) { if (why_not) { *why_not << "Class '" << repr_str() << "' is not a subtype of " - << "the module interface '" << rhs->repr_str() + << "the module interface '" << rhs.repr_str() << "' , only ScriptModule class can be subtype of module" << " interface.\n"; } @@ -2067,7 +2069,7 @@ bool ClassType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const if (!self_method) { if (why_not) { *why_not << "Class '" << repr_str() << "' does not have method '" - << schema.name() << "' but '" << rhs->repr_str() + << schema.name() << "' but '" << rhs.repr_str() << "' does.\n"; } return false; @@ -2078,7 +2080,7 @@ bool ClassType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const if (why_not) { *why_not << "Method on class '" << repr_str() << "' (1) is not compatible with interface '" - << rhs->repr_str() << "' (2)\n" + << rhs.repr_str() << "' (2)\n" << " (1) " << self_method->getSchema() << "\n" << " (2) " << schema << "\n"; } @@ -2131,9 +2133,9 @@ bool InterfaceType::isSubTypeImpl( return true; } -bool InterfaceType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { +bool InterfaceType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { // to improve performance this check can be cached - if (auto iface = rhs->cast()) { + if (auto iface = rhs.cast()) { return isSubTypeImpl(*this, *iface, why_not); } return Type::isSubtypeOfExt(rhs, why_not); @@ -2257,7 +2259,7 @@ size_t ClassType::addAttribute( type->expect()->getElementType()->kind() == TensorType::Kind) || (type->kind() == UnionType::Kind && - TensorType::get()->isSubtypeOf(type->expect())) || + TensorType::get()->isSubtypeOf(type->expectRef())) || (type->kind() == NoneType::Kind), "Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ", toString(type)); @@ -2402,10 +2404,10 @@ void SymbolicShape::dump() const { std::cout << *this << "\n"; } -bool EnumType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { - return rhs->kind() == TypeKind::AnyType || - rhs->kind() == TypeKind::AnyEnumType || - *this == *rhs || +bool EnumType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { + return rhs.kind() == TypeKind::AnyType || + rhs.kind() == TypeKind::AnyEnumType || + *this == rhs || Type::isSubtypeOfExt(rhs, why_not); } diff --git a/caffe2/core/export_c10_op_to_caffe2.h b/caffe2/core/export_c10_op_to_caffe2.h index 767177accec4..9dda158c63d6 100644 --- a/caffe2/core/export_c10_op_to_caffe2.h +++ b/caffe2/core/export_c10_op_to_caffe2.h @@ -46,7 +46,7 @@ class C10OperatorWrapper final : public Operator { AT_ASSERT( !has_preallocated_outputs_ || op_.schema().arguments().back().type()->isSubtypeOf( - OptionalType::create(ListType::ofTensors()))); + *OptionalType::create(ListType::ofTensors()))); // NOLINTNEXTLINE(clang-diagnostic-sign-compare) AT_ASSERT(operator_def.output_size() == op_.schema().returns().size()); @@ -89,13 +89,13 @@ class C10OperatorWrapper final : public Operator { AT_ASSERTM( argument.type()->isSubtypeOf( - OptionalType::create(ListType::ofTensors())), + *OptionalType::create(ListType::ofTensors())), "Error in caffe2->c10 wrapper: Operator schema has a parameter named ", detail::PREALLOCATED_OUTPUT_ARGNAME, ", but it's not of type TensorList?"); stack_.emplace_back(preallocated_outputs_()); - } else if (argument.type()->isSubtypeOf(TensorType::get())) { + } else if (argument.type()->isSubtypeOf(*TensorType::get())) { AT_ASSERTM( // NOLINTNEXTLINE(clang-diagnostic-sign-compare) input_tensor_index < InputSize(), @@ -103,13 +103,13 @@ class C10OperatorWrapper final : public Operator { InputSize(), "), operator schema expected more."); stack_.emplace_back(at::Tensor(Input(input_tensor_index++))); - } else if (argument.type()->isSubtypeOf(OptionalType::ofTensor())) { + } else if (argument.type()->isSubtypeOf(*OptionalType::ofTensor())) { if (input_tensor_index < InputSize()) { stack_.emplace_back(at::Tensor(Input(input_tensor_index++))); } else { stack_.emplace_back(IValue()); } - } else if (argument.type()->isSubtypeOf(ListType::ofTensors())) { + } else if (argument.type()->isSubtypeOf(*ListType::ofTensors())) { AT_ASSERTM( input_tensor_index == 0, "Error in caffe2->c10 wrapper: Schema can only have either one or more Tensor inputs or one TensorList input."); @@ -163,13 +163,13 @@ class C10OperatorWrapper final : public Operator { } IValue get_nontensor_argument_(const c10::Argument& argument) { - if (argument.type()->isSubtypeOf(IntType::get())) { + if (argument.type()->isSubtypeOf(*IntType::get())) { return get_nontensor_argument_( argument.name(), argument.default_value()); - } else if (argument.type()->isSubtypeOf(FloatType::get())) { + } else if (argument.type()->isSubtypeOf(*FloatType::get())) { return get_nontensor_argument_( argument.name(), argument.default_value()); - } else if (argument.type()->isSubtypeOf(BoolType::get())) { + } else if (argument.type()->isSubtypeOf(*BoolType::get())) { return get_nontensor_argument_( argument.name(), argument.default_value()); } else { diff --git a/caffe2/core/export_caffe2_op_to_c10.h b/caffe2/core/export_caffe2_op_to_c10.h index fdb06f6ff2c0..bac3b0fd5846 100644 --- a/caffe2/core/export_caffe2_op_to_c10.h +++ b/caffe2/core/export_caffe2_op_to_c10.h @@ -57,7 +57,7 @@ inline void _call_caffe2_op_from_c10( AT_ASSERT( schema.arguments().size() != 0 && schema.arguments().back().type()->isSubtypeOf( - OptionalType::create(ListType::ofTensors()))); + *OptionalType::create(ListType::ofTensors()))); IValue preallocated_outputs = torch::jit::pop(*stack); const size_t num_outputs = schema.returns().size(); diff --git a/test/cpp/jit/test_custom_operators.cpp b/test/cpp/jit/test_custom_operators.cpp index 39be82ea2343..58f87717844d 100644 --- a/test/cpp/jit/test_custom_operators.cpp +++ b/test/cpp/jit/test_custom_operators.cpp @@ -86,20 +86,20 @@ TEST(CustomOperatorTest, ListParameters) { ASSERT_EQ(op->schema().arguments().size(), 4); ASSERT_EQ(op->schema().arguments()[0].name(), "ints"); ASSERT_TRUE( - op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofInts())); + op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofInts())); ASSERT_EQ(op->schema().arguments()[1].name(), "floats"); ASSERT_TRUE( - op->schema().arguments()[1].type()->isSubtypeOf(ListType::ofFloats())); + op->schema().arguments()[1].type()->isSubtypeOf(*ListType::ofFloats())); ASSERT_EQ(op->schema().arguments()[2].name(), "complexdoubles"); ASSERT_TRUE(op->schema().arguments()[2].type()->isSubtypeOf( - ListType::ofComplexDoubles())); + *ListType::ofComplexDoubles())); ASSERT_EQ(op->schema().arguments()[3].name(), "tensors"); ASSERT_TRUE( - op->schema().arguments()[3].type()->isSubtypeOf(ListType::ofTensors())); + op->schema().arguments()[3].type()->isSubtypeOf(*ListType::ofTensors())); ASSERT_EQ(op->schema().returns().size(), 1); ASSERT_TRUE( - op->schema().returns()[0].type()->isSubtypeOf(ListType::ofFloats())); + op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofFloats())); Stack stack; push(stack, c10::List({1, 2})); @@ -132,11 +132,11 @@ TEST(CustomOperatorTest, ListParameters2) { ASSERT_EQ(op->schema().arguments().size(), 1); ASSERT_EQ(op->schema().arguments()[0].name(), "tensors"); ASSERT_TRUE( - op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofTensors())); + op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofTensors())); ASSERT_EQ(op->schema().returns().size(), 1); ASSERT_TRUE( - op->schema().returns()[0].type()->isSubtypeOf(ListType::ofTensors())); + op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofTensors())); Stack stack; push(stack, c10::List({at::ones(5)})); diff --git a/test/cpp/jit/test_irparser.cpp b/test/cpp/jit/test_irparser.cpp index b83959455dd1..08edef0ceec6 100644 --- a/test/cpp/jit/test_irparser.cpp +++ b/test/cpp/jit/test_irparser.cpp @@ -145,7 +145,7 @@ TEST(IRParserTest, InferredTypeIsTensor) { graph(%a): return (%a))IR", &*graph); - AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get())); + AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get())); } TEST(IRParserTest, ValueReuse) { @@ -260,7 +260,7 @@ TEST(IRParserTest, FileCheck) { return (%a))IR"; parseIR(text, &*graph); - AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get())); + AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get())); torch::jit::testing::FileCheck().run(text, *graph); } diff --git a/test/cpp/jit/test_jit_type.cpp b/test/cpp/jit/test_jit_type.cpp index ed9a99481408..606c1b0fa36e 100644 --- a/test/cpp/jit/test_jit_type.cpp +++ b/test/cpp/jit/test_jit_type.cpp @@ -26,14 +26,14 @@ TEST(JitTypeTest, UnifyTypes) { auto bool_tensor = TensorType::get()->withScalarType(at::kBool); auto opt_bool_tensor = OptionalType::create(bool_tensor); auto unified_opt_bool = unifyTypes(bool_tensor, opt_bool_tensor); - TORCH_INTERNAL_ASSERT(opt_bool_tensor->isSubtypeOf(*unified_opt_bool)); + TORCH_INTERNAL_ASSERT(opt_bool_tensor->isSubtypeOf(**unified_opt_bool)); auto tensor = TensorType::get(); - TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(opt_bool_tensor)); + TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(*opt_bool_tensor)); auto unified = unifyTypes(opt_bool_tensor, tensor); TORCH_INTERNAL_ASSERT(unified); auto elem = (*unified)->expectRef().getElementType(); - TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(TensorType::get())); + TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(*TensorType::get())); auto opt_tuple_none_int = OptionalType::create( TupleType::create({NoneType::get(), IntType::get()})); @@ -52,7 +52,7 @@ TEST(JitTypeTest, UnifyTypes) { auto fut_out = unifyTypes(fut_1, fut_2); TORCH_INTERNAL_ASSERT(fut_out); TORCH_INTERNAL_ASSERT((*fut_out)->isSubtypeOf( - FutureType::create(OptionalType::create(IntType::get())))); + *FutureType::create(OptionalType::create(IntType::get())))); auto dict_1 = DictType::create(IntType::get(), NoneType::get()); auto dict_2 = DictType::create(IntType::get(), IntType::get()); diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 6807dc2bc87b..0053e00c4816 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -498,21 +498,21 @@ TEST(SchemaParserTest, NestedArrays) { // nested arrays auto s = parseSchema("at::what(int[][4] foo) -> ()"); ASSERT_TRUE(s.arguments().at(0).N() == 4); - ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments() - .at(0) - .type() - ->expectRef() - .getElementType() - ->expectRef() - .getElementType())); + ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments() + .at(0) + .type() + ->expectRef() + .getElementType() + ->expectRef() + .getElementType())); auto s2 = parseSchema("at::what(int[][] foo) -> ()"); - ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments() - .at(0) - .type() - ->expectRef() - .getElementType() - ->expectRef() - .getElementType())); + ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments() + .at(0) + .type() + ->expectRef() + .getElementType() + ->expectRef() + .getElementType())); } TEST(SchemaParserTest, OutVariant) { @@ -550,7 +550,7 @@ TEST(SchemaParserTest, Futures) { // futures auto s4 = parseSchema("at::what(Future(int) foo) -> ()"); ASSERT_TRUE(IntType::get()->isSubtypeOf( - s4.arguments().at(0).type()->expectRef().getElementType())); + *s4.arguments().at(0).type()->expectRef().getElementType())); } TEST(SchemaParserTest, AnnotatedAliasSets) { diff --git a/test/cpp/jit/test_save_load.cpp b/test/cpp/jit/test_save_load.cpp index 658762530df3..705073b10696 100644 --- a/test/cpp/jit/test_save_load.cpp +++ b/test/cpp/jit/test_save_load.cpp @@ -116,8 +116,8 @@ TEST(SerializationTest, TypeTags) { for (auto item : items) { auto bytes = torch::pickle_save(item.value); auto loaded = torch::pickle_load(bytes); - ASSERT_TRUE(loaded.type()->isSubtypeOf(item.expected_type)); - ASSERT_TRUE(item.expected_type->isSubtypeOf(loaded.type())); + ASSERT_TRUE(loaded.type()->isSubtypeOf(*item.expected_type)); + ASSERT_TRUE(item.expected_type->isSubtypeOf(*loaded.type())); } } diff --git a/test/cpp/jit/test_union.cpp b/test/cpp/jit/test_union.cpp index f35acd35d1ed..5a183674615e 100644 --- a/test/cpp/jit/test_union.cpp +++ b/test/cpp/jit/test_union.cpp @@ -105,12 +105,12 @@ TEST_F(UnionTypeTest, Subtyping_NumberType) { const NumberTypePtr num = NumberType::get(); - ASSERT_TRUE(num->isSubtypeOf(union1)); - ASSERT_TRUE(union1->isSubtypeOf(num)); + ASSERT_TRUE(num->isSubtypeOf(*union1)); + ASSERT_TRUE(union1->isSubtypeOf(*num)); ASSERT_TRUE(*num == *union1); - ASSERT_TRUE(num->isSubtypeOf(union2)); - ASSERT_FALSE(union2->isSubtypeOf(num)); + ASSERT_TRUE(num->isSubtypeOf(*union2)); + ASSERT_FALSE(union2->isSubtypeOf(*num)); ASSERT_FALSE(*num == *union2); } diff --git a/torch/csrc/autograd/TraceTypeManual.cpp b/torch/csrc/autograd/TraceTypeManual.cpp index 845f9df67ed1..6426cf3c1eb8 100644 --- a/torch/csrc/autograd/TraceTypeManual.cpp +++ b/torch/csrc/autograd/TraceTypeManual.cpp @@ -185,7 +185,7 @@ void general_trace_function( type = type->expectRef().getElementType(); } } - if (type->isSubtypeOf(TensorType::get())) { + if (type->isSubtypeOf(*TensorType::get())) { AT_ASSERT(iter->isTensor()); tracer::addInputs(node, args[i].name().c_str(), iter->toTensor()); } else if (type->kind() == TypeKind::FloatType) { @@ -204,7 +204,7 @@ void general_trace_function( tracer::addInputs(node, args[i].name().c_str(), iter->toScalar()); } else if (type->kind() == TypeKind::ListType) { const auto& elem_type = type->expectRef().getElementType(); - if (elem_type->isSubtypeOf(TensorType::get())) { + if (elem_type->isSubtypeOf(*TensorType::get())) { AT_ASSERT(iter->isTensorList()); auto list = iter->toTensorVector(); tracer::addInputs(node, args[i].name().c_str(), list); @@ -265,12 +265,12 @@ void general_trace_function( for (auto iter = stack->end() - output_size; iter != stack->end(); ++iter, ++i) { const auto& type = op.schema().returns()[i].type(); - if (type->isSubtypeOf(TensorType::get())) { + if (type->isSubtypeOf(*TensorType::get())) { AT_ASSERT(iter->isTensor()); tracer::addOutput(node, iter->toTensor()); } else if (type->kind() == TypeKind::ListType) { const auto& elem_type = type->expectRef().getElementType(); - if (elem_type->isSubtypeOf(TensorType::get())) { + if (elem_type->isSubtypeOf(*TensorType::get())) { AT_ASSERT(iter->isTensorList()); tracer::addOutput(node, iter->toTensorList()); } else { diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 0fbbe9a75d26..49eaf3a6e2fd 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -72,7 +72,7 @@ TypePtr tryInferTypeWithTypeHint( TORCH_CHECK( type_hint_ptr != nullptr && module.value().type()->isSubtypeOfExt( - type_hint_ptr, &subtype_check_msg), + *type_hint_ptr, &subtype_check_msg), module.value().type()->repr_str(), " is not a subtype of the type hint: ", type_qualified_name.qualifiedName(), diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index 7e68d90965a0..004e9422be42 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -348,9 +348,9 @@ c10::intrusive_ptr RRefContext::getOrCreateOwnerRRef( // since Tensor can only get specialized with a previous run of local // JIT function, and we shouldn't preserve the specialized SubTensorType // information on other workers because it's only information only. - if (type->isSubtypeOf(TensorType::get())) { + if (type->isSubtypeOf(*TensorType::get())) { TORCH_INTERNAL_ASSERT( - ownerRRef->type()->isSubtypeOf(TensorType::get()), + ownerRRef->type()->isSubtypeOf(*TensorType::get()), "Expect OwnerRRef to be a sub-type of TensorType, but got ", ownerRRef->type()->repr_str()); } else { diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index e79149b39428..511e6d22d6e4 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -535,7 +535,7 @@ struct TORCH_API BufferPolicy { return std::move(v).toTensor(); } static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { - return typ->getAttribute(i)->isSubtypeOf(TensorType::get()) && + return typ->getAttribute(i)->isSubtypeOf(*TensorType::get()) && typ->is_buffer(i); } static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false; diff --git a/torch/csrc/jit/api/object.h b/torch/csrc/jit/api/object.h index 9871c1c56664..406647d11156 100644 --- a/torch/csrc/jit/api/object.h +++ b/torch/csrc/jit/api/object.h @@ -54,7 +54,7 @@ struct TORCH_API Object { } else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) { const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot); TORCH_CHECK( - v.type()->isSubtypeOf(expected), + v.type()->isSubtypeOf(*expected), "Expected a value of type '", expected->repr_str(), "' for field '", diff --git a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp index e1d67f3b088b..4d6153d733a6 100644 --- a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp +++ b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp @@ -114,7 +114,7 @@ struct CudaGraphFuser { value_list tensorInputs(Node* node) { return filter(node->inputs(), [](Value* v) { - return v->type()->isSubtypeOf(TensorType::get()); + return v->type()->isSubtypeOf(*TensorType::get()); }); } @@ -210,7 +210,7 @@ struct CudaGraphFuser { size_t tensor_insert_idx = 0; for (auto input : group->inputs()) { inputs_map[input] = subgraph.inputs()[i++]; - if (input->type()->isSubtypeOf(TensorType::get())) + if (input->type()->isSubtypeOf(*TensorType::get())) tensor_insert_idx = i; } // add n's inputs to the fusion group's input list if we don't already have @@ -222,7 +222,7 @@ struct CudaGraphFuser { if (inputs_map.count(input) == 0) { // TODO: we are following the convention for no good reason; // we don't need tensor to come before any other inputs. - if (input->type()->isSubtypeOf(TensorType::get())) { + if (input->type()->isSubtypeOf(*TensorType::get())) { auto in_group = subgraph.insertInput(tensor_insert_idx); in_group->setType(input->type()); inputs_map[input] = in_group; @@ -230,7 +230,7 @@ struct CudaGraphFuser { tensor_insert_idx++; } else if ( // TODO: extend the supporting inputs here. - (input->type()->isSubtypeOf(FloatType::get()) && + (input->type()->isSubtypeOf(*FloatType::get()) && input->node()->kind() != prim::Constant)) { auto in_group = subgraph.addInput(); in_group->setType(input->type()); @@ -440,7 +440,7 @@ struct CudaGraphFuser { // Replace tensors inputs with broadcasted values auto new_tensors_it = new_tensors.begin(); for (size_t i = 0; i < node->inputs().size(); ++i) { - if (node->inputs()[i]->type()->isSubtypeOf(TensorType::get())) { + if (node->inputs()[i]->type()->isSubtypeOf(*TensorType::get())) { AT_ASSERT(new_tensors_it != new_tensors.end()); node->replaceInput(i, *(new_tensors_it++)); } @@ -595,7 +595,7 @@ struct CudaGraphFuser { // XXX: we only work with pointwise ops in here, so we know it is valid to // push the concat only through tensor arguments (and all other args can // be safely ignored). - if (!input->type()->isSubtypeOf(TensorType::get())) + if (!input->type()->isSubtypeOf(*TensorType::get())) continue; // if 'input' is already an input to the bchunk, reuse it. @@ -688,7 +688,7 @@ struct CudaGraphFuser { chunked_op->output()->setType(chunk_sel->type()); auto chunked_inputs_it = chunked_inputs.begin(); for (Value* original_input : original_inputs) { - if (original_input->type()->isSubtypeOf(TensorType::get())) { + if (original_input->type()->isSubtypeOf(*TensorType::get())) { AT_ASSERT(chunked_inputs_it != chunked_inputs.end()); chunked_op->addInput( // NOLINTNEXTLINE(clang-analyzer-core.DivideZero) @@ -720,7 +720,7 @@ struct CudaGraphFuser { if (!size_calc_uses.empty()) { auto tensor_inputs = filter( producer_for_chunk_node->inputs(), - [](Value* v) { return v->type()->isSubtypeOf(TensorType::get()); }); + [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); }); auto tensor_sizes = fmap(tensor_inputs, [](Value* v) { return v->owningGraph()->insert(aten::size, {v}); }); @@ -824,7 +824,7 @@ struct CudaGraphFuser { auto sinputs = subgraph->inputs(); AT_ASSERT(inputs.size() == sinputs.size()); for (const auto i : c10::irange(inputs.size())) { - if (inputs[i]->type()->isSubtypeOf(TensorType::get())) { + if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) { shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]}); } } @@ -963,7 +963,7 @@ struct CudaGraphFuser { continue; } auto tensor_inputs = filter(n->inputs(), [](Value* v) { - return v->type()->isSubtypeOf(TensorType::get()); + return v->type()->isSubtypeOf(*TensorType::get()); }); auto shapes = fmap(tensor_inputs, [&](Value* v) { TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/partition.cpp b/torch/csrc/jit/codegen/cuda/partition.cpp index 3167c27561d6..47168d2bea6b 100644 --- a/torch/csrc/jit/codegen/cuda/partition.cpp +++ b/torch/csrc/jit/codegen/cuda/partition.cpp @@ -34,7 +34,7 @@ bool hasNonElementWiseOperation(const Node* node) { // 2. on the same device; // TODO: update this when codegen can output scalar static c10::optional getDevice(const Value* value) { - if (!value->type()->isSubtypeOf(TensorType::get())) { + if (!value->type()->isSubtypeOf(*TensorType::get())) { // not tensor type, return false as the op is not outputing scalar. return c10::nullopt; } diff --git a/torch/csrc/jit/codegen/cuda/shape_inference.cpp b/torch/csrc/jit/codegen/cuda/shape_inference.cpp index 4a646362620c..fd433c472d8d 100644 --- a/torch/csrc/jit/codegen/cuda/shape_inference.cpp +++ b/torch/csrc/jit/codegen/cuda/shape_inference.cpp @@ -44,7 +44,7 @@ class NaiveTypePropagator { switch (node->kind()) { // Constant: case prim::Constant: { - if (node->output()->type()->isSubtypeOf(TensorType::get())) { + if (node->output()->type()->isSubtypeOf(*TensorType::get())) { node->output()->inferTypeFrom(node->t(attr::value)); } break; diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index 3ffdf5a93970..407f0e195fbf 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -123,7 +123,7 @@ static std::vector getInputDependencies(const Value* output) { // This needs to be revisited when you start allowing // other things e.g. nonconstant scalars. if (producer->kind() == prim::Param && - val->type()->isSubtypeOf(TensorType::get())) { + val->type()->isSubtypeOf(*TensorType::get())) { inputs.insert(val); continue; } @@ -230,10 +230,10 @@ std::shared_ptr compileKernel( { size_t input_index = 0; for (const auto& p : graph->inputs()) { - if (p->type()->isSubtypeOf(FloatType::get())) { + if (p->type()->isSubtypeOf(*FloatType::get())) { flat_inputs.emplace_back(p, c10::nullopt); } - if (!p->type()->isSubtypeOf(TensorType::get())) { + if (!p->type()->isSubtypeOf(*TensorType::get())) { continue; } if (const Node* chunk = usedInFusedChunk(p)) { diff --git a/torch/csrc/jit/codegen/fuser/kernel_spec.h b/torch/csrc/jit/codegen/fuser/kernel_spec.h index d1f5f5c3fc08..f54812c210ef 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_spec.h +++ b/torch/csrc/jit/codegen/fuser/kernel_spec.h @@ -77,7 +77,7 @@ struct TORCH_API KernelSpec { } nTensorInputs_ = std::count_if( graph_->inputs().begin(), graph_->inputs().end(), [](const Value* v) { - return v->type()->isSubtypeOf(TensorType::get()); + return v->type()->isSubtypeOf(*TensorType::get()); }); } diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 52ad5fd47d5b..a6b4c996c8de 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -385,7 +385,7 @@ struct Environment { as_simple_value, /*allow_conversions=*/true); std::stringstream why_not; - if (!as_simple_value->type()->isSubtypeOfExt(parent_type, &why_not)) { + if (!as_simple_value->type()->isSubtypeOfExt(*parent_type, &why_not)) { auto error = ErrorReport(loc); error << "Variable '" << name << "' previously had type " << simple_parent->type()->repr_str() @@ -406,7 +406,7 @@ struct Environment { } if (as_simple_value) { if (annotated_type && - !as_simple_value->type()->isSubtypeOf(annotated_type)) { + !as_simple_value->type()->isSubtypeOf(*annotated_type)) { throw ErrorReport(loc) << "Variable '" << name << "' is annotated with type " << annotated_type->repr_str() @@ -603,8 +603,8 @@ static Value* materializeConstant( } inline bool isSupportedListElementType(const TypePtr& type) { - return type->isSubtypeOf(TensorType::get()) || - type->isSubtypeOf(NumberType::get()); + return type->isSubtypeOf(*TensorType::get()) || + type->isSubtypeOf(*NumberType::get()); } // Information for each def being emitted. @@ -1023,8 +1023,8 @@ struct to_ir { // this guard skips implicit conversion from None -> Tensor for the return // type. otherwise forgetting a return a function returning a tensor will // cause a None to be converted to a tensor. - if (!(actual_return->type()->isSubtypeOf(TensorType::get()) && - actual_return->type()->isSubtypeOf(NoneType::get()))) { + if (!(actual_return->type()->isSubtypeOf(*TensorType::get()) && + actual_return->type()->isSubtypeOf(*NoneType::get()))) { actual_return = tryConvertToType( stmt.range(), *graph, @@ -1032,7 +1032,7 @@ struct to_ir { actual_return, /*allow_conversions=*/true); } - if (!actual_return->type()->isSubtypeOf(declared_return_type)) { + if (!actual_return->type()->isSubtypeOf(*declared_return_type)) { throw ErrorReport(stmt.range()) << "Return value was annotated as having type " << declared_return_type->repr_str() << " but is actually of type " @@ -1423,8 +1423,8 @@ struct to_ir { if (all_candidates.empty() && refined_type_hint && !(*unified_elem_type) - ->isSubtypeOf( - refined_type_hint->expect()->getElementType())) { + ->isSubtypeOf(*refined_type_hint->expectRef() + .getElementType())) { throw ErrorReport(lc) << "List type annotation `" << refined_type_hint->repr_str() << "` did not match the types of the given list elements," @@ -1439,7 +1439,7 @@ struct to_ir { [&](TypePtr candidate) { auto candidate_elem_type = candidate->expect()->getElementType(); - if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) { + if ((*unified_elem_type)->isSubtypeOf(*candidate_elem_type)) { if (!greatest_elem_type) { greatest_elem_type = candidate_elem_type; } else { @@ -1582,7 +1582,7 @@ struct to_ir { std::stringstream err; bool is_key_subtype = - k->type()->isSubtypeOfExt(dict_type_hint->getKeyType(), &ss); + k->type()->isSubtypeOfExt(*dict_type_hint->getKeyType(), &ss); if (!is_key_subtype) { err << "Dict type annotation `" << dict_type_hint->repr_str() @@ -1594,7 +1594,7 @@ struct to_ir { ss.str(std::string()); bool is_value_subtype = - v->type()->isSubtypeOfExt(dict_type_hint->getValueType(), &ss); + v->type()->isSubtypeOfExt(*dict_type_hint->getValueType(), &ss); if (!is_value_subtype) { err << "Dict type annotation `" << dict_type_hint->repr_str() @@ -1651,11 +1651,11 @@ struct to_ir { current_candidate->expect()->getKeyType(); auto current_value_type = current_candidate->expect()->getValueType(); - if (known_key_type->isSubtypeOf(current_key_type) && - known_value_type->isSubtypeOf(current_value_type)) { + if (known_key_type->isSubtypeOf(*current_key_type) && + known_value_type->isSubtypeOf(*current_value_type)) { if (!candidate || - (candidate_key_type->isSubtypeOf(current_key_type) && - candidate_value_type->isSubtypeOf(current_value_type))) { + (candidate_key_type->isSubtypeOf(*current_key_type) && + candidate_value_type->isSubtypeOf(*current_value_type))) { candidate_key_type = current_key_type; candidate_value_type = current_value_type; candidate = current_candidate; @@ -1819,7 +1819,7 @@ struct to_ir { << v->type()->repr_str() << " to bool"; } // cast value not response for checking output type - if (!out->type()->isSubtypeOf(BoolType::get())) { + if (!out->type()->isSubtypeOf(*BoolType::get())) { throw ErrorReport(loc) << "expected a bool expression for condition but found " << out->type()->repr_str(); @@ -2088,10 +2088,11 @@ struct to_ir { break; } - auto get_smaller_type = [&](TypePtr t1, TypePtr t2) -> TypePtr { - if (t1->isSubtypeOf(t2)) { + auto get_smaller_type = [&](const TypePtr& t1, + const TypePtr& t2) -> TypePtr { + if (t1->isSubtypeOf(*t2)) { return t1; - } else if (t2->isSubtypeOf(t1)) { + } else if (t2->isSubtypeOf(*t1)) { return t2; } else { return nullptr; @@ -2416,7 +2417,7 @@ struct to_ir { << "exceptions must derive from BaseException"; } - if (!error_message->type()->isSubtypeOf(StringType::get())) { + if (!error_message->type()->isSubtypeOf(*StringType::get())) { error_message = graph->insert(aten::str, {error_message}); } @@ -2487,7 +2488,7 @@ struct to_ir { // If the RHS is a tensor, return the corresponding ATen in-place op // If it's a list of scalars, then return the corresponding list augment op Symbol getAugOp(const AugAssign& stmt, const TypePtr& type) { - bool use_inplace_op = type->isSubtypeOf(TensorType::get()) || + bool use_inplace_op = type->isSubtypeOf(*TensorType::get()) || type->kind() == TypeKind::ListType; switch (stmt.aug_op()) { case '+': @@ -2678,7 +2679,7 @@ struct to_ir { const auto lhs = Subscript(stmt.lhs()); const auto sliceable = emitExpr(lhs.value()); - if (sliceable->type()->isSubtypeOf(TensorType::get())) { + if (sliceable->type()->isSubtypeOf(*TensorType::get())) { // If it's a tensor, just fully evaluate the subscript operation and emit // an in-place assignment std::vector tensorIndices; @@ -2765,7 +2766,7 @@ struct to_ir { auto sliceable = emitExpr(lhs.value()); // If it's a tensor, copy the RHS data into it - if (sliceable->type()->isSubtypeOf(TensorType::get())) { + if (sliceable->type()->isSubtypeOf(*TensorType::get())) { std::vector tensorIndices; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Value* sliced; @@ -2810,7 +2811,7 @@ struct to_ir { << " subscripted assignment. " << "File a bug if you want this"; } - if (sliceable->type()->isSubtypeOf(AnyTupleType::get())) { + if (sliceable->type()->isSubtypeOf(*AnyTupleType::get())) { throw ErrorReport(lhs) << sliceable->type()->repr_str() << " does not support subscripted assignment"; } @@ -3270,7 +3271,7 @@ struct to_ir { /*allow_conversions=*/true); std::stringstream why_not; - if (!expr->type()->isSubtypeOfExt(type, &why_not)) { + if (!expr->type()->isSubtypeOfExt(*type, &why_not)) { throw ErrorReport(apply.inputs()) << "expected an expression of type " << type->repr_str() << " but found " << expr->type()->repr_str() << "\n" @@ -3284,7 +3285,7 @@ struct to_ir { if ((type->kind() == OptionalType::Kind || (type->kind() == UnionType::Kind && type->expect()->canHoldType(NoneType::get()))) && - expr->type()->isSubtypeOf(NoneType::get())) { + expr->type()->isSubtypeOf(*NoneType::get())) { Node* none = graph->createNone(); none->output()->setType(type); graph->insertNode(none); @@ -4094,7 +4095,7 @@ struct to_ir { } if (all_candidates.empty() && refined_type_hint && - !(*unified_elem_type)->isSubtypeOf(inferred_elem_type)) { + !(*unified_elem_type)->isSubtypeOf(*inferred_elem_type)) { throw ErrorReport(ll) << "List type annotation `" << refined_type_hint->repr_str() << "` did not match the types of the given list elements," @@ -4109,7 +4110,7 @@ struct to_ir { [&](TypePtr candidate) { auto candidate_elem_type = candidate->expect()->getElementType(); - if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) { + if ((*unified_elem_type)->isSubtypeOf(*candidate_elem_type)) { if (!greatest_elem_type) { greatest_elem_type = candidate_elem_type; } else { @@ -4349,8 +4350,8 @@ struct to_ir { for (const auto i : c10::irange(keys.size())) { std::stringstream ss; - if (!keys[i]->type()->isSubtypeOfExt(first_key_type, &ss) && - !first_key_type->isSubtypeOfExt(keys[i]->type(), &ss)) { + if (!keys[i]->type()->isSubtypeOfExt(*first_key_type, &ss) && + !first_key_type->isSubtypeOfExt(*keys[i]->type(), &ss)) { throw ErrorReport(key_trees[i]) << "Dict keys must contain " << "only a single type. Expected: " @@ -4386,11 +4387,11 @@ struct to_ir { current_candidate->expect()->getKeyType(); auto current_value_type = current_candidate->expect()->getValueType(); - if (known_key_type->isSubtypeOf(current_key_type) && - known_value_type->isSubtypeOf(current_value_type)) { + if (known_key_type->isSubtypeOf(*current_key_type) && + known_value_type->isSubtypeOf(*current_value_type)) { if (!candidate || - (candidate_key_type->isSubtypeOf(current_key_type) && - candidate_value_type->isSubtypeOf(current_value_type))) { + (candidate_key_type->isSubtypeOf(*current_key_type) && + candidate_value_type->isSubtypeOf(*current_value_type))) { candidate_key_type = current_key_type; candidate_value_type = current_value_type; candidate = current_candidate; @@ -4439,7 +4440,7 @@ struct to_ir { refined_type_hint->expect()->getValueType(); for (const auto i : c10::irange(value_types.size())) { TORCH_CHECK( - value_types[i]->isSubtypeOf(value_type_hint), + value_types[i]->isSubtypeOf(*value_type_hint), "Type " "hint for dict was ", refined_type_hint->repr_str(), @@ -4525,11 +4526,11 @@ struct to_ir { // XXX: If list slicing becomes more complicated or stops using // aten::slice, we should separate it from this function. if (dim) { - AT_ASSERT(sliceable->type()->isSubtypeOf(TensorType::get())); + AT_ASSERT(sliceable->type()->isSubtypeOf(*TensorType::get())); args.emplace_back(dim); } else { - AT_ASSERT(!sliceable->type()->isSubtypeOf(TensorType::get())); + AT_ASSERT(!sliceable->type()->isSubtypeOf(*TensorType::get())); } if (sliceable->type()->cast()) { @@ -4691,7 +4692,7 @@ struct to_ir { } exprs[expr_idx] = index; - if (index->type()->isSubtypeOf(NoneType::get())) { + if (index->type()->isSubtypeOf(*NoneType::get())) { if (is_reverse) { return dim; } else { @@ -4703,7 +4704,7 @@ struct to_ir { } else { return dim; } - } else if (index->type()->isSubtypeOf(OptionalType::ofTensor())) { + } else if (index->type()->isSubtypeOf(*OptionalType::ofTensor())) { if (is_reverse) { throw ErrorReport(loc) << "Ellipses followed by tensor indexing is currently not supported"; @@ -4768,13 +4769,13 @@ struct to_ir { continue; } auto expr = exprs[i].value(); - if (expr->type()->isSubtypeOf(NoneType::get())) { + if (expr->type()->isSubtypeOf(*NoneType::get())) { sliceable = emitUnsqueeze(loc, sliceable, insert_value_for_dim(dims[i])); } else if (expr->type() == IntType::get()) { sliceable = emitSelect(loc, sliceable, insert_value_for_dim(dims[i]), expr); - } else if (expr->type()->isSubtypeOf(OptionalType::ofTensor())) { + } else if (expr->type()->isSubtypeOf(*OptionalType::ofTensor())) { tensor_indices.resize(dims[i] + 1); tensor_indices[dims[i]] = expr; } else { @@ -4814,7 +4815,7 @@ struct to_ir { const SourceRange& loc, Value* sliceable, const List& subscript_exprs) { - if (!sliceable->type()->isSubtypeOf(TensorType::get())) { + if (!sliceable->type()->isSubtypeOf(*TensorType::get())) { throw ErrorReport(loc) << "Unsupported operation: attempted to use multidimensional " << "indexing on a non-tensor type"; @@ -4842,7 +4843,7 @@ struct to_ir { AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR); auto slice_exp = SliceExpr(subscript_exprs[0]); Value* maybe_dim = nullptr; - if (sliceable->type()->isSubtypeOf(TensorType::get())) { + if (sliceable->type()->isSubtypeOf(*TensorType::get())) { // If the sliceable object is a tensor, specify a default dimension maybe_dim = graph->insertConstant(0, loc); } @@ -5009,7 +5010,7 @@ struct to_ir { dynamic_cast(subscript_sv.get())) { Value* dim = nullptr; // aten::slice.tensor needs an additional `dim` input. - if (sliceable->type()->isSubtypeOf(TensorType::get())) { + if (sliceable->type()->isSubtypeOf(*TensorType::get())) { dim = method.graph()->insertConstant(0, val_range); } @@ -5030,7 +5031,7 @@ struct to_ir { if (sliceable->type()->cast()) { return std::make_shared( emitTupleIndex(range, sv->asValue(val_range, method), idx)); - } else if (sliceable->type()->isSubtypeOf(TensorType::get())) { + } else if (sliceable->type()->isSubtypeOf(*TensorType::get())) { return std::make_shared( emitMultidimSlicing(range, sliceable, subscript_exprs)); } else { diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index 8f75e7c73491..b4c0b1af4868 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -30,21 +30,21 @@ static inline bool isIntOrFloatUsedAsList( /// Returns true if `type` is a Tuple in which all the elements have the /// same type or if it's a subtype of `list_type_`. -inline bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) { - auto list_type = list_type_->cast(); +bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) { + auto list_type = list_type_->castRaw(); if (!list_type) { return false; } - if (type->isSubtypeOf(list_type_)) { + if (type->isSubtypeOf(*list_type_)) { return true; } - if (auto tuple = type->cast()) { + if (auto tuple = type->castRaw()) { return std::all_of( tuple->elements().begin(), tuple->elements().end(), [&](const TypePtr& t) { // TODO: resolve VarType if necessary - return t->isSubtypeOf(list_type->getElementType()); + return t->isSubtypeOf(*list_type->getElementType()); }); } return false; @@ -61,7 +61,7 @@ Value* tryConvertToType( // treat conversion to Optional[T] as conversions to T if (OptionalTypePtr op = concrete_type->cast()) { if (value->type()->kind() != OptionalType::Kind && - !value->type()->isSubtypeOf(NoneType::get())) { + !value->type()->isSubtypeOf(*NoneType::get())) { return tryConvertToType( loc, graph, op->getElementType(), value, allow_conversions); } @@ -79,7 +79,7 @@ Value* tryConvertToType( // inductively apply implicit conversions to tuples if (auto concrete_tuple = concrete_type->cast()) { - if (!value_tuple->isSubtypeOf(concrete_tuple) && + if (!value_tuple->isSubtypeOf(*concrete_tuple) && concrete_tuple->elements().size() == value_tuple->elements().size()) { auto unpacked = createTupleUnpack(value); std::vector converted; @@ -99,7 +99,7 @@ Value* tryConvertToType( // implicit conversions if (allow_conversions) { // Convert tensor or number to concrete int/float types - bool value_isa_tensor = value->type()->isSubtypeOf(TensorType::get()); + bool value_isa_tensor = value->type()->isSubtypeOf(*TensorType::get()); bool value_equals_number = *value->type() == *NumberType::get(); bool concrete_float = *concrete_type == *FloatType::get(); bool concrete_complex = *concrete_type == *ComplexType::get(); @@ -126,8 +126,8 @@ Value* tryConvertToType( } // Convert strings to device - if (value->type()->isSubtypeOf(StringType::get()) && - concrete_type->isSubtypeOf(DeviceObjType::get())) { + if (value->type()->isSubtypeOf(*StringType::get()) && + concrete_type->isSubtypeOf(*DeviceObjType::get())) { return graph.insert(aten::device, {value}, {}, loc); } } @@ -185,7 +185,7 @@ static Value* tryMatchArgument( value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions); std::stringstream ss; if (!value->type()->isSubtypeOfExt( - concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) { + *concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) { if (failure_messages) { auto& ostream = err() << arg.formatTypeMismatchMsg(value->type()->repr_str()); @@ -203,7 +203,7 @@ static Value* tryMatchArgument( } if (auto v = value->type()->cast()) { - if (v->getElementType()->isSubtypeOf(TensorType::get())) { + if (v->getElementType()->isSubtypeOf(*TensorType::get())) { ostream << "Empty lists default to List[Tensor]. Add a variable " "annotation to the assignment to create an empty list " "of another type (torch.jit.annotate(List[T, []]) where T " diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index dbb4b55e33bb..a8d81bf991f2 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -88,7 +88,7 @@ std::shared_ptr SimpleValue::attr( Function& m, const std::string& field) { // Allow method-style casts on Tensor types. e.g. x.int() - if (value_->type()->isSubtypeOf(TensorType::get())) { + if (value_->type()->isSubtypeOf(*TensorType::get())) { if (builtin_cast_method_to_scalar_type().count(field)) { return std::make_shared( builtin_cast_method_to_scalar_type().at(field), @@ -202,7 +202,7 @@ std::shared_ptr SimpleValue::attr( } // Handle calling tolist() on a Tensor. - if (value_->type()->isSubtypeOf(TensorType::get()) && field == "tolist") { + if (value_->type()->isSubtypeOf(*TensorType::get()) && field == "tolist") { return SpecialFormValue::create(prim::tolist); } @@ -252,7 +252,7 @@ std::vector> SimpleValue::asTuple( } static bool isRecursive(const TypePtr& classType, const TypePtr& attrType) { - if (attrType->isSubtypeOf(classType)) { + if (attrType->isSubtypeOf(*classType)) { return true; } @@ -334,7 +334,7 @@ void SimpleValue::setAttr( // Check type correctness const auto newType = newValue->type(); - if (!newType->isSubtypeOf(expectedType)) { + if (!newType->isSubtypeOf(*expectedType)) { throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected " << expectedType->repr_str() << " but got " << newType->repr_str(); @@ -390,7 +390,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) { TypePtr val_type = val->type(); Graph& g = *m.graph(); if (val_type->cast() || val_type->cast() || - val_type->isSubtypeOf(TensorType::get())) { + val_type->isSubtypeOf(*TensorType::get())) { return g.insert(aten::len, {val}, {}, loc); } else { throw ErrorReport(loc) << "'" << val_type->repr_str() << "'" @@ -415,7 +415,7 @@ SugaredValuePtr SimpleValue::getitem( } else if (auto dict_type = val_type->cast()) { return std::make_shared( g.insert(aten::__getitem__, {val, idx}, {}, loc)); - } else if (val_type->isSubtypeOf(TensorType::get())) { + } else if (val_type->isSubtypeOf(*TensorType::get())) { return std::make_shared( g.insert(aten::select, {val, 0, idx}, {}, loc)); } else if (auto class_type = val_type->cast()) { @@ -702,7 +702,7 @@ std::shared_ptr BuiltinFunction::tryCreate( continue; } const auto concrete_type = tryEvalTypeVariables(formal_type, type_env); - if (!concrete_type || !self->type()->isSubtypeOf(concrete_type)) { + if (!concrete_type || !self->type()->isSubtypeOf(*concrete_type)) { continue; } return std::make_shared(symbol, self); diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index d43198951fce..374d692559cc 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -517,12 +517,12 @@ struct TORCH_API CastValue : public BuiltinFunction { auto zero = m.graph()->insertConstant(0); auto v = args[0].value(*m.graph()); - if (v->type()->isSubtypeOf(type_)) { + if (v->type()->isSubtypeOf(*type_)) { return std::make_shared(v); } else if ( *type_ == *BoolType::get() && - (v->type()->isSubtypeOf(AnyListType::get()) || - v->type()->isSubtypeOf(StringType::get()) || + (v->type()->isSubtypeOf(*AnyListType::get()) || + v->type()->isSubtypeOf(*StringType::get()) || v->type()->cast())) { auto len = len_op->call(loc, m, {v}, {}, 1); return gt_op->call(loc, m, {len->asValue(loc, m), zero}, {}, 1); @@ -766,7 +766,7 @@ struct TORCH_API ExceptionValue : public SugaredValue { auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc); for (auto& input : args) { auto input_str = input.value(*m.graph()); - if (!input_str->type()->isSubtypeOf(StringType::get())) { + if (!input_str->type()->isSubtypeOf(*StringType::get())) { input_str = emitBuiltinCall(loc, *m.graph(), aten::str, {input_str}, {}); } diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index d0e5edd9bc39..cb814b1cdb36 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -285,15 +285,15 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) { TypePtr key_type = dict.keyType(); TypePtr value_type = dict.valueType(); - bool key_type_valid = key_type->isSubtypeOf(StringType::get()) || - key_type->isSubtypeOf(TensorType::get()); - bool value_type_valid = value_type->isSubtypeOf(TensorType::get()); + bool key_type_valid = key_type->isSubtypeOf(*StringType::get()) || + key_type->isSubtypeOf(*TensorType::get()); + bool value_type_valid = value_type->isSubtypeOf(*TensorType::get()); // Support tuple values that contain only tensors - if (value_type->isSubtypeOf(AnyTupleType::get())) { + if (value_type->isSubtypeOf(*AnyTupleType::get())) { value_type_valid = true; for (const auto& type : value_type->containedTypes()) { - if (!type->isSubtypeOf(TensorType::get())) { + if (!type->isSubtypeOf(*TensorType::get())) { value_type_valid = false; break; } @@ -330,7 +330,7 @@ static IValue addInput( const TypePtr& type, Value* value) { value->setType(type); - if (type->isSubtypeOf(TensorType::get())) { + if (type->isSubtypeOf(*TensorType::get())) { auto input_tensor = input.toTensor(); auto name = Variable(input_tensor).name(); if (state->hasValue(input)) { @@ -426,7 +426,7 @@ static void gatherParametersAndBuffers( ->s_(attr::scope, qualname) ->output() ->setType(s.value.type()); - if (s.value.type()->isSubtypeOf(TensorType::get())) { + if (s.value.type()->isSubtypeOf(*TensorType::get())) { addInput(state, s.value, s.value.type(), trace_get_attr); } if (isCustomClass(s.value)) { diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp index 8a29b2f79698..bdd11ed9eaa5 100644 --- a/torch/csrc/jit/ir/constants.cpp +++ b/torch/csrc/jit/ir/constants.cpp @@ -153,20 +153,20 @@ c10::optional toIValue(const Value* v) { } const Node* node = v->node(); const TypePtr& type = v->type(); - if (type->isSubtypeOf(TensorType::get())) { + if (type->isSubtypeOf(*TensorType::get())) { return node->t(attr::value); - } else if (type->isSubtypeOf(BoolType::get())) { + } else if (type->isSubtypeOf(*BoolType::get())) { return (bool)node->i(attr::value); } else if ( - type->isSubtypeOf(NumberType::get()) && + type->isSubtypeOf(*NumberType::get()) && node->kindOf(attr::value) == AttributeKind::i) { return node->i(attr::value); } else if ( - type->isSubtypeOf(NumberType::get()) && + type->isSubtypeOf(*NumberType::get()) && node->kindOf(attr::value) == AttributeKind::f) { return node->f(attr::value); } else if ( - type->isSubtypeOf(NumberType::get()) && + type->isSubtypeOf(*NumberType::get()) && node->kindOf(attr::value) == AttributeKind::c) { return node->c(attr::value); } else if ( diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index b9f9833a4088..fc57acbec9d2 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1001,7 +1001,7 @@ bool Node::matches(const FunctionSchema& schema) const { // we will not succeed at matching T. However None <: Optional[T] so this // check can still succeed. - if (!actuals[i]->type()->isSubtypeOf(formal)) { + if (!actuals[i]->type()->isSubtypeOf(*formal)) { return false; } } @@ -1773,7 +1773,7 @@ Node* Graph::createList( auto n = create(prim::ListConstruct, values); for (const auto& v : values) { TORCH_CHECK( - v->type()->isSubtypeOf(contained_type), + v->type()->isSubtypeOf(*contained_type), "Expected a list element that subtypes '", contained_type->repr_str(), "' but got an element of type '", @@ -1803,8 +1803,8 @@ Node* Graph::createDict( AT_ASSERT(keys.size() == values.size()); auto n = create(prim::DictConstruct, 1); for (const auto i : c10::irange(keys.size())) { - AT_ASSERT(keys[i]->type()->isSubtypeOf(key_type)); - AT_ASSERT(values[i]->type()->isSubtypeOf(value_type)); + AT_ASSERT(keys[i]->type()->isSubtypeOf(*key_type)); + AT_ASSERT(values[i]->type()->isSubtypeOf(*value_type)); n->addInput(keys[i]); n->addInput(values[i]); diff --git a/torch/csrc/jit/ir/irparser.cpp b/torch/csrc/jit/ir/irparser.cpp index b1f22fdd00db..db6a7660f1f0 100644 --- a/torch/csrc/jit/ir/irparser.cpp +++ b/torch/csrc/jit/ir/irparser.cpp @@ -406,7 +406,7 @@ void IRParser::parseOperator(Block* b) { // Don't currently support checking against type variables // TODO: support? if (!schema_return_type->hasFreeVariables() && - !v.type->isSubtypeOf(schema_return_type)) { + !v.type->isSubtypeOf(*schema_return_type)) { throw ErrorReport(source_range) << "Annotated type " << v.type->repr_str() << " does not match schema type " diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp index 277b66669c74..37925ec523a8 100644 --- a/torch/csrc/jit/ir/node_hashing.cpp +++ b/torch/csrc/jit/ir/node_hashing.cpp @@ -209,18 +209,18 @@ size_t HashNode::operator()(const Node* k) const { size_t constant_hash = 0; if (k->kind() == prim::Constant) { TypePtr type = k->output()->type(); - if (type->isSubtypeOf(NumberType::get()) && + if (type->isSubtypeOf(*NumberType::get()) && k->kindOf(attr::value) == AttributeKind::i) { constant_hash = std::hash{}(k->i(attr::value)); } else if ( - type->isSubtypeOf(NumberType::get()) && + type->isSubtypeOf(*NumberType::get()) && k->kindOf(attr::value) == AttributeKind::f) { constant_hash = std::hash{}(k->f(attr::value)); } else if ( - type->isSubtypeOf(NumberType::get()) && + type->isSubtypeOf(*NumberType::get()) && k->kindOf(attr::value) == AttributeKind::c) { constant_hash = c10::hash>{}(k->c(attr::value)); - } else if (type->isSubtypeOf(BoolType::get())) { + } else if (type->isSubtypeOf(*BoolType::get())) { constant_hash = std::hash{}(k->i(attr::value)); } } diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index 5aaec866f684..90bc11587a8f 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -35,7 +35,7 @@ void createObject(Stack& stack, const at::ClassTypePtr& type) { void isinstance(Stack& stack, at::ArrayRef types) { at::TypePtr ty = pop(stack).type(); for (const at::TypePtr& candidate : types) { - if (ty->isSubtypeOf(candidate)) { + if (ty->isSubtypeOf(*candidate)) { push(stack, true); return; } diff --git a/torch/csrc/jit/passes/clear_profiling.cpp b/torch/csrc/jit/passes/clear_profiling.cpp index 9acb9fbc3129..405b110b2efe 100644 --- a/torch/csrc/jit/passes/clear_profiling.cpp +++ b/torch/csrc/jit/passes/clear_profiling.cpp @@ -8,7 +8,7 @@ namespace jit { static void unprofileGraphInputs(const std::shared_ptr& graph) { for (auto i : graph->inputs()) { - if (i->type()->isSubtypeOf(TensorType::get())) { + if (i->type()->isSubtypeOf(*TensorType::get())) { i->setType(unshapedType(i->type())); } } @@ -24,7 +24,7 @@ static void unprofileBlock(Block* start_block) { for (auto n : block->nodes()) { for (auto o : n->outputs()) { - if (o->type()->isSubtypeOf(TensorType::get())) { + if (o->type()->isSubtypeOf(*TensorType::get())) { o->setType(unshapedType(o->type())); } } diff --git a/torch/csrc/jit/passes/decompose_ops.cpp b/torch/csrc/jit/passes/decompose_ops.cpp index 0706c9c14ae9..836367ad1c8a 100644 --- a/torch/csrc/jit/passes/decompose_ops.cpp +++ b/torch/csrc/jit/passes/decompose_ops.cpp @@ -21,7 +21,7 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() { // statically defined (neither a None constant nor a Optional[Tensor] type) // return yes, no, or no value if we can't tell c10::optional isDefined(Value* tensor) { - if (tensor->type()->isSubtypeOf(TensorType::get())) { + if (tensor->type()->isSubtypeOf(*TensorType::get())) { return true; } if (tensor->node()->mustBeNone()) { @@ -36,7 +36,7 @@ bool isDecomposableNorm(Node* normalize_op) { "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor", }; Value* input = normalize_op->namedInput(attr::input); - if (!input->type()->isSubtypeOf(TensorType::get())) { + if (!input->type()->isSubtypeOf(*TensorType::get())) { return false; } auto device = input->type()->expectRef().device(); diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 2cd39aaf1a00..f365cdedd94d 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -8,9 +8,9 @@ namespace torch { namespace jit { void SetNumTypeToTensorType(Value* v) { - if (v->type()->isSubtypeOf(NumberType::get())) { + if (v->type()->isSubtypeOf(*NumberType::get())) { v->setType(TensorType::fromNumberType(v->type())); - } else if (v->type()->isSubtypeOf(BoolType::get())) { + } else if (v->type()->isSubtypeOf(*BoolType::get())) { v->setType(TensorType::fromBoolType()); } } @@ -28,10 +28,10 @@ void EraseNumberTypesOnBlock(Block* block) { case prim::Constant: { // remove primitive constants, replacing with tensor equivalent // ONNX does not support non-tensor constants - if (it->output()->type()->isSubtypeOf(NumberType::get()) || - it->output()->type()->isSubtypeOf(BoolType::get())) { + if (it->output()->type()->isSubtypeOf(*NumberType::get()) || + it->output()->type()->isSubtypeOf(*BoolType::get())) { at::Scalar s; - if (it->output()->type()->isSubtypeOf(BoolType::get())) { + if (it->output()->type()->isSubtypeOf(*BoolType::get())) { s = *constant_as(it->output()); } else { s = *constant_as(it->output()); diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index d7f89e3e67d3..df684e14d4f2 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -543,7 +543,7 @@ class AttributePropagator { bool moduleEscapes(Module& subModule, std::shared_ptr& graph) { for (auto& output : graph->outputs()) { - if (subModule.type()->isSubtypeOf(output->type())) { + if (subModule.type()->isSubtypeOf(*output->type())) { return true; } } diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp index 542e13628052..e8bac2412bda 100644 --- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp +++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp @@ -755,7 +755,7 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) { body_node->replaceInput(1, node->outputs().at(1)); } if (body_node->kind() == aten::mul && - body_node->input(1)->type()->isSubtypeOf(NumberType::get())) { + body_node->input(1)->type()->isSubtypeOf(*NumberType::get())) { body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNScalarMul")); body_node->destroy(); continue; @@ -1022,7 +1022,7 @@ class MKLDNNSubgraphSlicer { if (n->kind() == aten::mul) { return n->input(0)->type()->cast() && (n->input(1)->type()->cast() || - n->input(1)->type()->isSubtypeOf(NumberType::get())); + n->input(1)->type()->isSubtypeOf(*NumberType::get())); } if (n->kind() == aten::dropout) { diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 653f9fec08b3..3b5dc448eda5 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -111,8 +111,8 @@ bool isSimpleMap(Node* node) { return false; } for (Value* input : node->inputs()) { - if (input->type()->isSubtypeOf(TensorType::get()) || - input->type()->isSubtypeOf(FloatType::get())) { + if (input->type()->isSubtypeOf(*TensorType::get()) || + input->type()->isSubtypeOf(*FloatType::get())) { continue; } if (input->node()->kind() != prim::Constant) { @@ -166,7 +166,7 @@ struct GraphFuser { value_list tensorInputs(Node* node) { return filter(node->inputs(), [](Value* v) { - return v->type()->isSubtypeOf(TensorType::get()); + return v->type()->isSubtypeOf(*TensorType::get()); }); } @@ -175,7 +175,7 @@ struct GraphFuser { } bool isFusableDevice(Value* v, bool strict_fuser_check) { - if (!v->type()->isSubtypeOf(TensorType::get())) { + if (!v->type()->isSubtypeOf(*TensorType::get())) { return true; } auto device = v->type()->expectRef().device(); @@ -332,7 +332,7 @@ struct GraphFuser { AT_ASSERT(group->inputs().size() == subgraph.inputs().size()); for (auto input : group->inputs()) { inputs_map[input] = subgraph.inputs()[i++]; - if (input->type()->isSubtypeOf(TensorType::get())) + if (input->type()->isSubtypeOf(*TensorType::get())) tensor_insert_idx = i; } // add n's inputs to the fusion group's input list if we don't already have @@ -342,17 +342,17 @@ struct GraphFuser { WithInsertPoint guard(*subgraph.nodes().begin()); for (auto input : n->inputs()) { if (inputs_map.count(input) == 0) { - if (input->type()->isSubtypeOf(TensorType::get())) { + if (input->type()->isSubtypeOf(*TensorType::get())) { auto in_group = subgraph.insertInput(tensor_insert_idx); in_group->setType(input->type()); inputs_map[input] = in_group; group->insertInput(tensor_insert_idx, input); tensor_insert_idx++; } else if ( - (input->type()->isSubtypeOf(FloatType::get()) && + (input->type()->isSubtypeOf(*FloatType::get()) && input->node()->kind() != prim::Constant) || (n->kind() == aten::_grad_sum_to_size && - input->type()->isSubtypeOf(ListType::ofInts()))) { + input->type()->isSubtypeOf(*ListType::ofInts()))) { auto in_group = subgraph.addInput(); in_group->setType(input->type()); inputs_map[input] = in_group; @@ -618,7 +618,7 @@ struct GraphFuser { // Replace tensors inputs with broadcasted values auto new_tensors_it = new_tensors.begin(); for (size_t i = 0; i < node->inputs().size(); ++i) { - if (node->inputs()[i]->type()->isSubtypeOf(TensorType::get())) { + if (node->inputs()[i]->type()->isSubtypeOf(*TensorType::get())) { AT_ASSERT(new_tensors_it != new_tensors.end()); node->replaceInput(i, *(new_tensors_it++)); } @@ -768,7 +768,7 @@ struct GraphFuser { // XXX: we only work with pointwise ops in here, so we know it is valid to // push the concat only through tensor arguments (and all other args can // be safely ignored). - if (!input->type()->isSubtypeOf(TensorType::get())) + if (!input->type()->isSubtypeOf(*TensorType::get())) continue; // if 'input' is already an input to the bchunk, reuse it. @@ -813,7 +813,7 @@ struct GraphFuser { chunked_op->output()->setType(chunk_sel->type()); auto chunked_inputs_it = chunked_inputs.begin(); for (Value* original_input : original_inputs) { - if (original_input->type()->isSubtypeOf(TensorType::get())) { + if (original_input->type()->isSubtypeOf(*TensorType::get())) { AT_ASSERT(chunked_inputs_it != chunked_inputs.end()); chunked_op->addInput( // NOLINTNEXTLINE(clang-analyzer-core.DivideZero) @@ -845,7 +845,7 @@ struct GraphFuser { if (!size_calc_uses.empty()) { auto tensor_inputs = filter( producer_for_chunk_node->inputs(), - [](Value* v) { return v->type()->isSubtypeOf(TensorType::get()); }); + [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); }); auto tensor_sizes = fmap(tensor_inputs, [&](Value* v) { Value* output = v->owningGraph()->insert(aten::size, {v}); aliasDb_->createValue(output); @@ -937,7 +937,7 @@ struct GraphFuser { auto sinputs = subgraph->inputs(); AT_ASSERT(inputs.size() == sinputs.size()); for (const auto i : c10::irange(inputs.size())) { - if (inputs[i]->type()->isSubtypeOf(TensorType::get())) { + if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) { Value* soutput = graph->insert(aten::size, {inputs[i]}); aliasDb_->createValue(soutput); shape_of[sinputs[i]] = soutput; @@ -992,7 +992,7 @@ struct GraphFuser { continue; } auto tensor_inputs = filter(n->inputs(), [](Value* v) { - return v->type()->isSubtypeOf(TensorType::get()); + return v->type()->isSubtypeOf(*TensorType::get()); }); auto shapes = fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); }); diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp index adb4b6fa853c..abc7c25738bb 100644 --- a/torch/csrc/jit/passes/guard_elimination.cpp +++ b/torch/csrc/jit/passes/guard_elimination.cpp @@ -234,7 +234,7 @@ struct GuardElimination { if ((input->node()->kind() == prim::Guard && !input->type()->expectRef().isSummarized()) || input->node()->kind() == prim::Constant || - (allow_numbers && input->type()->isSubtypeOf(NumberType::get())) || + (allow_numbers && input->type()->isSubtypeOf(*NumberType::get())) || except.count(i) != 0) { AT_ASSERT( input->node()->kind() != prim::Guard || diff --git a/torch/csrc/jit/passes/loop_unrolling.cpp b/torch/csrc/jit/passes/loop_unrolling.cpp index 051a3ce56b26..3cc5d30336c3 100644 --- a/torch/csrc/jit/passes/loop_unrolling.cpp +++ b/torch/csrc/jit/passes/loop_unrolling.cpp @@ -304,7 +304,7 @@ void LoopsPeeler::peelLoops() { bool PeelProfilingLoops(const std::shared_ptr& graph) { auto peel_predicate = [](Node* n) { for (auto i : n->inputs()) { - if (i->type()->isSubtypeOf(TensorType::get())) { + if (i->type()->isSubtypeOf(*TensorType::get())) { return true; } } diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index bef66f1f9e36..550cb24aa50a 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -63,7 +63,7 @@ void checkONNXCompatibility(const c10::FunctionSchema& schema) { if (type->kind() == TypeKind::ListType) { const auto& elem_type = reinterpret_cast(type.get())->getElementType(); - if (elem_type->isSubtypeOf(TensorType::get())) { + if (elem_type->isSubtypeOf(*TensorType::get())) { AT_ASSERTM( !has_tensor_list, "ONNX export supports at most one TensorList as input."); @@ -97,7 +97,7 @@ void preprocessCaffe2Ops(Block* block) { origin_input->mustBeNone()) { continue; } - if (type->isSubtypeOf(TensorType::get())) { + if (type->isSubtypeOf(*TensorType::get())) { it->addInput(origin_input); } else if ( type->kind() == TypeKind::BoolType || @@ -119,7 +119,7 @@ void preprocessCaffe2Ops(Block* block) { AT_ASSERT( list_node->kind() == prim::ListConstruct || list_node->kind() == prim::Constant); - if (elem_type->isSubtypeOf(TensorType::get())) { + if (elem_type->isSubtypeOf(*TensorType::get())) { AT_ASSERT(list_node->kind(), prim::ListConstruct); const auto& tensor_list = origin_input->node()->inputs(); for (const auto& t : tensor_list) { diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp index 05ea871eaa6e..2e271b635fc4 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp @@ -47,7 +47,7 @@ bool IsCondCastRequired(Value* cond_val) { return *scalar_type != c10::kBool; } } - return !type->isSubtypeOf(BoolType::get()); + return !type->isSubtypeOf(*BoolType::get()); } bool IsErasableSequence(const Node* loop_node, size_t i) { @@ -337,7 +337,7 @@ void ONNXFixupUninitializedOutput(Node* node, int opset_version) { // Check if the input to ONNX If node is node Bool, and insert // cast to Bool if needed. - if (!if_node->input()->type()->isSubtypeOf(BoolType::get())) { + if (!if_node->input()->type()->isSubtypeOf(*BoolType::get())) { Node* cast_node = CreateCastToBoolNode(if_node->input(), graph); cast_node->insertBefore(if_node); if_node->replaceInputWith(if_node->input(), cast_node->output()); diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index a2dd752e86c6..92605c699307 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -389,7 +389,7 @@ void unpackQuantizedWeightsHelper( auto input_val = match_vmap.at(vmap.at("r"))->node()->inputs()[0]; TORCH_INTERNAL_ASSERT( - input_val->type()->isSubtypeOf(TensorType::get()), + input_val->type()->isSubtypeOf(*TensorType::get()), "Unsupported input type. Expected TensorType, got ", input_val->type()->str()); diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index efb759735c50..63cc0f8d0c42 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -74,7 +74,7 @@ struct PeepholeOptimizeImpl { for (Use u : uses) { if (u.user->matches( "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") && - u.user->input(1)->type()->isSubtypeOf(ListType::ofInts())) { + u.user->input(1)->type()->isSubtypeOf(*ListType::ofInts())) { GRAPH_UPDATE( getHeader(node), " (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with ", diff --git a/torch/csrc/jit/passes/peephole_non_tensor.cpp b/torch/csrc/jit/passes/peephole_non_tensor.cpp index cffb4bb612ed..e6024253fe1b 100644 --- a/torch/csrc/jit/passes/peephole_non_tensor.cpp +++ b/torch/csrc/jit/passes/peephole_non_tensor.cpp @@ -185,7 +185,7 @@ struct PeepholeOptimizeNonTensorImpl { // losing anything by calling unshapedType here auto input_type = unshapedType(node->input()->type()); auto output_type = unshapedType(node->output()->type()); - if (input_type->isSubtypeOf(output_type)) { + if (input_type->isSubtypeOf(*output_type)) { GRAPH_UPDATE( "Removing ", getHeader(node), diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp index b96c5c4d30d5..5cc720a55781 100644 --- a/torch/csrc/jit/passes/quantization/helper.cpp +++ b/torch/csrc/jit/passes/quantization/helper.cpp @@ -345,14 +345,14 @@ std::vector getPassThroughInputs(Value* v) { return inputs; } else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) { // only propagate dequantize for Tensor - if (v->type()->isSubtypeOf(TensorType::get())) { + if (v->type()->isSubtypeOf(*TensorType::get())) { return {n->input(0)}; } else { return {}; } } else if ( n->kind() == prim::ListConstruct && - v->type()->isSubtypeOf(ListType::ofTensors())) { + v->type()->isSubtypeOf(*ListType::ofTensors())) { std::vector inputs; for (auto* v : n->inputs()) { inputs.push_back(v); @@ -361,7 +361,7 @@ std::vector getPassThroughInputs(Value* v) { } else if (n->kind() == prim::TupleConstruct) { std::vector inputs; for (auto* input : n->inputs()) { - if (input->type()->isSubtypeOf(TensorType::get())) { + if (input->type()->isSubtypeOf(*TensorType::get())) { inputs.push_back(input); } } @@ -560,8 +560,8 @@ bool alwaysRaisesException(Block* block) { // Check if a value in the graph is a Scalar value bool isScalar(Value* v) { auto iv = toIValue(v); - return v->type()->isSubtypeOf(NumberType::get()) || - (v->type()->isSubtypeOf(TensorType::get()) && iv && iv->isTensor() && + return v->type()->isSubtypeOf(*NumberType::get()) || + (v->type()->isSubtypeOf(*TensorType::get()) && iv && iv->isTensor() && iv->toTensor().dim() == 0); } diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index e1cc2eb38a4f..740442bd64ec 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -1185,8 +1185,8 @@ bool InsertObserversHelper::valueNeedsToBeQuantized( Value* v, const QConfig& qconfig) { if (isBiasOfConvOrLinear(v) || - !(v->type()->isSubtypeOf(TensorType::get()) || - v->type()->isSubtypeOf(ListType::ofTensors())) || + !(v->type()->isSubtypeOf(*TensorType::get()) || + v->type()->isSubtypeOf(*ListType::ofTensors())) || isEmbeddingBagNonInput(v)) { return false; } diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 85d7cbd08d53..3b4bc7a78100 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -1024,7 +1024,7 @@ std::tuple InsertQuantDeQuantHelper:: // TODO: refactor findObserverName to take Node* as input Value* v = n->output(); TORCH_INTERNAL_ASSERT( - v->type()->isSubtypeOf(TensorType::get()), + v->type()->isSubtypeOf(*TensorType::get()), "Expected output of observer node to be Tensor"); auto observer_name = findObserverName(v); TORCH_INTERNAL_ASSERT( @@ -1352,7 +1352,7 @@ void InsertQuantDeQuantHelper::runWeightObserver( blocks_to_visit.pop(); for (auto n : b->nodes()) { for (Value* v : n->outputs()) { - if (!v->type()->isSubtypeOf(TensorType::get())) { + if (!v->type()->isSubtypeOf(*TensorType::get())) { continue; } auto observer_name = findObserverName(v); @@ -1416,7 +1416,7 @@ void InsertQuantDeQuantHelper::run( std::vector input_values; for (const auto idx : c10::irange(1, method.num_inputs())) { auto& v = graph->inputs()[idx]; - if (v->type()->isSubtypeOf(TensorType::get())) { + if (v->type()->isSubtypeOf(*TensorType::get())) { input_values.push_back(v); } } @@ -1429,7 +1429,7 @@ void InsertQuantDeQuantHelper::run( for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) { Node* n = *it++; for (Value* v : n->outputs()) { - if (!v->type()->isSubtypeOf(TensorType::get())) { + if (!v->type()->isSubtypeOf(*TensorType::get())) { continue; } collectObserverNodesAndValueToQuantize(module, v); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index c74c6ee40221..382d09561b40 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -63,18 +63,18 @@ bool isValidArgumentForRunning(Value* v) { } return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false); } - return v->type()->isSubtypeOf(FloatType::get()); + return v->type()->isSubtypeOf(*FloatType::get()); } bool isValidReturnForRunning(Value* v) { - return v->type()->isSubtypeOf(TensorType::get()) || - v->type()->isSubtypeOf(NumberType::get()); + return v->type()->isSubtypeOf(*TensorType::get()) || + v->type()->isSubtypeOf(*NumberType::get()); } bool containsTensorType(const TypePtr& t) { auto n_contained = t->containedTypes().size(); if (n_contained == 1) { - return t->containedTypes().at(0)->isSubtypeOf(TensorType::get()); + return t->containedTypes().at(0)->isSubtypeOf(*TensorType::get()); } else if (n_contained > 1) { return std::any_of( t->containedTypes().begin(), @@ -129,7 +129,7 @@ class ShapePropagator { // ops which take the result and write to input "out" if (auto out_arg_index = n->schema().argumentIndexWithName("out")) { auto arg = n->schema().arguments().at(*out_arg_index); - return arg.kwarg_only() && arg.type()->isSubtypeOf(TensorType::get()); + return arg.kwarg_only() && arg.type()->isSubtypeOf(*TensorType::get()); } return false; } @@ -183,7 +183,7 @@ class ShapePropagator { .zero_(); } // fallthrough - } else if (type_->isSubtypeOf(FloatType::get())) { + } else if (type_->isSubtypeOf(*FloatType::get())) { return 0.f; } // we should not get here because isValidArgumentForRunning should have @@ -214,9 +214,9 @@ class ShapePropagator { return c10::nullopt; } for (const auto i : c10::irange(args.size())) { - if (args[i].type()->isSubtypeOf(ListType::ofTensors())) { + if (args[i].type()->isSubtypeOf(*ListType::ofTensors())) { return c10::nullopt; - } else if (args[i].type()->isSubtypeOf(TensorType::get())) { + } else if (args[i].type()->isSubtypeOf(*TensorType::get())) { if (auto type = node->input(i)->type()->cast()) { if (complete && !type->isComplete()) { return c10::nullopt; @@ -617,11 +617,11 @@ class ShapePropagator { return; // correct num type is already set case prim::NumToTensor: { TypePtr typ = node->input()->type(); - if (typ->isSubtypeOf(IntType::get()) || - typ->isSubtypeOf(BoolType::get())) { + if (typ->isSubtypeOf(*IntType::get()) || + typ->isSubtypeOf(*BoolType::get())) { node->output()->setType(TensorType::create( at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); - } else if (node->input()->type()->isSubtypeOf(FloatType::get())) { + } else if (node->input()->type()->isSubtypeOf(*FloatType::get())) { node->output()->setType(TensorType::create( at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); } @@ -632,7 +632,7 @@ class ShapePropagator { // as_tensor has an overloaded schema and can either have a tensor or // a list as the first input, if the input is a tensor, we delegate // the shape propagation in PropagateTensorShapeOnNode - if (node->inputs().at(0)->type()->isSubtypeOf(TensorType::get())) { + if (node->inputs().at(0)->type()->isSubtypeOf(*TensorType::get())) { break; } return propagateTorchTensorShape(node); @@ -659,7 +659,7 @@ class ShapePropagator { return; } case prim::Constant: { - if (node->output()->type()->isSubtypeOf(TensorType::get())) { + if (node->output()->type()->isSubtypeOf(*TensorType::get())) { node->output()->inferTypeFrom(node->t(attr::value)); } return; @@ -668,7 +668,7 @@ class ShapePropagator { // If we have specialized the optional type to the element type, // we want to pass it down. We write this as input.isSubtypeOf(output) // to be sure that we don't screw up nested optionals. - if (node->input()->type()->isSubtypeOf(node->output()->type())) { + if (node->input()->type()->isSubtypeOf(*node->output()->type())) { node->output()->setType(node->input()->type()); } return; @@ -709,7 +709,7 @@ class ShapePropagator { // If we have specialized the optional type to the element type, // we want to pass it down. We write this as input.isSubtypeOf(output) // to be sure that we don't screw up nested optionals. - if (node->input()->type()->isSubtypeOf(node->output()->type())) { + if (node->input()->type()->isSubtypeOf(*node->output()->type())) { node->output()->setType(node->input()->type()); } return; @@ -1610,7 +1610,7 @@ class ShapePropagator { auto outputs = node->outputs(); AT_ASSERT(types.size() == outputs.size()); for (const auto i : c10::irange(types.size())) { - AT_ASSERT(outputs[i]->type()->isSubtypeOf(TensorType::get())); + AT_ASSERT(outputs[i]->type()->isSubtypeOf(*TensorType::get())); outputs[i]->setType(types[i]); } return true; @@ -2194,7 +2194,7 @@ using TypeCache = std::unordered_map; TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache); TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) { - if (type->isSubtypeOf(TensorType::get())) { + if (type->isSubtypeOf(*TensorType::get())) { return TensorType::get(); } std::vector unshaped_contained_types; diff --git a/torch/csrc/jit/passes/specialize_autogradzero.cpp b/torch/csrc/jit/passes/specialize_autogradzero.cpp index d7b31cae851a..ae0b7175f9d4 100644 --- a/torch/csrc/jit/passes/specialize_autogradzero.cpp +++ b/torch/csrc/jit/passes/specialize_autogradzero.cpp @@ -142,8 +142,8 @@ struct AutogradZeroSpecializer { state_[input] = State::Unknown; } } else if ( - tp->isSubtypeOf(TensorType::get()) || - tp->isSubtypeOf(ListType::ofTensors())) { + tp->isSubtypeOf(*TensorType::get()) || + tp->isSubtypeOf(*ListType::ofTensors())) { state_[input] = State::Nonzero; } else { state_[input] = State::Unknown; diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp index e88d653e96ed..9c131df545ea 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -189,7 +189,7 @@ struct SymbolicShapeAnalyzer { graph_->inputs().at(i)->type()->cast()) { // None will get handled with constant substitution later if (!type->cast() && - !NoneType::get()->isSubtypeOf(type)) { + !NoneType::get()->isSubtypeOf(*type)) { graph_->inputs().at(i)->setType(opt_type->getElementType()); } } else if (graph_->inputs().at(i)->type()->cast()) { diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index b505ce836d89..30988cb7f72f 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -485,7 +485,7 @@ class TensorExprFuser { auto sinputs = subgraph->inputs(); AT_ASSERT(inputs.size() == sinputs.size()); for (const auto i : c10::irange(inputs.size())) { - if (inputs[i]->type()->isSubtypeOf(TensorType::get())) { + if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) { Value* soutput = graph->insert(aten::size, {inputs[i]}); aliasDb_->createValue(soutput); GRAPH_DEBUG( @@ -542,7 +542,7 @@ class TensorExprFuser { } auto tensor_inputs = filter(n->inputs(), [](Value* v) { - return v->type()->isSubtypeOf(TensorType::get()); + return v->type()->isSubtypeOf(*TensorType::get()); }); GRAPH_DEBUG("Building sizes for ", getHeader(n)); bool all_inputs_have_sizes = true; diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp index 28b70b0d137c..0d5dde95e5aa 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -209,7 +209,7 @@ c10::optional toIValueProp(const Value* v) { } else if (containedType == FloatType::get()) { return IValue( fmap(genericList, [](const IValue& v) { return v.toDouble(); })); - } else if (containedType->isSubtypeOf(TensorType::get())) { + } else if (containedType->isSubtypeOf(*TensorType::get())) { return IValue( fmap(genericList, [](const IValue& v) { return v.toTensor(); })); } else { diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index f8fae19ed8f5..673b139bd759 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -270,7 +270,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { } // check if the classType conform with the interface or not std::stringstream why_not; - if (!classType->isSubtypeOfExt(interfaceType, &why_not)) { + if (!classType->isSubtypeOfExt(*interfaceType, &why_not)) { throw py::cast_error(c10::str( "Object of type ", classType->repr_str(), diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index eff1ddc24399..ba8b2dc0b0a5 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -510,7 +510,7 @@ inline InferredType tryToInferContainerType(py::handle input) { } inline bool isTraceableType(const TypePtr& type) { - if (type->isSubtypeOf(TensorType::get())) { + if (type->isSubtypeOf(*TensorType::get())) { return true; } diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index f07b57504924..e1a9e9ce9acd 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1010,7 +1010,7 @@ std::shared_ptr PythonSliceClass::call( Graph& graph = *(caller.graph()); auto ValOr = [&](Value* given, int64_t default_val) { - if (!given || given->type()->isSubtypeOf(NoneType::get())) { + if (!given || given->type()->isSubtypeOf(*NoneType::get())) { return graph.insertConstant(default_val, loc); } return given; diff --git a/torch/csrc/jit/runtime/argument_spec.cpp b/torch/csrc/jit/runtime/argument_spec.cpp index b074921a187a..48db6ac39787 100644 --- a/torch/csrc/jit/runtime/argument_spec.cpp +++ b/torch/csrc/jit/runtime/argument_spec.cpp @@ -29,10 +29,10 @@ void ArgumentSpecCreator::scan( if (depth >= ARG_SPEC_DEPTH_LIMIT) { instructions_.emplace_back(SKIP); } - if (typ->isSubtypeOf(TensorType::get())) { + if (typ->isSubtypeOf(*TensorType::get())) { num_tensors_++; instructions_.emplace_back(SPECIALIZE_TENSOR); - } else if (typ->isSubtypeOf(OptionalType::ofTensor())) { + } else if (typ->isSubtypeOf(*OptionalType::ofTensor())) { num_tensors_++; num_optionals_++; instructions_.emplace_back(SPECIALIZE_OPTIONAL_TENSOR); diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index 3dffcc7f3cf5..e0caf304febd 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -218,7 +218,7 @@ class GradientHelper { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Value* input_list; if (grad_values.size() == 1 && - grad_values[0]->type()->isSubtypeOf(ListType::ofTensors())) { + grad_values[0]->type()->isSubtypeOf(*ListType::ofTensors())) { input_list = grad_values[0]; } else { input_list = diff --git a/torch/csrc/jit/runtime/profiling_record.cpp b/torch/csrc/jit/runtime/profiling_record.cpp index 400b54eb2c70..94c4b00a0421 100644 --- a/torch/csrc/jit/runtime/profiling_record.cpp +++ b/torch/csrc/jit/runtime/profiling_record.cpp @@ -108,7 +108,7 @@ ProfileIValueOp* ProfilingRecord::createProfileIValueNode(Value* in_val) { static void unprofileGraphInputs(const std::shared_ptr& graph) { for (auto i : graph->inputs()) { - if (i->type()->isSubtypeOf(TensorType::get())) { + if (i->type()->isSubtypeOf(*TensorType::get())) { i->setType(unshapedType(i->type())); } } @@ -124,7 +124,7 @@ static void unprofileBlock(Block* start_block) { for (auto n : block->nodes()) { for (auto o : n->outputs()) { - if (o->type()->isSubtypeOf(TensorType::get())) { + if (o->type()->isSubtypeOf(*TensorType::get())) { o->setType(unshapedType(o->type())); } } @@ -298,7 +298,7 @@ void ProfilingRecord::instrumentBlock(Block* block) { for (size_t offset = 0; offset < block->return_node()->inputs().size(); offset++) { auto i = block->return_node()->input(offset); - if (i->type()->isSubtypeOf(TensorType::get())) { + if (i->type()->isSubtypeOf(*TensorType::get())) { insertShapeProfile(block->return_node(), offset); } } diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 015d607044dd..719c4c36a2b5 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -32,13 +32,13 @@ c10::AliasAnalysisKind aliasAnalysisConservative() { } void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) { - if (!elem_type->isSubtypeOf(NumberType::get()) && + if (!elem_type->isSubtypeOf(*NumberType::get()) && elem_type != BoolType::get()) { std::stringstream error; error << "Input must be of ints, floats, or bools, " << "got " << elem_type->repr_str(); // special case empty list torch.tensor([]) - if (elem_type->isSubtypeOf(TensorType::get())) { + if (elem_type->isSubtypeOf(*TensorType::get())) { if (empty_list) { error << "\nEmpty lists default to List[Tensor]. Add a variable " "annotation to the assignment to create an empty list " diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 0ded27c5c465..097cf7883b12 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -432,7 +432,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( auto* node = p_node->node(); std::vector candidates = node->tys(attr::types); for (const auto& candidate_type : candidates) { - if (input_type->isSubtypeOf(candidate_type)) { + if (input_type->isSubtypeOf(*candidate_type)) { p_node->Output(0) = true; return; } diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp index e9df69845a2c..b7d442648dde 100644 --- a/torch/csrc/jit/runtime/vararg_functions.cpp +++ b/torch/csrc/jit/runtime/vararg_functions.cpp @@ -341,7 +341,7 @@ void createObject( void isinstance(Stack& stack, at::ArrayRef types) { at::TypePtr ty = pop(stack).type(); for (const at::TypePtr& candidate : types) { - if (ty->isSubtypeOf(candidate)) { + if (ty->isSubtypeOf(*candidate)) { push(stack, true); return; } diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index f465eaf4dff0..03846b4b3de7 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -724,7 +724,7 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls) { set_schema.returns().size(), " return values"); TORCH_CHECK( - set_schema.returns().at(0).type()->isSubtypeOf(NoneType::get()), + set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()), "'__setstate__' must return None, but found value of type", set_schema.returns().at(0).type()->annotation_str()); @@ -734,7 +734,7 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls) { auto set_type = set_schema.arguments().at(1).type(); TORCH_CHECK( - get_type->isSubtypeOf(set_type), + get_type->isSubtypeOf(*set_type), "'__getstate__'s return type (", get_type->annotation_str(), ") does not match '__setstate__'s argument type (", diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 6b1bf1530462..73b9d17bb107 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -1117,7 +1117,7 @@ struct PythonPrintImpl { // we cannot recover the type of unwrap_optional(None), // using normal schema matching, so we route around this by rewriting // the call to unwrap_optional(annotated(Optional[T], None)) - if (node->input()->type()->isSubtypeOf(NoneType::get()) || + if (node->input()->type()->isSubtypeOf(*NoneType::get()) || node->input()->mustBeNone()) { auto input_type = OptionalType::create(node->output()->type()); stmt << "annotate(" << input_type->annotation_str(type_printer_) diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 11d48eff796b..0e810ce3bb8a 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -377,11 +377,11 @@ std::vector TensorExprKernel::sizesForValue( } } - if (v->type()->isSubtypeOf(FloatType::get()) || - v->type()->isSubtypeOf(IntType::get())) { + if (v->type()->isSubtypeOf(*FloatType::get()) || + v->type()->isSubtypeOf(*IntType::get())) { return {int64_t{1}}; } - if (v->type()->isSubtypeOf(NoneType::get())) { + if (v->type()->isSubtypeOf(*NoneType::get())) { return {}; } diff --git a/torch/custom_class.h b/torch/custom_class.h index f39695a89881..f1aed8f9524b 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -344,7 +344,7 @@ class class_ : public ::torch::detail::class_base { auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema(); auto arg_type = setstate_schema.arguments().at(1).type(); TORCH_CHECK( - ser_type->isSubtypeOf(arg_type), + ser_type->isSubtypeOf(*arg_type), "__getstate__'s return type should be a subtype of " "input argument of __setstate__. Got ", ser_type->repr_str(),