[jit] Reduce refcounting of Types (#65345)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65345

FooType::get() can return a const reference. Inconveniently, converting shared_ptr<FooType> to shared_ptr<Type> requires a copy & refcount bump, so to properly take advantage of this in unshapedType() we need to take a const Type& in isSubtypeOf(), which is good practice anyway -- don't require a shared_ptr if you don't need to take ownership.
ghstack-source-id: 140044165

Test Plan:
CI

perf says c10::unshapedType time decreased from 2.8% to 2.2% during static runtime startup, though I expect this to be generally beneficial.

Reviewed By: hlu1

Differential Revision: D31027361

fbshipit-source-id: 676feb81db9f74ad7b8651d8774f4ecb4cfa6ab8
This commit is contained in:
Scott Wolchok
2021-10-08 09:01:42 -07:00
committed by Facebook GitHub Bot
parent 1ae468a484
commit 2d885ab73d
69 changed files with 421 additions and 405 deletions

View File

@ -36,7 +36,7 @@ static bool areAnyArgumentsTensorList(const FunctionSchema& schema) {
return std::any_of(
schema.arguments().begin(),
schema.arguments().end(),
[] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); });
[] (const Argument& arg) { return arg.type()->isSubtypeOf(*ListType::ofTensors()); });
}
// Returns if an operator is in-place. An operator is inplace if:

View File

@ -63,7 +63,7 @@ List<T> toTypedList(impl::GenericList list) {
// as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
// have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr<T>()))
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr<T>()))
, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
return List<T>(std::move(list.impl_));
}

View File

@ -172,13 +172,13 @@ private:
" arguments but this PyTorch build only supports ", c10::utils::bitset::NUM_BITS());
c10::utils::bitset dispatch_arg_indices_reverse;
for (size_t index = 0; index < schema.arguments().size(); ++index) {
if (schema.arguments()[index].type()->isSubtypeOf(TensorType::get()) ||
if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) ||
schema.arguments()[index].type()->isSubtypeOf(
ListType::ofTensors()) ||
*ListType::ofTensors()) ||
schema.arguments()[index].type()->isSubtypeOf(
ListType::ofOptionalTensors()) ||
*ListType::ofOptionalTensors()) ||
schema.arguments()[index].type()->isSubtypeOf(
OptionalType::ofTensor())) {
*OptionalType::ofTensor())) {
dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
}
}

View File

@ -76,7 +76,7 @@ inline bool Argument::isBackwardCompatibleWith(
if (lhs->kwarg_only() && !rhs->kwarg_only()) {
return false;
}
if (!rhs->type()->isSubtypeOfExt(lhs->type(), why_not)) {
if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) {
return false;
}
if (rhs->default_value().has_value() &&
@ -179,7 +179,7 @@ inline void FunctionSchema::checkArg(
// Fast-path for the common case
return;
}
if (!value.type()->isSubtypeOf(argument.type())) {
if (!value.type()->isSubtypeOf(*argument.type())) {
TORCH_CHECK(
false,
formatTypeMismatchMsg(
@ -304,7 +304,7 @@ inline bool isSubtypeOfList(
if (c.name() != p.name()) {
return false;
}
if (!c.type()->isSubtypeOfExt(p.type(), why_not)) {
if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) {
return false;
}
}

View File

@ -46,7 +46,7 @@ struct TORCH_API AnyType : public Type {
}
static const TypeKind Kind = TypeKind::AnyType;
// global singleton
static AnyTypePtr get();
static const AnyTypePtr& get();
private:
AnyType() : Type(TypeKind::AnyType) {}
@ -66,7 +66,7 @@ template <TypeKind K, typename T>
struct SingleElementType : public Type {
static const TypeKind Kind = K;
TypePtr getElementType() const {
const TypePtr& getElementType() const {
return elem;
}
@ -104,7 +104,7 @@ struct TORCH_API UnionType : public Type {
static const TypeKind Kind = TypeKind::UnionType;
bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override;
bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
std::string str() const override;
@ -169,7 +169,7 @@ struct TORCH_API OptionalType : public UnionType {
bool operator==(const Type& rhs) const override;
TypePtr getElementType() const {
const TypePtr& getElementType() const {
return contained_;
}
@ -189,7 +189,7 @@ struct TORCH_API OptionalType : public UnionType {
return create(contained_types[0]);
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
// common cast Optional[Tensor] for undefined tensor type
static OptionalTypePtr ofTensor();
@ -578,7 +578,7 @@ struct TORCH_API TensorType : public Type {
}
bool operator==(const Type& rhs) const override;
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
std::string str() const override;
@ -710,7 +710,7 @@ struct TORCH_API TensorType : public Type {
c10::optional<bool> undefined() const { return undefined_; }
static TensorTypePtr get();
static const TensorTypePtr& get();
static const TypeKind Kind = TypeKind::TensorType;
@ -788,7 +788,7 @@ struct TORCH_API ListType
return create(contained_types.at(0));
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
// common cast List[Tensor]
static ListTypePtr ofTensors();
@ -914,12 +914,12 @@ struct TORCH_API FutureType
return create(contained_types.at(0));
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
if (Type::isSubtypeOfExt(rhs, why_not)) {
return true;
}
if (auto rhs_ = rhs->cast<FutureType>()) {
return getElementType()->isSubtypeOfExt(rhs_->getElementType(), why_not);
if (auto rhs_ = rhs.castRaw<FutureType>()) {
return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not);
}
return false;
}
@ -1034,7 +1034,7 @@ struct TORCH_API TupleType : public NamedType {
}
bool operator==(const Type& rhs) const override;
bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override;
bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
std::string str() const override;
bool hasFreeVariables() const override {
@ -1129,7 +1129,7 @@ struct TORCH_API EnumType : public NamedType {
return false;
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
std::shared_ptr<const ::torch::jit::CompilationUnit> compilation_unit() const {
auto cu = cu_.lock();
@ -1182,7 +1182,7 @@ struct TORCH_API AnyEnumType : public Type {
}
static const TypeKind Kind = TypeKind::AnyEnumType;
// global singleton
static AnyEnumTypePtr get();
static const AnyEnumTypePtr& get();
private:
AnyEnumType()
: Type(TypeKind::AnyEnumType) {}
@ -1198,14 +1198,14 @@ using NumberTypePtr = std::shared_ptr<NumberType>;
struct TORCH_API NumberType : public Type {
bool operator==(const Type& rhs) const override;
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
std::string str() const override {
return "Scalar"; // match what PythonArgParser says for clarity
}
static const TypeKind Kind = TypeKind::NumberType;
// global singleton
static NumberTypePtr get();
static const NumberTypePtr& get();
protected:
NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {}
@ -1227,13 +1227,13 @@ struct TORCH_API FloatType : public NumberType {
std::string str() const override {
return "float";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::FloatType;
// global singleton
static FloatTypePtr get();
static const FloatTypePtr& get();
private:
FloatType() : NumberType(TypeKind::FloatType) {}
@ -1252,13 +1252,13 @@ struct TORCH_API ComplexType : public NumberType {
std::string str() const override {
return "complex";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::ComplexType;
// global singleton
static ComplexTypePtr get();
static const ComplexTypePtr& get();
private:
ComplexType() : NumberType(TypeKind::ComplexType) {}
@ -1277,13 +1277,13 @@ struct TORCH_API IntType : public NumberType {
std::string str() const override {
return "int";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::IntType;
// global singleton
static IntTypePtr get();
static const IntTypePtr& get();
private:
IntType() : NumberType(TypeKind::IntType) {}
@ -1304,7 +1304,7 @@ struct TORCH_API BoolType : public Type {
}
static const TypeKind Kind = TypeKind::BoolType;
// global singleton
static BoolTypePtr get();
static const BoolTypePtr& get();
private:
BoolType() : Type(TypeKind::BoolType) {}
@ -1326,7 +1326,7 @@ struct TORCH_API StringType : public Type {
}
static const TypeKind Kind = TypeKind::StringType;
// global singleton
static StringTypePtr get();
static const StringTypePtr& get();
private:
StringType() : Type(TypeKind::StringType) {}
@ -1346,7 +1346,7 @@ struct TORCH_API StorageType : public Type {
}
static const TypeKind Kind = TypeKind::StorageType;
// global singleton
static StorageTypePtr get();
static const StorageTypePtr& get();
private:
StorageType() : Type(TypeKind::StorageType) {}
@ -1393,11 +1393,11 @@ struct TORCH_API NoneType : public Type {
std::string str() const override {
return "NoneType";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override;
bool isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const override;
static const TypeKind Kind = TypeKind::NoneType;
// global singleton
static NoneTypePtr get();
static const NoneTypePtr& get();
private:
NoneType() : Type(TypeKind::NoneType) {}
@ -1415,7 +1415,7 @@ struct TORCH_API GeneratorType : public Type {
}
static const TypeKind Kind = TypeKind::GeneratorType;
// global singleton
static GeneratorTypePtr get();
static const GeneratorTypePtr& get();
private:
GeneratorType() : Type(TypeKind::GeneratorType) {}
@ -1433,7 +1433,7 @@ struct TORCH_API QuantizerType : public Type {
}
static const TypeKind Kind = TypeKind::QuantizerType;
// global singleton
static QuantizerTypePtr get();
static const QuantizerTypePtr& get();
private:
QuantizerType() : Type(TypeKind::QuantizerType) {}
@ -1451,7 +1451,7 @@ struct TORCH_API QSchemeType : public Type {
}
static const TypeKind Kind = TypeKind::QSchemeType;
// global singleton
static QSchemeTypePtr get();
static const QSchemeTypePtr& get();
private:
QSchemeType() : Type(TypeKind::QSchemeType) {}
@ -1469,7 +1469,7 @@ struct TORCH_API DeviceObjType : public Type {
}
static const TypeKind Kind = TypeKind::DeviceObjType;
// global singleton
static DeviceObjTypePtr get();
static const DeviceObjTypePtr& get();
private:
DeviceObjType() : Type(TypeKind::DeviceObjType) {}
@ -1487,7 +1487,7 @@ struct TORCH_API StreamObjType : public Type {
}
static const TypeKind Kind = TypeKind::StreamObjType;
// global singleton
static StreamObjTypePtr get();
static const StreamObjTypePtr& get();
private:
StreamObjType() : Type(TypeKind::StreamObjType) {}
@ -1533,7 +1533,7 @@ struct TORCH_API CapsuleType : public Type {
}
static const TypeKind Kind = TypeKind::CapsuleType;
// global singleton
static CapsuleTypePtr get();
static const CapsuleTypePtr& get();
private:
CapsuleType()
: Type(TypeKind::CapsuleType) {}
@ -1551,7 +1551,7 @@ struct TORCH_API PyObjectType : public Type {
}
static const TypeKind Kind = TypeKind::PyObjectType;
// global singleton
static PyObjectTypePtr get();
static const PyObjectTypePtr& get();
private:
PyObjectType()
: Type(TypeKind::PyObjectType) {}
@ -1589,18 +1589,18 @@ TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
// Be careful with calls because this can be very slow. If calling this
// on a graph, use `EraseShapeInformation` in shape_analysis.h
inline TypePtr unshapedType(const TypePtr& type) {
if (type->isSubtypeOf(TensorType::get())) {
if (type->isSubtypeOf(*TensorType::get())) {
return TensorType::get();
}
return type->withContained(fmap(type->containedTypes(), unshapedType));
}
inline TypePtr TensorType::fromNumberType(TypePtr typ) {
if (typ->isSubtypeOf(IntType::get())) {
if (typ->isSubtypeOf(*IntType::get())) {
return TensorType::createContiguous(at::kLong, at::kCPU, {});
} else if (typ->isSubtypeOf(FloatType::get())) {
} else if (typ->isSubtypeOf(*FloatType::get())) {
return TensorType::createContiguous(at::kDouble, at::kCPU, {});
} else if (typ->isSubtypeOf(BoolType::get())) {
} else if (typ->isSubtypeOf(*BoolType::get())) {
return TensorType::createContiguous(at::kBool, at::kCPU, {});
} else if (typ->kind() == NumberType::Kind) {
return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt);
@ -2013,7 +2013,7 @@ struct TORCH_API ClassType : public NamedType {
return attributes_.size();
}
const TypePtr getAttribute(size_t slot) const {
TypePtr getAttribute(size_t slot) const {
AT_ASSERT(slot < attributes_.size());
return attributes_.at(slot).getType();
}
@ -2104,7 +2104,7 @@ struct TORCH_API ClassType : public NamedType {
"'");
TypePtr atype = getAttribute(*slot_idx);
TORCH_CHECK(
ty->isSubtypeOf(atype),
ty->isSubtypeOf(*atype),
ty->repr_str(),
" is not compatible with the type ",
atype->repr_str(),
@ -2198,7 +2198,7 @@ struct TORCH_API ClassType : public NamedType {
auto ptr = ClassType::create(name(), compilation_unit_, is_module());
AT_ASSERT(numAttributes() == contained_types.size());
for(size_t i = 0; i < attributes_.size(); ++i) {
AT_ASSERT(attributes_[i].getType()->isSubtypeOf(contained_types[i]));
AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i]));
ptr->addAttribute(attributes_[i].getName(), contained_types[i]);
}
// Copy methods over
@ -2273,7 +2273,7 @@ struct TORCH_API ClassType : public NamedType {
// These variants are not registered in the global class table.
ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const;
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
static const TypeKind Kind = TypeKind::ClassType;
@ -2360,13 +2360,13 @@ struct TORCH_API InterfaceType : public NamedType {
return std::string("InterfaceType<") + name()->name() + ">";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
// try to find a method of this interface,
// returns nullptr if not found.
const FunctionSchema* getMethod(const std::string& name) const;
void addMethod(FunctionSchema schema);
const std::vector<FunctionSchema>& methods() {
const std::vector<FunctionSchema>& methods() const {
return *methods_;
}
@ -2414,7 +2414,7 @@ return "Layout";
}
static const TypeKind Kind = TypeKind::LayoutType;
// global singleton
static LayoutTypePtr get();
static const LayoutTypePtr& get();
private:
LayoutType() : EnumerationType() {}
@ -2429,7 +2429,7 @@ return "ScalarType";
}
static const TypeKind Kind = TypeKind::ScalarTypeType;
// global singleton
static ScalarTypeTypePtr get();
static const ScalarTypeTypePtr& get();
private:
ScalarTypeType() : EnumerationType() {}
@ -2448,7 +2448,7 @@ struct TORCH_API AnyListType : public Type {
}
static const TypeKind Kind = TypeKind::AnyListType;
// global singleton
static AnyListTypePtr get();
static const AnyListTypePtr& get();
private:
AnyListType()
: Type(TypeKind::AnyListType) {}
@ -2469,7 +2469,7 @@ struct TORCH_API AnyTupleType : public Type {
static const TypeKind Kind = TypeKind::AnyTupleType;
// global singleton
static AnyTupleTypePtr get();
static const AnyTupleTypePtr& get();
private:
AnyTupleType()
: Type(TypeKind::AnyTupleType) {}
@ -2488,7 +2488,7 @@ struct TORCH_API AnyClassType : public Type {
}
static const TypeKind Kind = TypeKind::AnyClassType;
// global singleton
static AnyClassTypePtr get();
static const AnyClassTypePtr& get();
private:
AnyClassType()
: Type(TypeKind::AnyClassType) {}

View File

@ -87,11 +87,24 @@ struct TORCH_API Type : std::enable_shared_from_this<Type> {
// This additional information should only contain details that are not obvious
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
// but not clear why `Foo <: InterfaceBar` might be false.
virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
virtual bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const;
virtual bool is_module() const;
bool isSubtypeOf(const TypePtr& rhs) const {
bool isSubtypeOf(const Type& rhs) const {
return isSubtypeOfExt(rhs, nullptr);
}
// Compatibility shims to accommodate existing code that passes shared_ptrs around.
// Ideally, we would just delete this, but it should be harmless.
template <typename T>
typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type
isSubtypeOf(const std::shared_ptr<T>& rhs) const {
return isSubtypeOf(*rhs);
}
template <typename T>
typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type
isSubtypeOfExt(const std::shared_ptr<T>& rhs, std::ostream* why_not) const {
return isSubtypeOfExt(*rhs, why_not);
}
// How this type will appear in FunctionSchema declarations
virtual std::string str() const = 0;

View File

@ -153,74 +153,74 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
return out;
}
AnyTypePtr AnyType::get() {
const AnyTypePtr& AnyType::get() {
static AnyTypePtr value(new AnyType());
return value;
}
TensorTypePtr TensorType::get() {
const TensorTypePtr& TensorType::get() {
static auto value = TensorType::create(
{}, {}, SymbolicShape(), VaryingShape<Stride>{}, {});
return value;
}
NumberTypePtr NumberType::get() {
const NumberTypePtr& NumberType::get() {
static NumberTypePtr value(new NumberType());
return value;
}
IntTypePtr IntType::get() {
const IntTypePtr& IntType::get() {
static IntTypePtr value(new IntType());
return value;
}
FloatTypePtr FloatType::get() {
const FloatTypePtr& FloatType::get() {
static FloatTypePtr value(new FloatType());
return value;
}
ComplexTypePtr ComplexType::get() {
const ComplexTypePtr& ComplexType::get() {
static ComplexTypePtr value(new ComplexType());
return value;
}
BoolTypePtr BoolType::get() {
const BoolTypePtr& BoolType::get() {
static BoolTypePtr value(new BoolType());
return value;
}
StorageTypePtr StorageType::get() {
const StorageTypePtr& StorageType::get() {
static StorageTypePtr value(new StorageType());
return value;
}
NoneTypePtr NoneType::get() {
const NoneTypePtr& NoneType::get() {
static NoneTypePtr value(new NoneType());
return value;
}
GeneratorTypePtr GeneratorType::get() {
const GeneratorTypePtr& GeneratorType::get() {
static GeneratorTypePtr value(new GeneratorType());
return value;
}
QuantizerTypePtr QuantizerType::get() {
const QuantizerTypePtr& QuantizerType::get() {
static QuantizerTypePtr value(new QuantizerType());
return value;
}
QSchemeTypePtr QSchemeType::get() {
const QSchemeTypePtr& QSchemeType::get() {
static QSchemeTypePtr value(new QSchemeType());
return value;
}
StringTypePtr StringType::get() {
const StringTypePtr& StringType::get() {
static StringTypePtr value(new StringType());
return value;
}
DeviceObjTypePtr DeviceObjType::get() {
const DeviceObjTypePtr& DeviceObjType::get() {
static DeviceObjTypePtr value(new DeviceObjType());
return value;
}
StreamObjTypePtr StreamObjType::get() {
const StreamObjTypePtr& StreamObjType::get() {
static StreamObjTypePtr value(new StreamObjType());
return value;
}
ScalarTypeTypePtr ScalarTypeType::get() {
const ScalarTypeTypePtr& ScalarTypeType::get() {
static ScalarTypeTypePtr value(new ScalarTypeType());
return value;
}
LayoutTypePtr LayoutType::get() {
const LayoutTypePtr& LayoutType::get() {
static LayoutTypePtr value(new LayoutType());
return value;
}
@ -228,11 +228,11 @@ OptionalTypePtr OptionalType::ofTensor() {
static auto value = OptionalType::create(TensorType::get());
return value;
}
PyObjectTypePtr PyObjectType::get() {
const PyObjectTypePtr& PyObjectType::get() {
static PyObjectTypePtr value(new PyObjectType());
return value;
}
CapsuleTypePtr CapsuleType::get() {
const CapsuleTypePtr& CapsuleType::get() {
static CapsuleTypePtr value(new CapsuleType());
return value;
}
@ -265,31 +265,31 @@ ListTypePtr ListType::ofStrings() {
return value;
}
AnyListTypePtr AnyListType::get() {
const AnyListTypePtr& AnyListType::get() {
static AnyListTypePtr value(new AnyListType());
return value;
}
AnyTupleTypePtr AnyTupleType::get() {
const AnyTupleTypePtr& AnyTupleType::get() {
static AnyTupleTypePtr value(new AnyTupleType());
return value;
}
AnyClassTypePtr AnyClassType::get() {
const AnyClassTypePtr& AnyClassType::get() {
static AnyClassTypePtr value(new AnyClassType());
return value;
}
AnyEnumTypePtr AnyEnumType::get() {
const AnyEnumTypePtr& AnyEnumType::get() {
static AnyEnumTypePtr value(new AnyEnumType());
return value;
}
c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_union=false, TypePtr type_hint=nullptr) {
// check direct subtyping relation
if (t1->isSubtypeOf(t2)) {
if (t1->isSubtypeOf(*t2)) {
return t2;
} else if (t2->isSubtypeOf(t1)) {
} else if (t2->isSubtypeOf(*t1)) {
return t1;
}
@ -298,9 +298,9 @@ c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool
return t1->expectRef<TensorType>().merge(*t2->expect<TensorType>());
}
if (t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) {
if (t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get())) {
return OptionalType::create(t2);
} else if (t2->isSubtypeOf(NoneType::get()) && !t1->isSubtypeOf(NoneType::get())) {
} else if (t2->isSubtypeOf(*NoneType::get()) && !t1->isSubtypeOf(*NoneType::get())) {
return OptionalType::create(t1);
}
@ -351,16 +351,16 @@ c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool
auto t1_unshaped = unshapedType(t1);
auto t2_unshaped = unshapedType(t2);
if (t1_unshaped->isSubtypeOf(t2_unshaped)) {
if (t1_unshaped->isSubtypeOf(*t2_unshaped)) {
return t2_unshaped;
} else if (t2_unshaped->isSubtypeOf(t1_unshaped)) {
} else if (t2_unshaped->isSubtypeOf(*t1_unshaped)) {
return t1_unshaped;
}
// Check whether or not `type_hint` is a common parent. This case
// could occur if we had two class types that had been annotated with
// a common interface
if (type_hint && t1->isSubtypeOf(type_hint) && t2->isSubtypeOf(type_hint)) {
if (type_hint && t1->isSubtypeOf(*type_hint) && t2->isSubtypeOf(*type_hint)) {
return type_hint;
}
@ -505,7 +505,7 @@ MatchTypeReturn matchTypeVariables(
// NOLINTNEXTLINE(performance-no-automatic-move)
return optionedMatch;
}
} else if (!actual->isSubtypeOf(NoneType::get())) {
} else if (!actual->isSubtypeOf(*NoneType::get())) {
// If the actual type is a non-optional, allow matching to the formal if
// its element type matches the actual.
// Don't match None because it is already an optional (but one of
@ -594,19 +594,19 @@ const char * typeKindToString(TypeKind kind) {
return "";
}
bool Type::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
if (rhs->kind() == TypeKind::AnyType || *this == *rhs) {
bool Type::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
if (rhs.kind() == TypeKind::AnyType || *this == rhs) {
return true;
}
if (auto opt_rhs = rhs->cast<OptionalType>()) {
return this->isSubtypeOfExt(opt_rhs->getElementType(), why_not);
if (auto opt_rhs = rhs.castRaw<OptionalType>()) {
return this->isSubtypeOfExt(*opt_rhs->getElementType(), why_not);
}
if (auto union_rhs = rhs->cast<UnionType>()) {
if (auto union_rhs = rhs.castRaw<UnionType>()) {
// Check if `this` is a subtype of any of the types within the Union
return std::any_of(union_rhs->containedTypes().begin(),
union_rhs->containedTypes().end(),
[&](TypePtr inner) {
return this->isSubtypeOfExt(inner, why_not);
[&](const TypePtr& inner) {
return this->isSubtypeOfExt(*inner, why_not);
});
}
return false;
@ -837,8 +837,8 @@ TupleTypePtr TupleType::createNamed(const c10::optional<c10::QualifiedName>& qua
field_types, qualName, schema)); // NOLINT(modernize-make-shared)
}
bool NoneType::isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const {
if (rhs->kind() == OptionalType::Kind) {
bool NoneType::isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const {
if (rhs.kind() == OptionalType::Kind) {
return true;
}
return Type::isSubtypeOfExt(rhs, why_not);
@ -882,8 +882,8 @@ void filterDuplicateSubtypes(std::vector<TypePtr>* types) {
auto get_supertype = [](const TypePtr t1, const TypePtr t2) -> c10::optional<TypePtr> {
// We don't want nested Optionals. Also, prematurely unifying to
// `Optional` could prevent us from coalescing other types
if ((t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get()))
|| (!t1->isSubtypeOf(NoneType::get()) && t2->isSubtypeOf(NoneType::get()))) {
if ((t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get()))
|| (!t1->isSubtypeOf(*NoneType::get()) && t2->isSubtypeOf(*NoneType::get()))) {
return c10::nullopt;
} else {
return unifyTypes(t1, t2, /*default_to_union=*/false);
@ -1064,34 +1064,36 @@ bool UnionType::operator==(const Type& rhs) const {
}
}
bool UnionType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
std::vector<TypePtr> rhs_types;
if (const auto union_rhs = rhs->cast<UnionType>()) {
bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
std::vector<const Type*> rhs_types;
if (const auto union_rhs = rhs.cast<UnionType>()) {
// Fast path
if (this->containedTypes() == rhs->containedTypes()) {
if (this->containedTypes() == rhs.containedTypes()) {
return true;
}
rhs_types = rhs->containedTypes().vec();
} else if (const auto optional_rhs = rhs->cast<OptionalType>()) {
rhs_types.push_back(NoneType::get());
for (const auto& typePtr: rhs.containedTypes()) {
rhs_types.push_back(typePtr.get());
}
} else if (const auto optional_rhs = rhs.cast<OptionalType>()) {
rhs_types.push_back(NoneType::get().get());
if (optional_rhs->getElementType() == NumberType::get()) {
std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
std::array<const Type*, 3> number_types{IntType::get().get(), FloatType::get().get(), ComplexType::get().get()};
rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
} else {
rhs_types.push_back(optional_rhs->getElementType());
rhs_types.push_back(optional_rhs->getElementType().get());
}
} else if (const auto number_rhs = rhs->cast<NumberType>()) {
std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
} else if (const auto number_rhs = rhs.cast<NumberType>()) {
std::array<const Type*, 3> number_types{IntType::get().get(), FloatType::get().get(), ComplexType::get().get()};
rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
} else {
rhs_types.push_back(rhs);
rhs_types.push_back(&rhs);
}
return std::all_of(this->containedTypes().begin(), this->containedTypes().end(),
[&](TypePtr lhs_type) -> bool {
[&](const TypePtr& lhs_type) -> bool {
return std::any_of(rhs_types.begin(),
rhs_types.end(),
[&](TypePtr rhs_type) -> bool {
return lhs_type->isSubtypeOfExt(rhs_type, why_not);
[&](const Type* rhs_type) -> bool {
return lhs_type->isSubtypeOfExt(*rhs_type, why_not);
});
});
}
@ -1157,8 +1159,8 @@ bool UnionType::canHoldType(TypePtr type) const {
&& canHoldType(ComplexType::get());
} else {
return std::any_of(this->containedTypes().begin(), this->containedTypes().end(),
[&](TypePtr inner) {
return type->isSubtypeOf(inner);
[&](const TypePtr& inner) {
return type->isSubtypeOf(*inner);
});
}
}
@ -1184,10 +1186,10 @@ c10::optional<TypePtr> UnionType::subtractTypeSet(std::vector<TypePtr>& to_subtr
// Given a TypePtr `lhs`, this function says whether or not `lhs` (or
// one of its parent types) is in the `to_subtract` vector
auto should_subtract = [&](TypePtr lhs) -> bool {
auto should_subtract = [&](const TypePtr& lhs) -> bool {
return std::any_of(to_subtract.begin(), to_subtract.end(),
[&](TypePtr rhs) {
return lhs->isSubtypeOf(rhs);
[&](const TypePtr& rhs) {
return lhs->isSubtypeOf(*rhs);
});
};
@ -1195,7 +1197,7 @@ c10::optional<TypePtr> UnionType::subtractTypeSet(std::vector<TypePtr>& to_subtr
// vector
std::copy_if(this->containedTypes().begin(), this->containedTypes().end(),
std::back_inserter(types),
[&](const TypePtr t) {
[&](const TypePtr& t) {
return !should_subtract(t);
});
@ -1245,18 +1247,18 @@ bool OptionalType::operator==(const Type& rhs) const {
}
}
bool OptionalType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
if (OptionalTypePtr optional_rhs = rhs->cast<OptionalType>()) {
return getElementType()->isSubtypeOfExt(optional_rhs->getElementType(), why_not);
} else if (UnionTypePtr union_rhs = rhs->cast<UnionType>()) {
bool OptionalType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
if (auto optional_rhs = rhs.castRaw<OptionalType>()) {
return getElementType()->isSubtypeOfExt(*optional_rhs->getElementType(), why_not);
} else if (auto union_rhs = rhs.castRaw<UnionType>()) {
if (!union_rhs->canHoldType(NoneType::get())) {
if (why_not) {
*why_not << rhs->repr_str() << " cannot hold None";
*why_not << rhs.repr_str() << " cannot hold None";
}
return false;
} else if (!union_rhs->canHoldType(this->getElementType())) {
if (why_not) {
*why_not << rhs->repr_str() << " cannot hold " << this->getElementType();
*why_not << rhs.repr_str() << " cannot hold " << this->getElementType();
}
return false;
} else {
@ -1276,8 +1278,8 @@ bool NumberType::operator==(const Type& rhs) const {
}
}
bool NumberType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
if (auto union_type = rhs->cast<UnionType>()) {
bool NumberType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
if (auto union_type = rhs.cast<UnionType>()) {
return union_type->canHoldType(NumberType::get());
} else {
return Type::isSubtypeOfExt(rhs, why_not);
@ -1305,14 +1307,14 @@ TupleType::TupleType(
}
}
bool TupleType::isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const {
bool TupleType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const {
if (Type::isSubtypeOfExt(rhs_, why_not)) {
return true;
}
if (rhs_->kind() == AnyTupleType::Kind) {
if (rhs_.kind() == AnyTupleType::Kind) {
return true;
}
auto rhs = rhs_->cast<TupleType>();
auto rhs = rhs_.cast<TupleType>();
if (!rhs)
return false;
// unnamed tuple is not a subtype of nametuple
@ -1336,15 +1338,15 @@ bool TupleType::isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const
bool names_match = !rhs->schema() || test_names_match(schema(), rhs->schema());
// co-variant rules for tuples
return names_match && compare(*rhs, [&](const TypePtr a, const TypePtr b) {
return a->isSubtypeOfExt(b, why_not);
return a->isSubtypeOfExt(*b, why_not);
});
}
bool ListType::isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const {
bool ListType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const {
if (Type::isSubtypeOfExt(rhs_, why_not)) {
return true;
}
if (rhs_->kind() == AnyListType::Kind) {
if (rhs_.kind() == AnyListType::Kind) {
return true;
}
return false;
@ -1622,8 +1624,8 @@ const SymbolicShape& TensorType::symbolic_sizes() const {
return sizes_;
}
bool TensorType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
if (auto rhs_p = rhs->cast<TensorType>()) {
bool TensorType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
if (auto rhs_p = rhs.cast<TensorType>()) {
// if we have the same pointer, avoid computing the merge
if (this == rhs_p.get()) {
return true;
@ -2034,7 +2036,7 @@ ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
auto ptr = ClassType::create(name(), compilation_unit_, is_module());
AT_ASSERT(numAttributes() == refined_slots.size());
for (size_t i = 0; i < attributes_.size(); ++i) {
AT_ASSERT(refined_slots[i]->isSubtypeOf(attributes_[i].getType()));
AT_ASSERT(refined_slots[i]->isSubtypeOf(*attributes_[i].getType()));
ptr->addAttribute(attributes_[i].getName(), refined_slots[i], (attributes_[i].getKind() == AttributeKind::PARAMETER),
(attributes_[i].getKind() == AttributeKind::BUFFER));
}
@ -2045,18 +2047,18 @@ ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
return ptr;
}
bool ClassType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
if (rhs->cast<AnyClassType>()) {
bool ClassType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
if (rhs.castRaw<AnyClassType>()) {
return true;
}
// to improve performance, this check can be cached
if (auto iface = rhs->cast<InterfaceType>()) {
if (auto iface = rhs.cast<InterfaceType>()) {
// ClassType is not a subtype of InterfaceType if the InterfaceType is a
// Module Interface Type but the Class Type is not a Module Class Type
if (!is_module() && iface->is_module()) {
if (why_not) {
*why_not << "Class '" << repr_str() << "' is not a subtype of "
<< "the module interface '" << rhs->repr_str()
<< "the module interface '" << rhs.repr_str()
<< "' , only ScriptModule class can be subtype of module"
<< " interface.\n";
}
@ -2067,7 +2069,7 @@ bool ClassType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const
if (!self_method) {
if (why_not) {
*why_not << "Class '" << repr_str() << "' does not have method '"
<< schema.name() << "' but '" << rhs->repr_str()
<< schema.name() << "' but '" << rhs.repr_str()
<< "' does.\n";
}
return false;
@ -2078,7 +2080,7 @@ bool ClassType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const
if (why_not) {
*why_not << "Method on class '" << repr_str()
<< "' (1) is not compatible with interface '"
<< rhs->repr_str() << "' (2)\n"
<< rhs.repr_str() << "' (2)\n"
<< " (1) " << self_method->getSchema() << "\n"
<< " (2) " << schema << "\n";
}
@ -2131,9 +2133,9 @@ bool InterfaceType::isSubTypeImpl(
return true;
}
bool InterfaceType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
bool InterfaceType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
// to improve performance this check can be cached
if (auto iface = rhs->cast<InterfaceType>()) {
if (auto iface = rhs.cast<InterfaceType>()) {
return isSubTypeImpl(*this, *iface, why_not);
}
return Type::isSubtypeOfExt(rhs, why_not);
@ -2257,7 +2259,7 @@ size_t ClassType::addAttribute(
type->expect<OptionalType>()->getElementType()->kind() ==
TensorType::Kind) ||
(type->kind() == UnionType::Kind &&
TensorType::get()->isSubtypeOf(type->expect<UnionType>())) ||
TensorType::get()->isSubtypeOf(type->expectRef<UnionType>())) ||
(type->kind() == NoneType::Kind),
"Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ",
toString(type));
@ -2402,10 +2404,10 @@ void SymbolicShape::dump() const {
std::cout << *this << "\n";
}
bool EnumType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const {
return rhs->kind() == TypeKind::AnyType ||
rhs->kind() == TypeKind::AnyEnumType ||
*this == *rhs ||
bool EnumType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
return rhs.kind() == TypeKind::AnyType ||
rhs.kind() == TypeKind::AnyEnumType ||
*this == rhs ||
Type::isSubtypeOfExt(rhs, why_not);
}

View File

@ -46,7 +46,7 @@ class C10OperatorWrapper final : public Operator<Context> {
AT_ASSERT(
!has_preallocated_outputs_ ||
op_.schema().arguments().back().type()->isSubtypeOf(
OptionalType::create(ListType::ofTensors())));
*OptionalType::create(ListType::ofTensors())));
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
AT_ASSERT(operator_def.output_size() == op_.schema().returns().size());
@ -89,13 +89,13 @@ class C10OperatorWrapper final : public Operator<Context> {
AT_ASSERTM(
argument.type()->isSubtypeOf(
OptionalType::create(ListType::ofTensors())),
*OptionalType::create(ListType::ofTensors())),
"Error in caffe2->c10 wrapper: Operator schema has a parameter named ",
detail::PREALLOCATED_OUTPUT_ARGNAME,
", but it's not of type TensorList?");
stack_.emplace_back(preallocated_outputs_());
} else if (argument.type()->isSubtypeOf(TensorType::get())) {
} else if (argument.type()->isSubtypeOf(*TensorType::get())) {
AT_ASSERTM(
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
input_tensor_index < InputSize(),
@ -103,13 +103,13 @@ class C10OperatorWrapper final : public Operator<Context> {
InputSize(),
"), operator schema expected more.");
stack_.emplace_back(at::Tensor(Input(input_tensor_index++)));
} else if (argument.type()->isSubtypeOf(OptionalType::ofTensor())) {
} else if (argument.type()->isSubtypeOf(*OptionalType::ofTensor())) {
if (input_tensor_index < InputSize()) {
stack_.emplace_back(at::Tensor(Input(input_tensor_index++)));
} else {
stack_.emplace_back(IValue());
}
} else if (argument.type()->isSubtypeOf(ListType::ofTensors())) {
} else if (argument.type()->isSubtypeOf(*ListType::ofTensors())) {
AT_ASSERTM(
input_tensor_index == 0,
"Error in caffe2->c10 wrapper: Schema can only have either one or more Tensor inputs or one TensorList input.");
@ -163,13 +163,13 @@ class C10OperatorWrapper final : public Operator<Context> {
}
IValue get_nontensor_argument_(const c10::Argument& argument) {
if (argument.type()->isSubtypeOf(IntType::get())) {
if (argument.type()->isSubtypeOf(*IntType::get())) {
return get_nontensor_argument_<int>(
argument.name(), argument.default_value());
} else if (argument.type()->isSubtypeOf(FloatType::get())) {
} else if (argument.type()->isSubtypeOf(*FloatType::get())) {
return get_nontensor_argument_<double>(
argument.name(), argument.default_value());
} else if (argument.type()->isSubtypeOf(BoolType::get())) {
} else if (argument.type()->isSubtypeOf(*BoolType::get())) {
return get_nontensor_argument_<bool>(
argument.name(), argument.default_value());
} else {

View File

@ -57,7 +57,7 @@ inline void _call_caffe2_op_from_c10(
AT_ASSERT(
schema.arguments().size() != 0 &&
schema.arguments().back().type()->isSubtypeOf(
OptionalType::create(ListType::ofTensors())));
*OptionalType::create(ListType::ofTensors())));
IValue preallocated_outputs = torch::jit::pop(*stack);
const size_t num_outputs = schema.returns().size();

View File

@ -86,20 +86,20 @@ TEST(CustomOperatorTest, ListParameters) {
ASSERT_EQ(op->schema().arguments().size(), 4);
ASSERT_EQ(op->schema().arguments()[0].name(), "ints");
ASSERT_TRUE(
op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofInts()));
op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofInts()));
ASSERT_EQ(op->schema().arguments()[1].name(), "floats");
ASSERT_TRUE(
op->schema().arguments()[1].type()->isSubtypeOf(ListType::ofFloats()));
op->schema().arguments()[1].type()->isSubtypeOf(*ListType::ofFloats()));
ASSERT_EQ(op->schema().arguments()[2].name(), "complexdoubles");
ASSERT_TRUE(op->schema().arguments()[2].type()->isSubtypeOf(
ListType::ofComplexDoubles()));
*ListType::ofComplexDoubles()));
ASSERT_EQ(op->schema().arguments()[3].name(), "tensors");
ASSERT_TRUE(
op->schema().arguments()[3].type()->isSubtypeOf(ListType::ofTensors()));
op->schema().arguments()[3].type()->isSubtypeOf(*ListType::ofTensors()));
ASSERT_EQ(op->schema().returns().size(), 1);
ASSERT_TRUE(
op->schema().returns()[0].type()->isSubtypeOf(ListType::ofFloats()));
op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofFloats()));
Stack stack;
push(stack, c10::List<int64_t>({1, 2}));
@ -132,11 +132,11 @@ TEST(CustomOperatorTest, ListParameters2) {
ASSERT_EQ(op->schema().arguments().size(), 1);
ASSERT_EQ(op->schema().arguments()[0].name(), "tensors");
ASSERT_TRUE(
op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofTensors()));
op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofTensors()));
ASSERT_EQ(op->schema().returns().size(), 1);
ASSERT_TRUE(
op->schema().returns()[0].type()->isSubtypeOf(ListType::ofTensors()));
op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofTensors()));
Stack stack;
push(stack, c10::List<at::Tensor>({at::ones(5)}));

View File

@ -145,7 +145,7 @@ TEST(IRParserTest, InferredTypeIsTensor) {
graph(%a):
return (%a))IR",
&*graph);
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get()));
}
TEST(IRParserTest, ValueReuse) {
@ -260,7 +260,7 @@ TEST(IRParserTest, FileCheck) {
return (%a))IR";
parseIR(text, &*graph);
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get()));
torch::jit::testing::FileCheck().run(text, *graph);
}

View File

@ -26,14 +26,14 @@ TEST(JitTypeTest, UnifyTypes) {
auto bool_tensor = TensorType::get()->withScalarType(at::kBool);
auto opt_bool_tensor = OptionalType::create(bool_tensor);
auto unified_opt_bool = unifyTypes(bool_tensor, opt_bool_tensor);
TORCH_INTERNAL_ASSERT(opt_bool_tensor->isSubtypeOf(*unified_opt_bool));
TORCH_INTERNAL_ASSERT(opt_bool_tensor->isSubtypeOf(**unified_opt_bool));
auto tensor = TensorType::get();
TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(opt_bool_tensor));
TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(*opt_bool_tensor));
auto unified = unifyTypes(opt_bool_tensor, tensor);
TORCH_INTERNAL_ASSERT(unified);
auto elem = (*unified)->expectRef<OptionalType>().getElementType();
TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(TensorType::get()));
TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(*TensorType::get()));
auto opt_tuple_none_int = OptionalType::create(
TupleType::create({NoneType::get(), IntType::get()}));
@ -52,7 +52,7 @@ TEST(JitTypeTest, UnifyTypes) {
auto fut_out = unifyTypes(fut_1, fut_2);
TORCH_INTERNAL_ASSERT(fut_out);
TORCH_INTERNAL_ASSERT((*fut_out)->isSubtypeOf(
FutureType::create(OptionalType::create(IntType::get()))));
*FutureType::create(OptionalType::create(IntType::get()))));
auto dict_1 = DictType::create(IntType::get(), NoneType::get());
auto dict_2 = DictType::create(IntType::get(), IntType::get());

View File

@ -498,21 +498,21 @@ TEST(SchemaParserTest, NestedArrays) {
// nested arrays
auto s = parseSchema("at::what(int[][4] foo) -> ()");
ASSERT_TRUE(s.arguments().at(0).N() == 4);
ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments()
.at(0)
.type()
->expectRef<ListType>()
.getElementType()
->expectRef<ListType>()
.getElementType()));
ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments()
.at(0)
.type()
->expectRef<ListType>()
.getElementType()
->expectRef<ListType>()
.getElementType()));
auto s2 = parseSchema("at::what(int[][] foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments()
.at(0)
.type()
->expectRef<ListType>()
.getElementType()
->expectRef<ListType>()
.getElementType()));
ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments()
.at(0)
.type()
->expectRef<ListType>()
.getElementType()
->expectRef<ListType>()
.getElementType()));
}
TEST(SchemaParserTest, OutVariant) {
@ -550,7 +550,7 @@ TEST(SchemaParserTest, Futures) {
// futures
auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
ASSERT_TRUE(IntType::get()->isSubtypeOf(
s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
*s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
}
TEST(SchemaParserTest, AnnotatedAliasSets) {

View File

@ -116,8 +116,8 @@ TEST(SerializationTest, TypeTags) {
for (auto item : items) {
auto bytes = torch::pickle_save(item.value);
auto loaded = torch::pickle_load(bytes);
ASSERT_TRUE(loaded.type()->isSubtypeOf(item.expected_type));
ASSERT_TRUE(item.expected_type->isSubtypeOf(loaded.type()));
ASSERT_TRUE(loaded.type()->isSubtypeOf(*item.expected_type));
ASSERT_TRUE(item.expected_type->isSubtypeOf(*loaded.type()));
}
}

View File

@ -105,12 +105,12 @@ TEST_F(UnionTypeTest, Subtyping_NumberType) {
const NumberTypePtr num = NumberType::get();
ASSERT_TRUE(num->isSubtypeOf(union1));
ASSERT_TRUE(union1->isSubtypeOf(num));
ASSERT_TRUE(num->isSubtypeOf(*union1));
ASSERT_TRUE(union1->isSubtypeOf(*num));
ASSERT_TRUE(*num == *union1);
ASSERT_TRUE(num->isSubtypeOf(union2));
ASSERT_FALSE(union2->isSubtypeOf(num));
ASSERT_TRUE(num->isSubtypeOf(*union2));
ASSERT_FALSE(union2->isSubtypeOf(*num));
ASSERT_FALSE(*num == *union2);
}

View File

@ -185,7 +185,7 @@ void general_trace_function(
type = type->expectRef<OptionalType>().getElementType();
}
}
if (type->isSubtypeOf(TensorType::get())) {
if (type->isSubtypeOf(*TensorType::get())) {
AT_ASSERT(iter->isTensor());
tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
} else if (type->kind() == TypeKind::FloatType) {
@ -204,7 +204,7 @@ void general_trace_function(
tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
} else if (type->kind() == TypeKind::ListType) {
const auto& elem_type = type->expectRef<ListType>().getElementType();
if (elem_type->isSubtypeOf(TensorType::get())) {
if (elem_type->isSubtypeOf(*TensorType::get())) {
AT_ASSERT(iter->isTensorList());
auto list = iter->toTensorVector();
tracer::addInputs(node, args[i].name().c_str(), list);
@ -265,12 +265,12 @@ void general_trace_function(
for (auto iter = stack->end() - output_size; iter != stack->end();
++iter, ++i) {
const auto& type = op.schema().returns()[i].type();
if (type->isSubtypeOf(TensorType::get())) {
if (type->isSubtypeOf(*TensorType::get())) {
AT_ASSERT(iter->isTensor());
tracer::addOutput(node, iter->toTensor());
} else if (type->kind() == TypeKind::ListType) {
const auto& elem_type = type->expectRef<ListType>().getElementType();
if (elem_type->isSubtypeOf(TensorType::get())) {
if (elem_type->isSubtypeOf(*TensorType::get())) {
AT_ASSERT(iter->isTensorList());
tracer::addOutput(node, iter->toTensorList());
} else {

View File

@ -72,7 +72,7 @@ TypePtr tryInferTypeWithTypeHint(
TORCH_CHECK(
type_hint_ptr != nullptr &&
module.value().type()->isSubtypeOfExt(
type_hint_ptr, &subtype_check_msg),
*type_hint_ptr, &subtype_check_msg),
module.value().type()->repr_str(),
" is not a subtype of the type hint: ",
type_qualified_name.qualifiedName(),

View File

@ -348,9 +348,9 @@ c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
// since Tensor can only get specialized with a previous run of local
// JIT function, and we shouldn't preserve the specialized SubTensorType
// information on other workers because it's only information only.
if (type->isSubtypeOf(TensorType::get())) {
if (type->isSubtypeOf(*TensorType::get())) {
TORCH_INTERNAL_ASSERT(
ownerRRef->type()->isSubtypeOf(TensorType::get()),
ownerRRef->type()->isSubtypeOf(*TensorType::get()),
"Expect OwnerRRef to be a sub-type of TensorType, but got ",
ownerRRef->type()->repr_str());
} else {

View File

@ -535,7 +535,7 @@ struct TORCH_API BufferPolicy {
return std::move(v).toTensor();
}
static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) {
return typ->getAttribute(i)->isSubtypeOf(TensorType::get()) &&
return typ->getAttribute(i)->isSubtypeOf(*TensorType::get()) &&
typ->is_buffer(i);
}
static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false;

View File

@ -54,7 +54,7 @@ struct TORCH_API Object {
} else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) {
const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot);
TORCH_CHECK(
v.type()->isSubtypeOf(expected),
v.type()->isSubtypeOf(*expected),
"Expected a value of type '",
expected->repr_str(),
"' for field '",

View File

@ -114,7 +114,7 @@ struct CudaGraphFuser {
value_list tensorInputs(Node* node) {
return filter(node->inputs(), [](Value* v) {
return v->type()->isSubtypeOf(TensorType::get());
return v->type()->isSubtypeOf(*TensorType::get());
});
}
@ -210,7 +210,7 @@ struct CudaGraphFuser {
size_t tensor_insert_idx = 0;
for (auto input : group->inputs()) {
inputs_map[input] = subgraph.inputs()[i++];
if (input->type()->isSubtypeOf(TensorType::get()))
if (input->type()->isSubtypeOf(*TensorType::get()))
tensor_insert_idx = i;
}
// add n's inputs to the fusion group's input list if we don't already have
@ -222,7 +222,7 @@ struct CudaGraphFuser {
if (inputs_map.count(input) == 0) {
// TODO: we are following the convention for no good reason;
// we don't need tensor to come before any other inputs.
if (input->type()->isSubtypeOf(TensorType::get())) {
if (input->type()->isSubtypeOf(*TensorType::get())) {
auto in_group = subgraph.insertInput(tensor_insert_idx);
in_group->setType(input->type());
inputs_map[input] = in_group;
@ -230,7 +230,7 @@ struct CudaGraphFuser {
tensor_insert_idx++;
} else if (
// TODO: extend the supporting inputs here.
(input->type()->isSubtypeOf(FloatType::get()) &&
(input->type()->isSubtypeOf(*FloatType::get()) &&
input->node()->kind() != prim::Constant)) {
auto in_group = subgraph.addInput();
in_group->setType(input->type());
@ -440,7 +440,7 @@ struct CudaGraphFuser {
// Replace tensors inputs with broadcasted values
auto new_tensors_it = new_tensors.begin();
for (size_t i = 0; i < node->inputs().size(); ++i) {
if (node->inputs()[i]->type()->isSubtypeOf(TensorType::get())) {
if (node->inputs()[i]->type()->isSubtypeOf(*TensorType::get())) {
AT_ASSERT(new_tensors_it != new_tensors.end());
node->replaceInput(i, *(new_tensors_it++));
}
@ -595,7 +595,7 @@ struct CudaGraphFuser {
// XXX: we only work with pointwise ops in here, so we know it is valid to
// push the concat only through tensor arguments (and all other args can
// be safely ignored).
if (!input->type()->isSubtypeOf(TensorType::get()))
if (!input->type()->isSubtypeOf(*TensorType::get()))
continue;
// if 'input' is already an input to the bchunk, reuse it.
@ -688,7 +688,7 @@ struct CudaGraphFuser {
chunked_op->output()->setType(chunk_sel->type());
auto chunked_inputs_it = chunked_inputs.begin();
for (Value* original_input : original_inputs) {
if (original_input->type()->isSubtypeOf(TensorType::get())) {
if (original_input->type()->isSubtypeOf(*TensorType::get())) {
AT_ASSERT(chunked_inputs_it != chunked_inputs.end());
chunked_op->addInput(
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
@ -720,7 +720,7 @@ struct CudaGraphFuser {
if (!size_calc_uses.empty()) {
auto tensor_inputs = filter(
producer_for_chunk_node->inputs(),
[](Value* v) { return v->type()->isSubtypeOf(TensorType::get()); });
[](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); });
auto tensor_sizes = fmap(tensor_inputs, [](Value* v) {
return v->owningGraph()->insert(aten::size, {v});
});
@ -824,7 +824,7 @@ struct CudaGraphFuser {
auto sinputs = subgraph->inputs();
AT_ASSERT(inputs.size() == sinputs.size());
for (const auto i : c10::irange(inputs.size())) {
if (inputs[i]->type()->isSubtypeOf(TensorType::get())) {
if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]});
}
}
@ -963,7 +963,7 @@ struct CudaGraphFuser {
continue;
}
auto tensor_inputs = filter(n->inputs(), [](Value* v) {
return v->type()->isSubtypeOf(TensorType::get());
return v->type()->isSubtypeOf(*TensorType::get());
});
auto shapes = fmap(tensor_inputs, [&](Value* v) {
TORCH_INTERNAL_ASSERT(

View File

@ -34,7 +34,7 @@ bool hasNonElementWiseOperation(const Node* node) {
// 2. on the same device;
// TODO: update this when codegen can output scalar
static c10::optional<c10::Device> getDevice(const Value* value) {
if (!value->type()->isSubtypeOf(TensorType::get())) {
if (!value->type()->isSubtypeOf(*TensorType::get())) {
// not tensor type, return false as the op is not outputing scalar.
return c10::nullopt;
}

View File

@ -44,7 +44,7 @@ class NaiveTypePropagator {
switch (node->kind()) {
// Constant:
case prim::Constant: {
if (node->output()->type()->isSubtypeOf(TensorType::get())) {
if (node->output()->type()->isSubtypeOf(*TensorType::get())) {
node->output()->inferTypeFrom(node->t(attr::value));
}
break;

View File

@ -123,7 +123,7 @@ static std::vector<int64_t> getInputDependencies(const Value* output) {
// This needs to be revisited when you start allowing
// other things e.g. nonconstant scalars.
if (producer->kind() == prim::Param &&
val->type()->isSubtypeOf(TensorType::get())) {
val->type()->isSubtypeOf(*TensorType::get())) {
inputs.insert(val);
continue;
}
@ -230,10 +230,10 @@ std::shared_ptr<FusedKernel> compileKernel(
{
size_t input_index = 0;
for (const auto& p : graph->inputs()) {
if (p->type()->isSubtypeOf(FloatType::get())) {
if (p->type()->isSubtypeOf(*FloatType::get())) {
flat_inputs.emplace_back(p, c10::nullopt);
}
if (!p->type()->isSubtypeOf(TensorType::get())) {
if (!p->type()->isSubtypeOf(*TensorType::get())) {
continue;
}
if (const Node* chunk = usedInFusedChunk(p)) {

View File

@ -77,7 +77,7 @@ struct TORCH_API KernelSpec {
}
nTensorInputs_ = std::count_if(
graph_->inputs().begin(), graph_->inputs().end(), [](const Value* v) {
return v->type()->isSubtypeOf(TensorType::get());
return v->type()->isSubtypeOf(*TensorType::get());
});
}

View File

@ -385,7 +385,7 @@ struct Environment {
as_simple_value,
/*allow_conversions=*/true);
std::stringstream why_not;
if (!as_simple_value->type()->isSubtypeOfExt(parent_type, &why_not)) {
if (!as_simple_value->type()->isSubtypeOfExt(*parent_type, &why_not)) {
auto error = ErrorReport(loc);
error << "Variable '" << name << "' previously had type "
<< simple_parent->type()->repr_str()
@ -406,7 +406,7 @@ struct Environment {
}
if (as_simple_value) {
if (annotated_type &&
!as_simple_value->type()->isSubtypeOf(annotated_type)) {
!as_simple_value->type()->isSubtypeOf(*annotated_type)) {
throw ErrorReport(loc)
<< "Variable '" << name << "' is annotated with type "
<< annotated_type->repr_str()
@ -603,8 +603,8 @@ static Value* materializeConstant(
}
inline bool isSupportedListElementType(const TypePtr& type) {
return type->isSubtypeOf(TensorType::get()) ||
type->isSubtypeOf(NumberType::get());
return type->isSubtypeOf(*TensorType::get()) ||
type->isSubtypeOf(*NumberType::get());
}
// Information for each def being emitted.
@ -1023,8 +1023,8 @@ struct to_ir {
// this guard skips implicit conversion from None -> Tensor for the return
// type. otherwise forgetting a return a function returning a tensor will
// cause a None to be converted to a tensor.
if (!(actual_return->type()->isSubtypeOf(TensorType::get()) &&
actual_return->type()->isSubtypeOf(NoneType::get()))) {
if (!(actual_return->type()->isSubtypeOf(*TensorType::get()) &&
actual_return->type()->isSubtypeOf(*NoneType::get()))) {
actual_return = tryConvertToType(
stmt.range(),
*graph,
@ -1032,7 +1032,7 @@ struct to_ir {
actual_return,
/*allow_conversions=*/true);
}
if (!actual_return->type()->isSubtypeOf(declared_return_type)) {
if (!actual_return->type()->isSubtypeOf(*declared_return_type)) {
throw ErrorReport(stmt.range())
<< "Return value was annotated as having type "
<< declared_return_type->repr_str() << " but is actually of type "
@ -1423,8 +1423,8 @@ struct to_ir {
if (all_candidates.empty() && refined_type_hint &&
!(*unified_elem_type)
->isSubtypeOf(
refined_type_hint->expect<ListType>()->getElementType())) {
->isSubtypeOf(*refined_type_hint->expectRef<ListType>()
.getElementType())) {
throw ErrorReport(lc)
<< "List type annotation `" << refined_type_hint->repr_str()
<< "` did not match the types of the given list elements,"
@ -1439,7 +1439,7 @@ struct to_ir {
[&](TypePtr candidate) {
auto candidate_elem_type =
candidate->expect<ListType>()->getElementType();
if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) {
if ((*unified_elem_type)->isSubtypeOf(*candidate_elem_type)) {
if (!greatest_elem_type) {
greatest_elem_type = candidate_elem_type;
} else {
@ -1582,7 +1582,7 @@ struct to_ir {
std::stringstream err;
bool is_key_subtype =
k->type()->isSubtypeOfExt(dict_type_hint->getKeyType(), &ss);
k->type()->isSubtypeOfExt(*dict_type_hint->getKeyType(), &ss);
if (!is_key_subtype) {
err << "Dict type annotation `" << dict_type_hint->repr_str()
@ -1594,7 +1594,7 @@ struct to_ir {
ss.str(std::string());
bool is_value_subtype =
v->type()->isSubtypeOfExt(dict_type_hint->getValueType(), &ss);
v->type()->isSubtypeOfExt(*dict_type_hint->getValueType(), &ss);
if (!is_value_subtype) {
err << "Dict type annotation `" << dict_type_hint->repr_str()
@ -1651,11 +1651,11 @@ struct to_ir {
current_candidate->expect<DictType>()->getKeyType();
auto current_value_type =
current_candidate->expect<DictType>()->getValueType();
if (known_key_type->isSubtypeOf(current_key_type) &&
known_value_type->isSubtypeOf(current_value_type)) {
if (known_key_type->isSubtypeOf(*current_key_type) &&
known_value_type->isSubtypeOf(*current_value_type)) {
if (!candidate ||
(candidate_key_type->isSubtypeOf(current_key_type) &&
candidate_value_type->isSubtypeOf(current_value_type))) {
(candidate_key_type->isSubtypeOf(*current_key_type) &&
candidate_value_type->isSubtypeOf(*current_value_type))) {
candidate_key_type = current_key_type;
candidate_value_type = current_value_type;
candidate = current_candidate;
@ -1819,7 +1819,7 @@ struct to_ir {
<< v->type()->repr_str() << " to bool";
}
// cast value not response for checking output type
if (!out->type()->isSubtypeOf(BoolType::get())) {
if (!out->type()->isSubtypeOf(*BoolType::get())) {
throw ErrorReport(loc)
<< "expected a bool expression for condition but found "
<< out->type()->repr_str();
@ -2088,10 +2088,11 @@ struct to_ir {
break;
}
auto get_smaller_type = [&](TypePtr t1, TypePtr t2) -> TypePtr {
if (t1->isSubtypeOf(t2)) {
auto get_smaller_type = [&](const TypePtr& t1,
const TypePtr& t2) -> TypePtr {
if (t1->isSubtypeOf(*t2)) {
return t1;
} else if (t2->isSubtypeOf(t1)) {
} else if (t2->isSubtypeOf(*t1)) {
return t2;
} else {
return nullptr;
@ -2416,7 +2417,7 @@ struct to_ir {
<< "exceptions must derive from BaseException";
}
if (!error_message->type()->isSubtypeOf(StringType::get())) {
if (!error_message->type()->isSubtypeOf(*StringType::get())) {
error_message = graph->insert(aten::str, {error_message});
}
@ -2487,7 +2488,7 @@ struct to_ir {
// If the RHS is a tensor, return the corresponding ATen in-place op
// If it's a list of scalars, then return the corresponding list augment op
Symbol getAugOp(const AugAssign& stmt, const TypePtr& type) {
bool use_inplace_op = type->isSubtypeOf(TensorType::get()) ||
bool use_inplace_op = type->isSubtypeOf(*TensorType::get()) ||
type->kind() == TypeKind::ListType;
switch (stmt.aug_op()) {
case '+':
@ -2678,7 +2679,7 @@ struct to_ir {
const auto lhs = Subscript(stmt.lhs());
const auto sliceable = emitExpr(lhs.value());
if (sliceable->type()->isSubtypeOf(TensorType::get())) {
if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
// If it's a tensor, just fully evaluate the subscript operation and emit
// an in-place assignment
std::vector<Value*> tensorIndices;
@ -2765,7 +2766,7 @@ struct to_ir {
auto sliceable = emitExpr(lhs.value());
// If it's a tensor, copy the RHS data into it
if (sliceable->type()->isSubtypeOf(TensorType::get())) {
if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
std::vector<Value*> tensorIndices;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Value* sliced;
@ -2810,7 +2811,7 @@ struct to_ir {
<< " subscripted assignment. "
<< "File a bug if you want this";
}
if (sliceable->type()->isSubtypeOf(AnyTupleType::get())) {
if (sliceable->type()->isSubtypeOf(*AnyTupleType::get())) {
throw ErrorReport(lhs) << sliceable->type()->repr_str()
<< " does not support subscripted assignment";
}
@ -3270,7 +3271,7 @@ struct to_ir {
/*allow_conversions=*/true);
std::stringstream why_not;
if (!expr->type()->isSubtypeOfExt(type, &why_not)) {
if (!expr->type()->isSubtypeOfExt(*type, &why_not)) {
throw ErrorReport(apply.inputs())
<< "expected an expression of type " << type->repr_str()
<< " but found " << expr->type()->repr_str() << "\n"
@ -3284,7 +3285,7 @@ struct to_ir {
if ((type->kind() == OptionalType::Kind ||
(type->kind() == UnionType::Kind &&
type->expect<UnionType>()->canHoldType(NoneType::get()))) &&
expr->type()->isSubtypeOf(NoneType::get())) {
expr->type()->isSubtypeOf(*NoneType::get())) {
Node* none = graph->createNone();
none->output()->setType(type);
graph->insertNode(none);
@ -4094,7 +4095,7 @@ struct to_ir {
}
if (all_candidates.empty() && refined_type_hint &&
!(*unified_elem_type)->isSubtypeOf(inferred_elem_type)) {
!(*unified_elem_type)->isSubtypeOf(*inferred_elem_type)) {
throw ErrorReport(ll)
<< "List type annotation `" << refined_type_hint->repr_str()
<< "` did not match the types of the given list elements,"
@ -4109,7 +4110,7 @@ struct to_ir {
[&](TypePtr candidate) {
auto candidate_elem_type =
candidate->expect<ListType>()->getElementType();
if ((*unified_elem_type)->isSubtypeOf(candidate_elem_type)) {
if ((*unified_elem_type)->isSubtypeOf(*candidate_elem_type)) {
if (!greatest_elem_type) {
greatest_elem_type = candidate_elem_type;
} else {
@ -4349,8 +4350,8 @@ struct to_ir {
for (const auto i : c10::irange(keys.size())) {
std::stringstream ss;
if (!keys[i]->type()->isSubtypeOfExt(first_key_type, &ss) &&
!first_key_type->isSubtypeOfExt(keys[i]->type(), &ss)) {
if (!keys[i]->type()->isSubtypeOfExt(*first_key_type, &ss) &&
!first_key_type->isSubtypeOfExt(*keys[i]->type(), &ss)) {
throw ErrorReport(key_trees[i])
<< "Dict keys must contain "
<< "only a single type. Expected: "
@ -4386,11 +4387,11 @@ struct to_ir {
current_candidate->expect<DictType>()->getKeyType();
auto current_value_type =
current_candidate->expect<DictType>()->getValueType();
if (known_key_type->isSubtypeOf(current_key_type) &&
known_value_type->isSubtypeOf(current_value_type)) {
if (known_key_type->isSubtypeOf(*current_key_type) &&
known_value_type->isSubtypeOf(*current_value_type)) {
if (!candidate ||
(candidate_key_type->isSubtypeOf(current_key_type) &&
candidate_value_type->isSubtypeOf(current_value_type))) {
(candidate_key_type->isSubtypeOf(*current_key_type) &&
candidate_value_type->isSubtypeOf(*current_value_type))) {
candidate_key_type = current_key_type;
candidate_value_type = current_value_type;
candidate = current_candidate;
@ -4439,7 +4440,7 @@ struct to_ir {
refined_type_hint->expect<DictType>()->getValueType();
for (const auto i : c10::irange(value_types.size())) {
TORCH_CHECK(
value_types[i]->isSubtypeOf(value_type_hint),
value_types[i]->isSubtypeOf(*value_type_hint),
"Type "
"hint for dict was ",
refined_type_hint->repr_str(),
@ -4525,11 +4526,11 @@ struct to_ir {
// XXX: If list slicing becomes more complicated or stops using
// aten::slice, we should separate it from this function.
if (dim) {
AT_ASSERT(sliceable->type()->isSubtypeOf(TensorType::get()));
AT_ASSERT(sliceable->type()->isSubtypeOf(*TensorType::get()));
args.emplace_back(dim);
} else {
AT_ASSERT(!sliceable->type()->isSubtypeOf(TensorType::get()));
AT_ASSERT(!sliceable->type()->isSubtypeOf(*TensorType::get()));
}
if (sliceable->type()->cast<TupleType>()) {
@ -4691,7 +4692,7 @@ struct to_ir {
}
exprs[expr_idx] = index;
if (index->type()->isSubtypeOf(NoneType::get())) {
if (index->type()->isSubtypeOf(*NoneType::get())) {
if (is_reverse) {
return dim;
} else {
@ -4703,7 +4704,7 @@ struct to_ir {
} else {
return dim;
}
} else if (index->type()->isSubtypeOf(OptionalType::ofTensor())) {
} else if (index->type()->isSubtypeOf(*OptionalType::ofTensor())) {
if (is_reverse) {
throw ErrorReport(loc)
<< "Ellipses followed by tensor indexing is currently not supported";
@ -4768,13 +4769,13 @@ struct to_ir {
continue;
}
auto expr = exprs[i].value();
if (expr->type()->isSubtypeOf(NoneType::get())) {
if (expr->type()->isSubtypeOf(*NoneType::get())) {
sliceable =
emitUnsqueeze(loc, sliceable, insert_value_for_dim(dims[i]));
} else if (expr->type() == IntType::get()) {
sliceable =
emitSelect(loc, sliceable, insert_value_for_dim(dims[i]), expr);
} else if (expr->type()->isSubtypeOf(OptionalType::ofTensor())) {
} else if (expr->type()->isSubtypeOf(*OptionalType::ofTensor())) {
tensor_indices.resize(dims[i] + 1);
tensor_indices[dims[i]] = expr;
} else {
@ -4814,7 +4815,7 @@ struct to_ir {
const SourceRange& loc,
Value* sliceable,
const List<Expr>& subscript_exprs) {
if (!sliceable->type()->isSubtypeOf(TensorType::get())) {
if (!sliceable->type()->isSubtypeOf(*TensorType::get())) {
throw ErrorReport(loc)
<< "Unsupported operation: attempted to use multidimensional "
<< "indexing on a non-tensor type";
@ -4842,7 +4843,7 @@ struct to_ir {
AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR);
auto slice_exp = SliceExpr(subscript_exprs[0]);
Value* maybe_dim = nullptr;
if (sliceable->type()->isSubtypeOf(TensorType::get())) {
if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
// If the sliceable object is a tensor, specify a default dimension
maybe_dim = graph->insertConstant(0, loc);
}
@ -5009,7 +5010,7 @@ struct to_ir {
dynamic_cast<SliceValue*>(subscript_sv.get())) {
Value* dim = nullptr;
// aten::slice.tensor needs an additional `dim` input.
if (sliceable->type()->isSubtypeOf(TensorType::get())) {
if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
dim = method.graph()->insertConstant(0, val_range);
}
@ -5030,7 +5031,7 @@ struct to_ir {
if (sliceable->type()->cast<TupleType>()) {
return std::make_shared<SimpleValue>(
emitTupleIndex(range, sv->asValue(val_range, method), idx));
} else if (sliceable->type()->isSubtypeOf(TensorType::get())) {
} else if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
return std::make_shared<SimpleValue>(
emitMultidimSlicing(range, sliceable, subscript_exprs));
} else {

View File

@ -30,21 +30,21 @@ static inline bool isIntOrFloatUsedAsList(
/// Returns true if `type` is a Tuple in which all the elements have the
/// same type or if it's a subtype of `list_type_`.
inline bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
auto list_type = list_type_->cast<ListType>();
bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
auto list_type = list_type_->castRaw<ListType>();
if (!list_type) {
return false;
}
if (type->isSubtypeOf(list_type_)) {
if (type->isSubtypeOf(*list_type_)) {
return true;
}
if (auto tuple = type->cast<TupleType>()) {
if (auto tuple = type->castRaw<TupleType>()) {
return std::all_of(
tuple->elements().begin(),
tuple->elements().end(),
[&](const TypePtr& t) {
// TODO: resolve VarType if necessary
return t->isSubtypeOf(list_type->getElementType());
return t->isSubtypeOf(*list_type->getElementType());
});
}
return false;
@ -61,7 +61,7 @@ Value* tryConvertToType(
// treat conversion to Optional[T] as conversions to T
if (OptionalTypePtr op = concrete_type->cast<OptionalType>()) {
if (value->type()->kind() != OptionalType::Kind &&
!value->type()->isSubtypeOf(NoneType::get())) {
!value->type()->isSubtypeOf(*NoneType::get())) {
return tryConvertToType(
loc, graph, op->getElementType(), value, allow_conversions);
}
@ -79,7 +79,7 @@ Value* tryConvertToType(
// inductively apply implicit conversions to tuples
if (auto concrete_tuple = concrete_type->cast<TupleType>()) {
if (!value_tuple->isSubtypeOf(concrete_tuple) &&
if (!value_tuple->isSubtypeOf(*concrete_tuple) &&
concrete_tuple->elements().size() == value_tuple->elements().size()) {
auto unpacked = createTupleUnpack(value);
std::vector<Value*> converted;
@ -99,7 +99,7 @@ Value* tryConvertToType(
// implicit conversions
if (allow_conversions) {
// Convert tensor or number to concrete int/float types
bool value_isa_tensor = value->type()->isSubtypeOf(TensorType::get());
bool value_isa_tensor = value->type()->isSubtypeOf(*TensorType::get());
bool value_equals_number = *value->type() == *NumberType::get();
bool concrete_float = *concrete_type == *FloatType::get();
bool concrete_complex = *concrete_type == *ComplexType::get();
@ -126,8 +126,8 @@ Value* tryConvertToType(
}
// Convert strings to device
if (value->type()->isSubtypeOf(StringType::get()) &&
concrete_type->isSubtypeOf(DeviceObjType::get())) {
if (value->type()->isSubtypeOf(*StringType::get()) &&
concrete_type->isSubtypeOf(*DeviceObjType::get())) {
return graph.insert(aten::device, {value}, {}, loc);
}
}
@ -185,7 +185,7 @@ static Value* tryMatchArgument(
value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions);
std::stringstream ss;
if (!value->type()->isSubtypeOfExt(
concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) {
*concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) {
if (failure_messages) {
auto& ostream = err()
<< arg.formatTypeMismatchMsg(value->type()->repr_str());
@ -203,7 +203,7 @@ static Value* tryMatchArgument(
}
if (auto v = value->type()->cast<ListType>()) {
if (v->getElementType()->isSubtypeOf(TensorType::get())) {
if (v->getElementType()->isSubtypeOf(*TensorType::get())) {
ostream << "Empty lists default to List[Tensor]. Add a variable "
"annotation to the assignment to create an empty list "
"of another type (torch.jit.annotate(List[T, []]) where T "

View File

@ -88,7 +88,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
Function& m,
const std::string& field) {
// Allow method-style casts on Tensor types. e.g. x.int()
if (value_->type()->isSubtypeOf(TensorType::get())) {
if (value_->type()->isSubtypeOf(*TensorType::get())) {
if (builtin_cast_method_to_scalar_type().count(field)) {
return std::make_shared<TensorCastValue>(
builtin_cast_method_to_scalar_type().at(field),
@ -202,7 +202,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
}
// Handle calling tolist() on a Tensor.
if (value_->type()->isSubtypeOf(TensorType::get()) && field == "tolist") {
if (value_->type()->isSubtypeOf(*TensorType::get()) && field == "tolist") {
return SpecialFormValue::create(prim::tolist);
}
@ -252,7 +252,7 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
}
static bool isRecursive(const TypePtr& classType, const TypePtr& attrType) {
if (attrType->isSubtypeOf(classType)) {
if (attrType->isSubtypeOf(*classType)) {
return true;
}
@ -334,7 +334,7 @@ void SimpleValue::setAttr(
// Check type correctness
const auto newType = newValue->type();
if (!newType->isSubtypeOf(expectedType)) {
if (!newType->isSubtypeOf(*expectedType)) {
throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
<< expectedType->repr_str() << " but got "
<< newType->repr_str();
@ -390,7 +390,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
TypePtr val_type = val->type();
Graph& g = *m.graph();
if (val_type->cast<ListType>() || val_type->cast<StringType>() ||
val_type->isSubtypeOf(TensorType::get())) {
val_type->isSubtypeOf(*TensorType::get())) {
return g.insert(aten::len, {val}, {}, loc);
} else {
throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
@ -415,7 +415,7 @@ SugaredValuePtr SimpleValue::getitem(
} else if (auto dict_type = val_type->cast<DictType>()) {
return std::make_shared<SimpleValue>(
g.insert(aten::__getitem__, {val, idx}, {}, loc));
} else if (val_type->isSubtypeOf(TensorType::get())) {
} else if (val_type->isSubtypeOf(*TensorType::get())) {
return std::make_shared<SimpleValue>(
g.insert(aten::select, {val, 0, idx}, {}, loc));
} else if (auto class_type = val_type->cast<ClassType>()) {
@ -702,7 +702,7 @@ std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
continue;
}
const auto concrete_type = tryEvalTypeVariables(formal_type, type_env);
if (!concrete_type || !self->type()->isSubtypeOf(concrete_type)) {
if (!concrete_type || !self->type()->isSubtypeOf(*concrete_type)) {
continue;
}
return std::make_shared<BuiltinFunction>(symbol, self);

View File

@ -517,12 +517,12 @@ struct TORCH_API CastValue : public BuiltinFunction {
auto zero = m.graph()->insertConstant(0);
auto v = args[0].value(*m.graph());
if (v->type()->isSubtypeOf(type_)) {
if (v->type()->isSubtypeOf(*type_)) {
return std::make_shared<SimpleValue>(v);
} else if (
*type_ == *BoolType::get() &&
(v->type()->isSubtypeOf(AnyListType::get()) ||
v->type()->isSubtypeOf(StringType::get()) ||
(v->type()->isSubtypeOf(*AnyListType::get()) ||
v->type()->isSubtypeOf(*StringType::get()) ||
v->type()->cast<DictType>())) {
auto len = len_op->call(loc, m, {v}, {}, 1);
return gt_op->call(loc, m, {len->asValue(loc, m), zero}, {}, 1);
@ -766,7 +766,7 @@ struct TORCH_API ExceptionValue : public SugaredValue {
auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc);
for (auto& input : args) {
auto input_str = input.value(*m.graph());
if (!input_str->type()->isSubtypeOf(StringType::get())) {
if (!input_str->type()->isSubtypeOf(*StringType::get())) {
input_str =
emitBuiltinCall(loc, *m.graph(), aten::str, {input_str}, {});
}

View File

@ -285,15 +285,15 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
TypePtr key_type = dict.keyType();
TypePtr value_type = dict.valueType();
bool key_type_valid = key_type->isSubtypeOf(StringType::get()) ||
key_type->isSubtypeOf(TensorType::get());
bool value_type_valid = value_type->isSubtypeOf(TensorType::get());
bool key_type_valid = key_type->isSubtypeOf(*StringType::get()) ||
key_type->isSubtypeOf(*TensorType::get());
bool value_type_valid = value_type->isSubtypeOf(*TensorType::get());
// Support tuple values that contain only tensors
if (value_type->isSubtypeOf(AnyTupleType::get())) {
if (value_type->isSubtypeOf(*AnyTupleType::get())) {
value_type_valid = true;
for (const auto& type : value_type->containedTypes()) {
if (!type->isSubtypeOf(TensorType::get())) {
if (!type->isSubtypeOf(*TensorType::get())) {
value_type_valid = false;
break;
}
@ -330,7 +330,7 @@ static IValue addInput(
const TypePtr& type,
Value* value) {
value->setType(type);
if (type->isSubtypeOf(TensorType::get())) {
if (type->isSubtypeOf(*TensorType::get())) {
auto input_tensor = input.toTensor();
auto name = Variable(input_tensor).name();
if (state->hasValue(input)) {
@ -426,7 +426,7 @@ static void gatherParametersAndBuffers(
->s_(attr::scope, qualname)
->output()
->setType(s.value.type());
if (s.value.type()->isSubtypeOf(TensorType::get())) {
if (s.value.type()->isSubtypeOf(*TensorType::get())) {
addInput(state, s.value, s.value.type(), trace_get_attr);
}
if (isCustomClass(s.value)) {

View File

@ -153,20 +153,20 @@ c10::optional<IValue> toIValue(const Value* v) {
}
const Node* node = v->node();
const TypePtr& type = v->type();
if (type->isSubtypeOf(TensorType::get())) {
if (type->isSubtypeOf(*TensorType::get())) {
return node->t(attr::value);
} else if (type->isSubtypeOf(BoolType::get())) {
} else if (type->isSubtypeOf(*BoolType::get())) {
return (bool)node->i(attr::value);
} else if (
type->isSubtypeOf(NumberType::get()) &&
type->isSubtypeOf(*NumberType::get()) &&
node->kindOf(attr::value) == AttributeKind::i) {
return node->i(attr::value);
} else if (
type->isSubtypeOf(NumberType::get()) &&
type->isSubtypeOf(*NumberType::get()) &&
node->kindOf(attr::value) == AttributeKind::f) {
return node->f(attr::value);
} else if (
type->isSubtypeOf(NumberType::get()) &&
type->isSubtypeOf(*NumberType::get()) &&
node->kindOf(attr::value) == AttributeKind::c) {
return node->c(attr::value);
} else if (

View File

@ -1001,7 +1001,7 @@ bool Node::matches(const FunctionSchema& schema) const {
// we will not succeed at matching T. However None <: Optional[T] so this
// check can still succeed.
if (!actuals[i]->type()->isSubtypeOf(formal)) {
if (!actuals[i]->type()->isSubtypeOf(*formal)) {
return false;
}
}
@ -1773,7 +1773,7 @@ Node* Graph::createList(
auto n = create(prim::ListConstruct, values);
for (const auto& v : values) {
TORCH_CHECK(
v->type()->isSubtypeOf(contained_type),
v->type()->isSubtypeOf(*contained_type),
"Expected a list element that subtypes '",
contained_type->repr_str(),
"' but got an element of type '",
@ -1803,8 +1803,8 @@ Node* Graph::createDict(
AT_ASSERT(keys.size() == values.size());
auto n = create(prim::DictConstruct, 1);
for (const auto i : c10::irange(keys.size())) {
AT_ASSERT(keys[i]->type()->isSubtypeOf(key_type));
AT_ASSERT(values[i]->type()->isSubtypeOf(value_type));
AT_ASSERT(keys[i]->type()->isSubtypeOf(*key_type));
AT_ASSERT(values[i]->type()->isSubtypeOf(*value_type));
n->addInput(keys[i]);
n->addInput(values[i]);

View File

@ -406,7 +406,7 @@ void IRParser::parseOperator(Block* b) {
// Don't currently support checking against type variables
// TODO: support?
if (!schema_return_type->hasFreeVariables() &&
!v.type->isSubtypeOf(schema_return_type)) {
!v.type->isSubtypeOf(*schema_return_type)) {
throw ErrorReport(source_range)
<< "Annotated type " << v.type->repr_str()
<< " does not match schema type "

View File

@ -209,18 +209,18 @@ size_t HashNode::operator()(const Node* k) const {
size_t constant_hash = 0;
if (k->kind() == prim::Constant) {
TypePtr type = k->output()->type();
if (type->isSubtypeOf(NumberType::get()) &&
if (type->isSubtypeOf(*NumberType::get()) &&
k->kindOf(attr::value) == AttributeKind::i) {
constant_hash = std::hash<int64_t>{}(k->i(attr::value));
} else if (
type->isSubtypeOf(NumberType::get()) &&
type->isSubtypeOf(*NumberType::get()) &&
k->kindOf(attr::value) == AttributeKind::f) {
constant_hash = std::hash<double>{}(k->f(attr::value));
} else if (
type->isSubtypeOf(NumberType::get()) &&
type->isSubtypeOf(*NumberType::get()) &&
k->kindOf(attr::value) == AttributeKind::c) {
constant_hash = c10::hash<c10::complex<double>>{}(k->c(attr::value));
} else if (type->isSubtypeOf(BoolType::get())) {
} else if (type->isSubtypeOf(*BoolType::get())) {
constant_hash = std::hash<bool>{}(k->i(attr::value));
}
}

View File

@ -35,7 +35,7 @@ void createObject(Stack& stack, const at::ClassTypePtr& type) {
void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
at::TypePtr ty = pop(stack).type();
for (const at::TypePtr& candidate : types) {
if (ty->isSubtypeOf(candidate)) {
if (ty->isSubtypeOf(*candidate)) {
push(stack, true);
return;
}

View File

@ -8,7 +8,7 @@ namespace jit {
static void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) {
for (auto i : graph->inputs()) {
if (i->type()->isSubtypeOf(TensorType::get())) {
if (i->type()->isSubtypeOf(*TensorType::get())) {
i->setType(unshapedType(i->type()));
}
}
@ -24,7 +24,7 @@ static void unprofileBlock(Block* start_block) {
for (auto n : block->nodes()) {
for (auto o : n->outputs()) {
if (o->type()->isSubtypeOf(TensorType::get())) {
if (o->type()->isSubtypeOf(*TensorType::get())) {
o->setType(unshapedType(o->type()));
}
}

View File

@ -21,7 +21,7 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {
// statically defined (neither a None constant nor a Optional[Tensor] type)
// return yes, no, or no value if we can't tell
c10::optional<bool> isDefined(Value* tensor) {
if (tensor->type()->isSubtypeOf(TensorType::get())) {
if (tensor->type()->isSubtypeOf(*TensorType::get())) {
return true;
}
if (tensor->node()->mustBeNone()) {
@ -36,7 +36,7 @@ bool isDecomposableNorm(Node* normalize_op) {
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
};
Value* input = normalize_op->namedInput(attr::input);
if (!input->type()->isSubtypeOf(TensorType::get())) {
if (!input->type()->isSubtypeOf(*TensorType::get())) {
return false;
}
auto device = input->type()->expectRef<TensorType>().device();

View File

@ -8,9 +8,9 @@ namespace torch {
namespace jit {
void SetNumTypeToTensorType(Value* v) {
if (v->type()->isSubtypeOf(NumberType::get())) {
if (v->type()->isSubtypeOf(*NumberType::get())) {
v->setType(TensorType::fromNumberType(v->type()));
} else if (v->type()->isSubtypeOf(BoolType::get())) {
} else if (v->type()->isSubtypeOf(*BoolType::get())) {
v->setType(TensorType::fromBoolType());
}
}
@ -28,10 +28,10 @@ void EraseNumberTypesOnBlock(Block* block) {
case prim::Constant: {
// remove primitive constants, replacing with tensor equivalent
// ONNX does not support non-tensor constants
if (it->output()->type()->isSubtypeOf(NumberType::get()) ||
it->output()->type()->isSubtypeOf(BoolType::get())) {
if (it->output()->type()->isSubtypeOf(*NumberType::get()) ||
it->output()->type()->isSubtypeOf(*BoolType::get())) {
at::Scalar s;
if (it->output()->type()->isSubtypeOf(BoolType::get())) {
if (it->output()->type()->isSubtypeOf(*BoolType::get())) {
s = *constant_as<bool>(it->output());
} else {
s = *constant_as<at::Scalar>(it->output());

View File

@ -543,7 +543,7 @@ class AttributePropagator {
bool moduleEscapes(Module& subModule, std::shared_ptr<Graph>& graph) {
for (auto& output : graph->outputs()) {
if (subModule.type()->isSubtypeOf(output->type())) {
if (subModule.type()->isSubtypeOf(*output->type())) {
return true;
}
}

View File

@ -755,7 +755,7 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
body_node->replaceInput(1, node->outputs().at(1));
}
if (body_node->kind() == aten::mul &&
body_node->input(1)->type()->isSubtypeOf(NumberType::get())) {
body_node->input(1)->type()->isSubtypeOf(*NumberType::get())) {
body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNScalarMul"));
body_node->destroy();
continue;
@ -1022,7 +1022,7 @@ class MKLDNNSubgraphSlicer {
if (n->kind() == aten::mul) {
return n->input(0)->type()->cast<TensorType>() &&
(n->input(1)->type()->cast<TensorType>() ||
n->input(1)->type()->isSubtypeOf(NumberType::get()));
n->input(1)->type()->isSubtypeOf(*NumberType::get()));
}
if (n->kind() == aten::dropout) {

View File

@ -111,8 +111,8 @@ bool isSimpleMap(Node* node) {
return false;
}
for (Value* input : node->inputs()) {
if (input->type()->isSubtypeOf(TensorType::get()) ||
input->type()->isSubtypeOf(FloatType::get())) {
if (input->type()->isSubtypeOf(*TensorType::get()) ||
input->type()->isSubtypeOf(*FloatType::get())) {
continue;
}
if (input->node()->kind() != prim::Constant) {
@ -166,7 +166,7 @@ struct GraphFuser {
value_list tensorInputs(Node* node) {
return filter(node->inputs(), [](Value* v) {
return v->type()->isSubtypeOf(TensorType::get());
return v->type()->isSubtypeOf(*TensorType::get());
});
}
@ -175,7 +175,7 @@ struct GraphFuser {
}
bool isFusableDevice(Value* v, bool strict_fuser_check) {
if (!v->type()->isSubtypeOf(TensorType::get())) {
if (!v->type()->isSubtypeOf(*TensorType::get())) {
return true;
}
auto device = v->type()->expectRef<TensorType>().device();
@ -332,7 +332,7 @@ struct GraphFuser {
AT_ASSERT(group->inputs().size() == subgraph.inputs().size());
for (auto input : group->inputs()) {
inputs_map[input] = subgraph.inputs()[i++];
if (input->type()->isSubtypeOf(TensorType::get()))
if (input->type()->isSubtypeOf(*TensorType::get()))
tensor_insert_idx = i;
}
// add n's inputs to the fusion group's input list if we don't already have
@ -342,17 +342,17 @@ struct GraphFuser {
WithInsertPoint guard(*subgraph.nodes().begin());
for (auto input : n->inputs()) {
if (inputs_map.count(input) == 0) {
if (input->type()->isSubtypeOf(TensorType::get())) {
if (input->type()->isSubtypeOf(*TensorType::get())) {
auto in_group = subgraph.insertInput(tensor_insert_idx);
in_group->setType(input->type());
inputs_map[input] = in_group;
group->insertInput(tensor_insert_idx, input);
tensor_insert_idx++;
} else if (
(input->type()->isSubtypeOf(FloatType::get()) &&
(input->type()->isSubtypeOf(*FloatType::get()) &&
input->node()->kind() != prim::Constant) ||
(n->kind() == aten::_grad_sum_to_size &&
input->type()->isSubtypeOf(ListType::ofInts()))) {
input->type()->isSubtypeOf(*ListType::ofInts()))) {
auto in_group = subgraph.addInput();
in_group->setType(input->type());
inputs_map[input] = in_group;
@ -618,7 +618,7 @@ struct GraphFuser {
// Replace tensors inputs with broadcasted values
auto new_tensors_it = new_tensors.begin();
for (size_t i = 0; i < node->inputs().size(); ++i) {
if (node->inputs()[i]->type()->isSubtypeOf(TensorType::get())) {
if (node->inputs()[i]->type()->isSubtypeOf(*TensorType::get())) {
AT_ASSERT(new_tensors_it != new_tensors.end());
node->replaceInput(i, *(new_tensors_it++));
}
@ -768,7 +768,7 @@ struct GraphFuser {
// XXX: we only work with pointwise ops in here, so we know it is valid to
// push the concat only through tensor arguments (and all other args can
// be safely ignored).
if (!input->type()->isSubtypeOf(TensorType::get()))
if (!input->type()->isSubtypeOf(*TensorType::get()))
continue;
// if 'input' is already an input to the bchunk, reuse it.
@ -813,7 +813,7 @@ struct GraphFuser {
chunked_op->output()->setType(chunk_sel->type());
auto chunked_inputs_it = chunked_inputs.begin();
for (Value* original_input : original_inputs) {
if (original_input->type()->isSubtypeOf(TensorType::get())) {
if (original_input->type()->isSubtypeOf(*TensorType::get())) {
AT_ASSERT(chunked_inputs_it != chunked_inputs.end());
chunked_op->addInput(
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
@ -845,7 +845,7 @@ struct GraphFuser {
if (!size_calc_uses.empty()) {
auto tensor_inputs = filter(
producer_for_chunk_node->inputs(),
[](Value* v) { return v->type()->isSubtypeOf(TensorType::get()); });
[](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); });
auto tensor_sizes = fmap(tensor_inputs, [&](Value* v) {
Value* output = v->owningGraph()->insert(aten::size, {v});
aliasDb_->createValue(output);
@ -937,7 +937,7 @@ struct GraphFuser {
auto sinputs = subgraph->inputs();
AT_ASSERT(inputs.size() == sinputs.size());
for (const auto i : c10::irange(inputs.size())) {
if (inputs[i]->type()->isSubtypeOf(TensorType::get())) {
if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
Value* soutput = graph->insert(aten::size, {inputs[i]});
aliasDb_->createValue(soutput);
shape_of[sinputs[i]] = soutput;
@ -992,7 +992,7 @@ struct GraphFuser {
continue;
}
auto tensor_inputs = filter(n->inputs(), [](Value* v) {
return v->type()->isSubtypeOf(TensorType::get());
return v->type()->isSubtypeOf(*TensorType::get());
});
auto shapes =
fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); });

View File

@ -234,7 +234,7 @@ struct GuardElimination {
if ((input->node()->kind() == prim::Guard &&
!input->type()->expectRef<TensorType>().isSummarized()) ||
input->node()->kind() == prim::Constant ||
(allow_numbers && input->type()->isSubtypeOf(NumberType::get())) ||
(allow_numbers && input->type()->isSubtypeOf(*NumberType::get())) ||
except.count(i) != 0) {
AT_ASSERT(
input->node()->kind() != prim::Guard ||

View File

@ -304,7 +304,7 @@ void LoopsPeeler::peelLoops() {
bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph) {
auto peel_predicate = [](Node* n) {
for (auto i : n->inputs()) {
if (i->type()->isSubtypeOf(TensorType::get())) {
if (i->type()->isSubtypeOf(*TensorType::get())) {
return true;
}
}

View File

@ -63,7 +63,7 @@ void checkONNXCompatibility(const c10::FunctionSchema& schema) {
if (type->kind() == TypeKind::ListType) {
const auto& elem_type =
reinterpret_cast<ListType*>(type.get())->getElementType();
if (elem_type->isSubtypeOf(TensorType::get())) {
if (elem_type->isSubtypeOf(*TensorType::get())) {
AT_ASSERTM(
!has_tensor_list,
"ONNX export supports at most one TensorList as input.");
@ -97,7 +97,7 @@ void preprocessCaffe2Ops(Block* block) {
origin_input->mustBeNone()) {
continue;
}
if (type->isSubtypeOf(TensorType::get())) {
if (type->isSubtypeOf(*TensorType::get())) {
it->addInput(origin_input);
} else if (
type->kind() == TypeKind::BoolType ||
@ -119,7 +119,7 @@ void preprocessCaffe2Ops(Block* block) {
AT_ASSERT(
list_node->kind() == prim::ListConstruct ||
list_node->kind() == prim::Constant);
if (elem_type->isSubtypeOf(TensorType::get())) {
if (elem_type->isSubtypeOf(*TensorType::get())) {
AT_ASSERT(list_node->kind(), prim::ListConstruct);
const auto& tensor_list = origin_input->node()->inputs();
for (const auto& t : tensor_list) {

View File

@ -47,7 +47,7 @@ bool IsCondCastRequired(Value* cond_val) {
return *scalar_type != c10::kBool;
}
}
return !type->isSubtypeOf(BoolType::get());
return !type->isSubtypeOf(*BoolType::get());
}
bool IsErasableSequence(const Node* loop_node, size_t i) {
@ -337,7 +337,7 @@ void ONNXFixupUninitializedOutput(Node* node, int opset_version) {
// Check if the input to ONNX If node is node Bool, and insert
// cast to Bool if needed.
if (!if_node->input()->type()->isSubtypeOf(BoolType::get())) {
if (!if_node->input()->type()->isSubtypeOf(*BoolType::get())) {
Node* cast_node = CreateCastToBoolNode(if_node->input(), graph);
cast_node->insertBefore(if_node);
if_node->replaceInputWith(if_node->input(), cast_node->output());

View File

@ -389,7 +389,7 @@ void unpackQuantizedWeightsHelper(
auto input_val = match_vmap.at(vmap.at("r"))->node()->inputs()[0];
TORCH_INTERNAL_ASSERT(
input_val->type()->isSubtypeOf(TensorType::get()),
input_val->type()->isSubtypeOf(*TensorType::get()),
"Unsupported input type. Expected TensorType, got ",
input_val->type()->str());

View File

@ -74,7 +74,7 @@ struct PeepholeOptimizeImpl {
for (Use u : uses) {
if (u.user->matches(
"aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") &&
u.user->input(1)->type()->isSubtypeOf(ListType::ofInts())) {
u.user->input(1)->type()->isSubtypeOf(*ListType::ofInts())) {
GRAPH_UPDATE(
getHeader(node),
" (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with ",

View File

@ -185,7 +185,7 @@ struct PeepholeOptimizeNonTensorImpl {
// losing anything by calling unshapedType here
auto input_type = unshapedType(node->input()->type());
auto output_type = unshapedType(node->output()->type());
if (input_type->isSubtypeOf(output_type)) {
if (input_type->isSubtypeOf(*output_type)) {
GRAPH_UPDATE(
"Removing ",
getHeader(node),

View File

@ -345,14 +345,14 @@ std::vector<Value*> getPassThroughInputs(Value* v) {
return inputs;
} else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
// only propagate dequantize for Tensor
if (v->type()->isSubtypeOf(TensorType::get())) {
if (v->type()->isSubtypeOf(*TensorType::get())) {
return {n->input(0)};
} else {
return {};
}
} else if (
n->kind() == prim::ListConstruct &&
v->type()->isSubtypeOf(ListType::ofTensors())) {
v->type()->isSubtypeOf(*ListType::ofTensors())) {
std::vector<Value*> inputs;
for (auto* v : n->inputs()) {
inputs.push_back(v);
@ -361,7 +361,7 @@ std::vector<Value*> getPassThroughInputs(Value* v) {
} else if (n->kind() == prim::TupleConstruct) {
std::vector<Value*> inputs;
for (auto* input : n->inputs()) {
if (input->type()->isSubtypeOf(TensorType::get())) {
if (input->type()->isSubtypeOf(*TensorType::get())) {
inputs.push_back(input);
}
}
@ -560,8 +560,8 @@ bool alwaysRaisesException(Block* block) {
// Check if a value in the graph is a Scalar value
bool isScalar(Value* v) {
auto iv = toIValue(v);
return v->type()->isSubtypeOf(NumberType::get()) ||
(v->type()->isSubtypeOf(TensorType::get()) && iv && iv->isTensor() &&
return v->type()->isSubtypeOf(*NumberType::get()) ||
(v->type()->isSubtypeOf(*TensorType::get()) && iv && iv->isTensor() &&
iv->toTensor().dim() == 0);
}

View File

@ -1185,8 +1185,8 @@ bool InsertObserversHelper::valueNeedsToBeQuantized(
Value* v,
const QConfig& qconfig) {
if (isBiasOfConvOrLinear(v) ||
!(v->type()->isSubtypeOf(TensorType::get()) ||
v->type()->isSubtypeOf(ListType::ofTensors())) ||
!(v->type()->isSubtypeOf(*TensorType::get()) ||
v->type()->isSubtypeOf(*ListType::ofTensors())) ||
isEmbeddingBagNonInput(v)) {
return false;
}

View File

@ -1024,7 +1024,7 @@ std::tuple<c10::QScheme, QParamVector> InsertQuantDeQuantHelper::
// TODO: refactor findObserverName to take Node* as input
Value* v = n->output();
TORCH_INTERNAL_ASSERT(
v->type()->isSubtypeOf(TensorType::get()),
v->type()->isSubtypeOf(*TensorType::get()),
"Expected output of observer node to be Tensor");
auto observer_name = findObserverName(v);
TORCH_INTERNAL_ASSERT(
@ -1352,7 +1352,7 @@ void InsertQuantDeQuantHelper::runWeightObserver(
blocks_to_visit.pop();
for (auto n : b->nodes()) {
for (Value* v : n->outputs()) {
if (!v->type()->isSubtypeOf(TensorType::get())) {
if (!v->type()->isSubtypeOf(*TensorType::get())) {
continue;
}
auto observer_name = findObserverName(v);
@ -1416,7 +1416,7 @@ void InsertQuantDeQuantHelper::run(
std::vector<Value*> input_values;
for (const auto idx : c10::irange(1, method.num_inputs())) {
auto& v = graph->inputs()[idx];
if (v->type()->isSubtypeOf(TensorType::get())) {
if (v->type()->isSubtypeOf(*TensorType::get())) {
input_values.push_back(v);
}
}
@ -1429,7 +1429,7 @@ void InsertQuantDeQuantHelper::run(
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
Node* n = *it++;
for (Value* v : n->outputs()) {
if (!v->type()->isSubtypeOf(TensorType::get())) {
if (!v->type()->isSubtypeOf(*TensorType::get())) {
continue;
}
collectObserverNodesAndValueToQuantize(module, v);

View File

@ -63,18 +63,18 @@ bool isValidArgumentForRunning(Value* v) {
}
return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false);
}
return v->type()->isSubtypeOf(FloatType::get());
return v->type()->isSubtypeOf(*FloatType::get());
}
bool isValidReturnForRunning(Value* v) {
return v->type()->isSubtypeOf(TensorType::get()) ||
v->type()->isSubtypeOf(NumberType::get());
return v->type()->isSubtypeOf(*TensorType::get()) ||
v->type()->isSubtypeOf(*NumberType::get());
}
bool containsTensorType(const TypePtr& t) {
auto n_contained = t->containedTypes().size();
if (n_contained == 1) {
return t->containedTypes().at(0)->isSubtypeOf(TensorType::get());
return t->containedTypes().at(0)->isSubtypeOf(*TensorType::get());
} else if (n_contained > 1) {
return std::any_of(
t->containedTypes().begin(),
@ -129,7 +129,7 @@ class ShapePropagator {
// ops which take the result and write to input "out"
if (auto out_arg_index = n->schema().argumentIndexWithName("out")) {
auto arg = n->schema().arguments().at(*out_arg_index);
return arg.kwarg_only() && arg.type()->isSubtypeOf(TensorType::get());
return arg.kwarg_only() && arg.type()->isSubtypeOf(*TensorType::get());
}
return false;
}
@ -183,7 +183,7 @@ class ShapePropagator {
.zero_();
}
// fallthrough
} else if (type_->isSubtypeOf(FloatType::get())) {
} else if (type_->isSubtypeOf(*FloatType::get())) {
return 0.f;
}
// we should not get here because isValidArgumentForRunning should have
@ -214,9 +214,9 @@ class ShapePropagator {
return c10::nullopt;
}
for (const auto i : c10::irange(args.size())) {
if (args[i].type()->isSubtypeOf(ListType::ofTensors())) {
if (args[i].type()->isSubtypeOf(*ListType::ofTensors())) {
return c10::nullopt;
} else if (args[i].type()->isSubtypeOf(TensorType::get())) {
} else if (args[i].type()->isSubtypeOf(*TensorType::get())) {
if (auto type = node->input(i)->type()->cast<TensorType>()) {
if (complete && !type->isComplete()) {
return c10::nullopt;
@ -617,11 +617,11 @@ class ShapePropagator {
return; // correct num type is already set
case prim::NumToTensor: {
TypePtr typ = node->input()->type();
if (typ->isSubtypeOf(IntType::get()) ||
typ->isSubtypeOf(BoolType::get())) {
if (typ->isSubtypeOf(*IntType::get()) ||
typ->isSubtypeOf(*BoolType::get())) {
node->output()->setType(TensorType::create(
at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
} else if (node->input()->type()->isSubtypeOf(FloatType::get())) {
} else if (node->input()->type()->isSubtypeOf(*FloatType::get())) {
node->output()->setType(TensorType::create(
at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
}
@ -632,7 +632,7 @@ class ShapePropagator {
// as_tensor has an overloaded schema and can either have a tensor or
// a list as the first input, if the input is a tensor, we delegate
// the shape propagation in PropagateTensorShapeOnNode
if (node->inputs().at(0)->type()->isSubtypeOf(TensorType::get())) {
if (node->inputs().at(0)->type()->isSubtypeOf(*TensorType::get())) {
break;
}
return propagateTorchTensorShape(node);
@ -659,7 +659,7 @@ class ShapePropagator {
return;
}
case prim::Constant: {
if (node->output()->type()->isSubtypeOf(TensorType::get())) {
if (node->output()->type()->isSubtypeOf(*TensorType::get())) {
node->output()->inferTypeFrom(node->t(attr::value));
}
return;
@ -668,7 +668,7 @@ class ShapePropagator {
// If we have specialized the optional type to the element type,
// we want to pass it down. We write this as input.isSubtypeOf(output)
// to be sure that we don't screw up nested optionals.
if (node->input()->type()->isSubtypeOf(node->output()->type())) {
if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
node->output()->setType(node->input()->type());
}
return;
@ -709,7 +709,7 @@ class ShapePropagator {
// If we have specialized the optional type to the element type,
// we want to pass it down. We write this as input.isSubtypeOf(output)
// to be sure that we don't screw up nested optionals.
if (node->input()->type()->isSubtypeOf(node->output()->type())) {
if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
node->output()->setType(node->input()->type());
}
return;
@ -1610,7 +1610,7 @@ class ShapePropagator {
auto outputs = node->outputs();
AT_ASSERT(types.size() == outputs.size());
for (const auto i : c10::irange(types.size())) {
AT_ASSERT(outputs[i]->type()->isSubtypeOf(TensorType::get()));
AT_ASSERT(outputs[i]->type()->isSubtypeOf(*TensorType::get()));
outputs[i]->setType(types[i]);
}
return true;
@ -2194,7 +2194,7 @@ using TypeCache = std::unordered_map<TypePtr, TypePtr>;
TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache);
TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) {
if (type->isSubtypeOf(TensorType::get())) {
if (type->isSubtypeOf(*TensorType::get())) {
return TensorType::get();
}
std::vector<TypePtr> unshaped_contained_types;

View File

@ -142,8 +142,8 @@ struct AutogradZeroSpecializer {
state_[input] = State::Unknown;
}
} else if (
tp->isSubtypeOf(TensorType::get()) ||
tp->isSubtypeOf(ListType::ofTensors())) {
tp->isSubtypeOf(*TensorType::get()) ||
tp->isSubtypeOf(*ListType::ofTensors())) {
state_[input] = State::Nonzero;
} else {
state_[input] = State::Unknown;

View File

@ -189,7 +189,7 @@ struct SymbolicShapeAnalyzer {
graph_->inputs().at(i)->type()->cast<OptionalType>()) {
// None will get handled with constant substitution later
if (!type->cast<OptionalType>() &&
!NoneType::get()->isSubtypeOf(type)) {
!NoneType::get()->isSubtypeOf(*type)) {
graph_->inputs().at(i)->setType(opt_type->getElementType());
}
} else if (graph_->inputs().at(i)->type()->cast<NumberType>()) {

View File

@ -485,7 +485,7 @@ class TensorExprFuser {
auto sinputs = subgraph->inputs();
AT_ASSERT(inputs.size() == sinputs.size());
for (const auto i : c10::irange(inputs.size())) {
if (inputs[i]->type()->isSubtypeOf(TensorType::get())) {
if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
Value* soutput = graph->insert(aten::size, {inputs[i]});
aliasDb_->createValue(soutput);
GRAPH_DEBUG(
@ -542,7 +542,7 @@ class TensorExprFuser {
}
auto tensor_inputs = filter(n->inputs(), [](Value* v) {
return v->type()->isSubtypeOf(TensorType::get());
return v->type()->isSubtypeOf(*TensorType::get());
});
GRAPH_DEBUG("Building sizes for ", getHeader(n));
bool all_inputs_have_sizes = true;

View File

@ -209,7 +209,7 @@ c10::optional<IValue> toIValueProp(const Value* v) {
} else if (containedType == FloatType::get()) {
return IValue(
fmap(genericList, [](const IValue& v) { return v.toDouble(); }));
} else if (containedType->isSubtypeOf(TensorType::get())) {
} else if (containedType->isSubtypeOf(*TensorType::get())) {
return IValue(
fmap(genericList, [](const IValue& v) { return v.toTensor(); }));
} else {

View File

@ -270,7 +270,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
}
// check if the classType conform with the interface or not
std::stringstream why_not;
if (!classType->isSubtypeOfExt(interfaceType, &why_not)) {
if (!classType->isSubtypeOfExt(*interfaceType, &why_not)) {
throw py::cast_error(c10::str(
"Object of type ",
classType->repr_str(),

View File

@ -510,7 +510,7 @@ inline InferredType tryToInferContainerType(py::handle input) {
}
inline bool isTraceableType(const TypePtr& type) {
if (type->isSubtypeOf(TensorType::get())) {
if (type->isSubtypeOf(*TensorType::get())) {
return true;
}

View File

@ -1010,7 +1010,7 @@ std::shared_ptr<SugaredValue> PythonSliceClass::call(
Graph& graph = *(caller.graph());
auto ValOr = [&](Value* given, int64_t default_val) {
if (!given || given->type()->isSubtypeOf(NoneType::get())) {
if (!given || given->type()->isSubtypeOf(*NoneType::get())) {
return graph.insertConstant(default_val, loc);
}
return given;

View File

@ -29,10 +29,10 @@ void ArgumentSpecCreator::scan(
if (depth >= ARG_SPEC_DEPTH_LIMIT) {
instructions_.emplace_back(SKIP);
}
if (typ->isSubtypeOf(TensorType::get())) {
if (typ->isSubtypeOf(*TensorType::get())) {
num_tensors_++;
instructions_.emplace_back(SPECIALIZE_TENSOR);
} else if (typ->isSubtypeOf(OptionalType::ofTensor())) {
} else if (typ->isSubtypeOf(*OptionalType::ofTensor())) {
num_tensors_++;
num_optionals_++;
instructions_.emplace_back(SPECIALIZE_OPTIONAL_TENSOR);

View File

@ -218,7 +218,7 @@ class GradientHelper {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Value* input_list;
if (grad_values.size() == 1 &&
grad_values[0]->type()->isSubtypeOf(ListType::ofTensors())) {
grad_values[0]->type()->isSubtypeOf(*ListType::ofTensors())) {
input_list = grad_values[0];
} else {
input_list =

View File

@ -108,7 +108,7 @@ ProfileIValueOp* ProfilingRecord::createProfileIValueNode(Value* in_val) {
static void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) {
for (auto i : graph->inputs()) {
if (i->type()->isSubtypeOf(TensorType::get())) {
if (i->type()->isSubtypeOf(*TensorType::get())) {
i->setType(unshapedType(i->type()));
}
}
@ -124,7 +124,7 @@ static void unprofileBlock(Block* start_block) {
for (auto n : block->nodes()) {
for (auto o : n->outputs()) {
if (o->type()->isSubtypeOf(TensorType::get())) {
if (o->type()->isSubtypeOf(*TensorType::get())) {
o->setType(unshapedType(o->type()));
}
}
@ -298,7 +298,7 @@ void ProfilingRecord::instrumentBlock(Block* block) {
for (size_t offset = 0; offset < block->return_node()->inputs().size();
offset++) {
auto i = block->return_node()->input(offset);
if (i->type()->isSubtypeOf(TensorType::get())) {
if (i->type()->isSubtypeOf(*TensorType::get())) {
insertShapeProfile(block->return_node(), offset);
}
}

View File

@ -32,13 +32,13 @@ c10::AliasAnalysisKind aliasAnalysisConservative() {
}
void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
if (!elem_type->isSubtypeOf(NumberType::get()) &&
if (!elem_type->isSubtypeOf(*NumberType::get()) &&
elem_type != BoolType::get()) {
std::stringstream error;
error << "Input must be of ints, floats, or bools, "
<< "got " << elem_type->repr_str();
// special case empty list torch.tensor([])
if (elem_type->isSubtypeOf(TensorType::get())) {
if (elem_type->isSubtypeOf(*TensorType::get())) {
if (empty_list) {
error << "\nEmpty lists default to List[Tensor]. Add a variable "
"annotation to the assignment to create an empty list "

View File

@ -432,7 +432,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
auto* node = p_node->node();
std::vector<TypePtr> candidates = node->tys(attr::types);
for (const auto& candidate_type : candidates) {
if (input_type->isSubtypeOf(candidate_type)) {
if (input_type->isSubtypeOf(*candidate_type)) {
p_node->Output(0) = true;
return;
}

View File

@ -341,7 +341,7 @@ void createObject(
void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
at::TypePtr ty = pop(stack).type();
for (const at::TypePtr& candidate : types) {
if (ty->isSubtypeOf(candidate)) {
if (ty->isSubtypeOf(*candidate)) {
push(stack, true);
return;
}

View File

@ -724,7 +724,7 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
set_schema.returns().size(),
" return values");
TORCH_CHECK(
set_schema.returns().at(0).type()->isSubtypeOf(NoneType::get()),
set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
"'__setstate__' must return None, but found value of type",
set_schema.returns().at(0).type()->annotation_str());
@ -734,7 +734,7 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
auto set_type = set_schema.arguments().at(1).type();
TORCH_CHECK(
get_type->isSubtypeOf(set_type),
get_type->isSubtypeOf(*set_type),
"'__getstate__'s return type (",
get_type->annotation_str(),
") does not match '__setstate__'s argument type (",

View File

@ -1117,7 +1117,7 @@ struct PythonPrintImpl {
// we cannot recover the type of unwrap_optional(None),
// using normal schema matching, so we route around this by rewriting
// the call to unwrap_optional(annotated(Optional[T], None))
if (node->input()->type()->isSubtypeOf(NoneType::get()) ||
if (node->input()->type()->isSubtypeOf(*NoneType::get()) ||
node->input()->mustBeNone()) {
auto input_type = OptionalType::create(node->output()->type());
stmt << "annotate(" << input_type->annotation_str(type_printer_)

View File

@ -377,11 +377,11 @@ std::vector<ExprHandle> TensorExprKernel::sizesForValue(
}
}
if (v->type()->isSubtypeOf(FloatType::get()) ||
v->type()->isSubtypeOf(IntType::get())) {
if (v->type()->isSubtypeOf(*FloatType::get()) ||
v->type()->isSubtypeOf(*IntType::get())) {
return {int64_t{1}};
}
if (v->type()->isSubtypeOf(NoneType::get())) {
if (v->type()->isSubtypeOf(*NoneType::get())) {
return {};
}

View File

@ -344,7 +344,7 @@ class class_ : public ::torch::detail::class_base {
auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema();
auto arg_type = setstate_schema.arguments().at(1).type();
TORCH_CHECK(
ser_type->isSubtypeOf(arg_type),
ser_type->isSubtypeOf(*arg_type),
"__getstate__'s return type should be a subtype of "
"input argument of __setstate__. Got ",
ser_type->repr_str(),