diff --git a/aten/src/ATen/ParallelNative.cpp b/aten/src/ATen/ParallelNative.cpp index 81cc66abcf1b..899a98eace7c 100644 --- a/aten/src/ATen/ParallelNative.cpp +++ b/aten/src/ATen/ParallelNative.cpp @@ -308,7 +308,7 @@ c10::intrusive_ptr intraop_launch_future( #else // TODO: caffe2::PThreadPool only provides a data-parallel API. // Task parallelism is not currently supported. - auto future = c10::make_intrusive(NoneType::get()); + auto future = c10::make_intrusive(c10::dynT()); func(); future->markCompleted(); return future; diff --git a/aten/src/ATen/core/custom_class.cpp b/aten/src/ATen/core/custom_class.cpp index d78599c7b46f..f61766c0cef3 100644 --- a/aten/src/ATen/core/custom_class.cpp +++ b/aten/src/ATen/core/custom_class.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -102,7 +103,7 @@ class_base::class_base( { detail::checkValidIdent(namespaceName, "Namespace name"); detail::checkValidIdent(className, "Class name"); - classTypePtr->addAttribute("capsule", at::CapsuleType::get()); + classTypePtr->addAttribute("capsule", c10::TypeFactory::get()); c10::getCustomClassTypeMap().insert( {std::type_index(intrusivePtrClassTypeid), classTypePtr}); c10::getCustomClassTypeMap().insert( diff --git a/aten/src/ATen/core/dynamic_type.cpp b/aten/src/ATen/core/dynamic_type.cpp index 88bc63a37bb6..95050da593eb 100644 --- a/aten/src/ATen/core/dynamic_type.cpp +++ b/aten/src/ATen/core/dynamic_type.cpp @@ -2,6 +2,7 @@ #include +#include #include #include #include @@ -198,6 +199,11 @@ TypePtr DynamicType::containedType(size_t i) const { return arguments_.elems.at(i).ty; } +size_t DynamicType::containedTypeSize() const { + TORCH_INTERNAL_ASSERT(tag_ != Tag::Class); + return arguments_.elems.size(); +} + TypeKind DynamicType::dynamicKind() const { switch (tag_) { #define CASE_TYPE(T, _, __) \ @@ -271,6 +277,16 @@ TypePtr DynamicType::fallback() const { return VarType::create(*name_); case Tag::AnyClass: return AnyClassType::get(); + case Tag::QScheme: + return QSchemeType::get(); + case Tag::Quantizer: + return QuantizerType::get(); + case Tag::AnyEnum: + return AnyEnumType::get(); + case Tag::RRef: + return RRefType::create(arguments_.elems[0].ty->fallback()); + case Tag::Future: + return FutureType::create(arguments_.elems[0].ty->fallback()); case Tag::Any: return AnyType::get(); } diff --git a/aten/src/ATen/core/dynamic_type.h b/aten/src/ATen/core/dynamic_type.h index 2b53aad3559b..d5551c9a5e51 100644 --- a/aten/src/ATen/core/dynamic_type.h +++ b/aten/src/ATen/core/dynamic_type.h @@ -3,8 +3,6 @@ #include #include -#include -#include #include #include @@ -53,8 +51,17 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10); _(Storage, DYNAMIC_TYPE_BIT(16), 1) \ _(Var, DYNAMIC_TYPE_BIT(17), 0) \ _(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \ + _(QScheme, DYNAMIC_TYPE_BIT(18), 1) \ + _(Quantizer, DYNAMIC_TYPE_BIT(19), 1) \ + _(AnyEnum, DYNAMIC_TYPE_BIT(20), 1) \ + _(RRef, DYNAMIC_TYPE_BIT(21), 0) \ + _(Future, DYNAMIC_TYPE_BIT(22), 0) \ _(Any, 0xffffffff, 1) +#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type; + FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE) +#undef FORWARD_DECL_TYPE + class DynamicType; using DynamicTypePtr = std::shared_ptr; @@ -142,6 +149,7 @@ class DynamicType : public SharedType { explicit DynamicType(Tag, c10::string_view, Arguments); TypePtr containedType(size_t) const override; + size_t containedTypeSize() const override; Tag tag() const { return tag_; } @@ -154,6 +162,9 @@ class DynamicType : public SharedType { TypeKind dynamicKind() const; // Should be used only on the server side to restore static type information. +#ifndef C10_MOBILE + TORCH_API +#endif TypePtr fallback() const; private: @@ -188,7 +199,7 @@ class DynamicType : public SharedType { template struct DynamicTypeTrait { - static auto tagValue() { + C10_NOINLINE static auto tagValue() { TORCH_CHECK(false); return DynamicType::Tag::Any; } @@ -201,7 +212,7 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag); #define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \ template <> \ struct TORCH_API DynamicTypeTrait { \ - static auto tagValue() { \ + C10_ERASE static auto tagValue() { \ return DynamicType::Tag::NAME; \ } \ static constexpr bool isBaseType = IS_BASE_TYPE; \ @@ -214,19 +225,4 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag); FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE) #undef DYNAMIC_TYPE_TAG_VALUE -template <> -struct IValue::TagType { - static DynamicType::Ptr get(const c10::IValue& v); -}; - -namespace ivalue { - -template <> -struct TORCH_API TupleTypeFactory { - static DynamicTypePtr create(std::vector elemTypes); - static DynamicTypePtr fallback(const Type&); -}; - -} // namespace ivalue - } // namespace c10 diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 3494fcbd29d2..737a98ab088f 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -390,7 +390,7 @@ struct FunctionSchema { // Check that inputs have the correct types and appends any missing default // values. - template + template void checkAndNormalizeInputs( std::vector& inputs, const std::unordered_map& kwargs = diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index a814399a0801..5d58ee88a418 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -293,7 +293,7 @@ inline void FunctionSchema::checkArg( TORCH_CHECK( false, formatTypeMismatchMsg( - argument, value.type()->repr_str(), pos)); + argument, value.type()->repr_str(), pos)); } } diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 966c9d7d48b3..3140fe8f3972 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -403,6 +404,39 @@ bool IValue::is(const IValue& rhs) const { return lhs == rhs; } +template +inline bool IValue::isListOf() const { + // note: avoids calling type() to avoid extra referencing counting for the returned type. + if (!isList()) { + return false; + } + const auto& ty = static_cast(payload.u.as_intrusive_ptr)->elementType; + if (ty->kind() == T::Kind) { + return true; + } + return *ty == *TypeFactory::get(); +} + +bool IValue::isDoubleList() const { + return isListOf(); +} + +bool IValue::isComplexDoubleList() const { + return isListOf(); +} + +bool IValue::isTensorList() const { + return isListOf(); +} + +bool IValue::isIntList() const { + return isListOf(); +} + +bool IValue::isBoolList() const { + return isListOf(); +} + namespace { using IValueFormatter = std::function; @@ -430,7 +464,7 @@ std::ostream& printMaybeAnnotatedList( std::ostream& out, const IValue& the_list, IValueFormatter formatter) { - auto list_elem_type = the_list.type()->expectRef().getElementType(); + auto list_elem_type = the_list.type()->containedType(0); if (the_list.toListRef().size() == 0 || !elementTypeCanBeInferredFromMembers(list_elem_type)) { out << "annotate(" << the_list.type()->annotation_str() << ", "; @@ -925,7 +959,7 @@ c10::intrusive_ptr ivalue::Object::deepcopy(IValue::HashAliasedI auto cu = type_.cu_; auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes()); for (const auto i : c10::irange(slots_.size())) { - if (slots_[i].type() == c10::CapsuleType::get()) { + if (*slots_[i].type() == *c10::TypeFactory::get()) { // If we've gotten here, it means that we have *not* copied this // class via __getstate__ and __setstate__. That fact and the // fact that we have a Capsule attribute mean that this is a diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 1b155866cb8f..cb0d433a693e 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -895,8 +896,8 @@ public: } } - template - typename T::Ptr type() const; + template + TypePtr type() const; // Detect aliased tensors. struct HashAliasedIValue { diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 87a3229139c1..6c524da40ed2 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -586,6 +586,12 @@ struct TORCH_API TupleTypeFactory { static TupleTypePtr fallback(const Type& type); }; +template <> +struct TORCH_API TupleTypeFactory { + static DynamicTypePtr create(std::vector elemTypes); + static DynamicTypePtr fallback(const Type&); +}; + struct TORCH_API Tuple : c10::intrusive_ptr_target { private: TupleElements elements_; @@ -1915,39 +1921,6 @@ inline ivalue::Tuple& IValue::toTupleRef() const { payload.u.as_intrusive_ptr); } -template -inline bool IValue::isListOf() const { - // note: avoids calling type() to avoid extra referencing counting for the returned type. - if (!isList()) { - return false; - } - const auto& ty = static_cast(payload.u.as_intrusive_ptr)->elementType; - if (ty->kind() == T::Kind) { - return true; - } - return *ty == *T::get(); -} - -inline bool IValue::isDoubleList() const { - return isListOf(); -} - -inline bool IValue::isComplexDoubleList() const { - return isListOf(); -} - -inline bool IValue::isTensorList() const { - return isListOf(); -} - -inline bool IValue::isIntList() const { - return isListOf(); -} - -inline bool IValue::isBoolList() const { - return isListOf(); -} - inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Tuple), is_intrusive_ptr(true) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); @@ -2285,8 +2258,13 @@ struct IValue::TagType { static TORCH_API c10::TypePtr get(const IValue&); }; +template <> +struct IValue::TagType { + static TORCH_API c10::TypePtr get(const IValue&); +}; + template -typename T::Ptr IValue::type() const { +TypePtr IValue::type() const { return IValue::TagType::get(*this); } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 358230cb2aff..6478e18b9a03 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -1730,7 +1731,8 @@ struct getTypePtr_ final { template <> struct getTypePtr_ final { static decltype(auto) call() { - return OptionalType::create(GeneratorType::get()); + return TypeFactory::create( + TypeFactory::get()); } }; template <> @@ -1798,7 +1800,8 @@ struct getTypePtr_> final { template struct getTypePtr_> final { static const auto& call() { - static auto type = OptionalType::create(getTypePtr_::call()); + static auto type = TypeFactory::create( + getTypePtr_::call()); return type; } }; diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h index 6c69bf367dc1..3eb741d051c3 100644 --- a/aten/src/ATen/core/jit_type_base.h +++ b/aten/src/ATen/core/jit_type_base.h @@ -558,6 +558,9 @@ struct TORCH_API Type { virtual TypePtr containedType(size_t i) const { return containedTypes().at(i); } + virtual size_t containedTypeSize() const { + return containedTypes().size(); + } // create a new version of this type, replacing its contained types with // contained_types TypePtr withContained(std::vector contained_types); diff --git a/aten/src/ATen/core/type_factory.cpp b/aten/src/ATen/core/type_factory.cpp index 11ee72c14bd0..78c5a31b86ef 100644 --- a/aten/src/ATen/core/type_factory.cpp +++ b/aten/src/ATen/core/type_factory.cpp @@ -1,5 +1,7 @@ #include +#include + namespace c10 { // Dtype constraints are not constrained in compilation. Therefore, we map @@ -56,4 +58,11 @@ const std::unordered_map& DefaultTypeFactory:: return map; } +c10::TypePtr DefaultTypeFactory::createNamedTuple( + const std::string& name, + const std::vector& fields, + const std::vector& types) { + return c10::TupleType::createNamed(name, fields, types); +} + } // namespace c10 diff --git a/aten/src/ATen/core/type_factory.h b/aten/src/ATen/core/type_factory.h index b979d6bb17c2..5718f79efff2 100644 --- a/aten/src/ATen/core/type_factory.h +++ b/aten/src/ATen/core/type_factory.h @@ -1,12 +1,19 @@ #pragma once -#include -#include #include +#include + +#include +#include +#include namespace c10 { -struct TORCH_API DynamicTypeFactory { +template +struct TORCH_API TypeFactoryBase {}; + +template <> +struct TORCH_API TypeFactoryBase { template static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) { return std::make_shared( @@ -29,26 +36,40 @@ struct TORCH_API DynamicTypeFactory { name, c10::DynamicType::Arguments(fields, types)); } + template + C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) { + return std::make_shared( + c10::DynamicTypeTrait::tagValue(), + name, + c10::DynamicType::Arguments{}); + } + template + C10_ERASE static c10::DynamicTypePtr get() { + return DynamicTypeTrait::getBaseType(); + } static const std::unordered_map& basePythonTypes(); }; +using DynamicTypeFactory = TypeFactoryBase; + // Helper functions for constructing DynamicTypes inline. template < typename T, std::enable_if_t::isBaseType, int> = 0> -DynamicTypePtr dynT() { - return DynamicTypeTrait::getBaseType(); +C10_ERASE DynamicTypePtr dynT() { + return DynamicTypeFactory::get(); } template < typename T, typename... Args, std::enable_if_t::isBaseType, int> = 0> -DynamicTypePtr dynT(Args&&... args) { +C10_ERASE DynamicTypePtr dynT(Args&&... args) { return DynamicTypeFactory::create(std::forward(args)...); } -struct TORCH_API DefaultTypeFactory { +template <> +struct TORCH_API TypeFactoryBase { template static c10::TypePtr create(TypePtr ty, Args&&... args) { return T::create(std::move(ty), std::forward(args)...); @@ -60,18 +81,28 @@ struct TORCH_API DefaultTypeFactory { static c10::TypePtr createNamedTuple( const std::string& name, const std::vector& fields, - const std::vector& types) { - return c10::TupleType::createNamed(name, fields, types); + const std::vector& types); + template + C10_ERASE static c10::TypePtr createNamed(const std::string& name) { + return T::create(name); } static const std::unordered_map& basePythonTypes(); + template + C10_ERASE static c10::TypePtr get() { + return T::get(); + } }; -using TypeFactory = +using DefaultTypeFactory = TypeFactoryBase; + +using PlatformType = #ifdef C10_MOBILE - DynamicTypeFactory + c10::DynamicType #else - DefaultTypeFactory + c10::Type #endif ; +using TypeFactory = TypeFactoryBase; + } // namespace c10 diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index a10806cf6716..28dc1df9430e 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -225,6 +225,16 @@ using namespace c10::hip; #define C10_ALWAYS_INLINE inline #endif +#if defined(_MSC_VER) +#define C10_ATTR_VISIBILITY_HIDDEN +#elif defined(__GNUC__) +#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden"))) +#else +#define C10_ATTR_VISIBILITY_HIDDEN +#endif + +#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN + // C10_FALLTHROUGH - Annotate fallthrough to the next case in a switch. #if C10_HAS_CPP_ATTRIBUTE(fallthrough) #define C10_FALLTHROUGH [[fallthrough]] diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index 5fd4d8ec267f..7190eb166c1c 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -16,6 +16,9 @@ namespace torch { namespace jit { static inline TypePtr unwrapOptional(TypePtr opt_type) { + if (auto dyn = opt_type->castRaw()) { + return unwrapOptional(dyn->fallback()); + } if (auto unwrap_list_type = opt_type->cast()) { return unwrap_list_type->getElementType(); } @@ -282,12 +285,17 @@ static bool varargsCanBeUsedAsList( bool is_last_argument = arg_index + 1 == schema.arguments().size() || schema.arguments()[arg_index + 1].kwarg_only(); + auto arg_type = arg.type(); + if (auto dyn = arg_type->castRaw()) { + arg_type = dyn->fallback(); + } + // The formal must be a list - bool argument_is_list = arg.type()->kind() == TypeKind::ListType; + bool argument_is_list = arg_type->kind() == TypeKind::ListType; // matching varargs of typevar list nyi bool typevar_list = argument_is_list && - arg.type()->castRaw()->getElementType()->cast(); + arg_type->castRaw()->getElementType()->cast(); // it must not be a broadcasting list like int[3], // otherwise a single int is a valid input diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index 9f7c4b436b33..252b5e2370ca 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -41,32 +41,33 @@ namespace jit { TypePtr SchemaTypeParser::parseBaseType() { static std::unordered_map type_map = { - {"Generator", GeneratorType::get()}, - {"Dimname", StringType::get()}, - {"ScalarType", IntType::get()}, - {"Layout", IntType::get()}, - {"MemoryFormat", IntType::get()}, - {"Storage", StorageType::get()}, - {"QScheme", QSchemeType::get()}, - {"Quantizer", QuantizerType::get()}, + {"Generator", c10::TypeFactory::get()}, + {"Dimname", c10::TypeFactory::get()}, + {"ScalarType", c10::TypeFactory::get()}, + {"Layout", c10::TypeFactory::get()}, + {"MemoryFormat", c10::TypeFactory::get()}, + {"Storage", c10::TypeFactory::get()}, + {"QScheme", c10::TypeFactory::get()}, + {"Quantizer", c10::TypeFactory::get()}, {"ConstQuantizerPtr", - IntType::get()}, // TODO This type should be removed from the schema - // parser, it should use the custom class mechanism - // instead. @jerryzh - {"Device", DeviceObjType::get()}, - {"Stream", StreamObjType::get()}, - {"Scalar", NumberType::get()}, - {"str", StringType::get()}, - {"float", FloatType::get()}, - {"complex", ComplexType::get()}, - {"int", IntType::get()}, - {"bool", BoolType::get()}, - {"None", NoneType::get()}, - {"NoneType", NoneType::get()}, - {"Capsule", CapsuleType::get()}, - {"Any", at::AnyType::get()}, - {"AnyClassType", at::AnyClassType::get()}, - {"AnyEnumType", at::AnyEnumType::get()}, + c10::TypeFactory::get()}, // TODO This type should be removed + // from the schema parser, it should + // use the custom class mechanism + // instead. @jerryzh + {"Device", c10::TypeFactory::get()}, + {"Stream", c10::TypeFactory::get()}, + {"Scalar", c10::TypeFactory::get()}, + {"str", c10::TypeFactory::get()}, + {"float", c10::TypeFactory::get()}, + {"complex", c10::TypeFactory::get()}, + {"int", c10::TypeFactory::get()}, + {"bool", c10::TypeFactory::get()}, + {"None", c10::TypeFactory::get()}, + {"NoneType", c10::TypeFactory::get()}, + {"Capsule", c10::TypeFactory::get()}, + {"Any", c10::TypeFactory::get()}, + {"AnyClassType", c10::TypeFactory::get()}, + {"AnyEnumType", c10::TypeFactory::get()}, }; auto tok = L.cur(); if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) { @@ -79,7 +80,7 @@ TypePtr SchemaTypeParser::parseBaseType() { if (text.size() > 0 && islower(text[0])) { // lower case identifiers that are not otherwise valid types // are treated as type variables - return VarType::create(text); + return c10::TypeFactory::createNamed(text); } throw ErrorReport(tok.range) << "unknown type specifier"; } @@ -313,7 +314,7 @@ std::pair> SchemaTypeParser::parseType() { alias_info->addContainedType(std::move(*r.second)); } }); - value = TupleType::create(std::move(types)); + value = c10::TypeFactory::create(std::move(types)); } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") { L.next(); // Future L.expect('('); @@ -321,7 +322,7 @@ std::pair> SchemaTypeParser::parseType() { auto subtype = std::move(p.first); auto subalias = std::move(p.second); L.expect(')'); - value = FutureType::create(subtype); + value = c10::TypeFactory::create(subtype); } else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") { L.next(); // RRef L.expect('('); @@ -329,10 +330,10 @@ std::pair> SchemaTypeParser::parseType() { auto subtype = std::move(p.first); auto subalias = std::move(p.second); L.expect(')'); - value = RRefType::create(subtype); + value = c10::TypeFactory::create(subtype); } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") { L.next(); - value = TensorType::get(); + value = c10::TypeFactory::get(); alias_info = parseAliasAnnotation(); } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") { L.next(); @@ -342,7 +343,7 @@ std::pair> SchemaTypeParser::parseType() { auto value_type = parseType().first; L.expect(')'); alias_info = parseAliasAnnotation(); - value = DictType::create(key_type, value_type); + value = c10::TypeFactory::create(key_type, value_type); } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") { L.next(); L.expect('('); @@ -395,7 +396,7 @@ std::pair> SchemaTypeParser::parseType() { if (L.cur().kind == '[' && L.lookahead().kind == ']') { L.next(); // [ L.next(); // ] - value = ListType::create(value); + value = c10::TypeFactory::create(value); auto container = parseAliasAnnotation(); if (container && alias_info) { container->addContainedType(std::move(*alias_info)); diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index d363a988fb22..f5de995a9748 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -1485,6 +1485,9 @@ inline Value::Value(Node* node_, size_t offset_) inline Value* Value::setType(TypePtr type) { AT_ASSERT(type); + if (auto dyn = type->castRaw()) { + type = dyn->fallback(); + } type_ = std::move(type); for (Use& use : uses_) { use.user->op_ = nullptr; diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 6a09d6a5cc3c..ccdccf4eb1d9 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -108,7 +108,11 @@ std::pair getFunctionTuple( static const std::string torch_prefix("__torch__"); static const std::string class_prefix("__torch__.torch.classes"); - for (const TypePtr& t : mobile_code.types_) { + for (const TypePtr& ty : mobile_code.types_) { + auto t = ty; + if (auto dyn = t->castRaw()) { + t = dyn->fallback(); + } std::string type_str = t->annotation_str(); if (t->kind() == TypeKind::TupleType) { TORCH_CHECK( @@ -216,9 +220,13 @@ std::pair getFunctionTuple( arg.type()->annotation_str(type_printer) => mangled unique name of the module/submodule */ + auto arg_type = arg.type(); + if (auto dyn = arg_type->castRaw()) { + arg_type = dyn->fallback(); + } argTables.emplace_back(Table({ {"name", arg.name()}, - {"type", arg.type()->annotation_str(type_printer)}, + {"type", arg_type->annotation_str(type_printer)}, {"default_value", arg.default_value()}, })); } diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 3f384f023b84..ac365142411e 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -575,7 +575,11 @@ void Pickler::endTypeTag(const IValue& ivalue) { // Push the dict type TORCH_INTERNAL_ASSERT(ivalue.type()); - pushString(ivalue.type()->annotation_str()); + auto type = ivalue.type(); + if (auto dyn = type->castRaw()) { + type = dyn->fallback(); + } + pushString(type->annotation_str()); // Pop the dict and type into a tuple push(PickleOpCode::TUPLE2); diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 1d7a7f5eba1f..391849ace7c7 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -34,7 +34,7 @@ static void restoreAccurateTypeTagsIfPossible(const IValue& root) { // of the contained objects and cannot restore the tags. void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { struct Work { - TypePtr static_type; + TypePtr type; IValue value; }; std::vector to_process = {{type_tag, root}}; @@ -53,7 +53,11 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { } scanned.emplace_hint(it, key); } - switch (w.static_type->kind()) { + auto kind = w.type->kind(); + if (auto dyn = w.type->castRaw()) { + kind = dyn->dynamicKind(); + } + switch (kind) { case TensorType::Kind: case StorageType::Kind: case NumberType::Kind: @@ -83,52 +87,37 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { // no op, there is nothing to tag break; case DynamicType::Kind: + case UnionType::Kind: case EnumType::Kind: // TODO(gmagogsfm): Implement serialization/deserialization of Enum. TORCH_INTERNAL_ASSERT(false); case TupleType::Kind: { auto t = w.value.toTuple(); - auto ttype = w.static_type->expect(); - for (size_t i = 0; i < ttype->containedTypes().size(); ++i) { - Work elem = {ttype->containedTypes().at(i), t->elements().at(i)}; + for (size_t i = 0; i < w.type->containedTypeSize(); ++i) { + Work elem = {w.type->containedType(i), t->elements().at(i)}; to_process.emplace_back(std::move(elem)); } } break; case FutureType::Kind: { auto f = w.value.toFuture(); - auto t = w.static_type->expect(); if (f->completed()) { - Work elem = {t->getElementType(), f->value()}; + Work elem = {w.type->containedType(0), f->value()}; to_process.emplace_back(std::move(elem)); } } break; case OptionalType::Kind: { if (!w.value.isNone()) { - auto t = w.static_type->expect(); - Work elem = {t->getElementType(), w.value}; + Work elem = {w.type->containedType(0), w.value}; to_process.emplace_back(std::move(elem)); } } break; - case UnionType::Kind: { - auto t = w.static_type->expect(); - if (t->containedTypes().size() == 2 && - t->canHoldType(*NoneType::get())) { - if (!w.value.isNone()) { - auto inner = t->containedTypes()[0] != NoneType::get() - ? t->containedTypes()[0] - : t->containedTypes()[1]; - Work elem = {inner, w.value}; - to_process.emplace_back(std::move(elem)); - } - } - } break; case ListType::Kind: { // specialized lists do not need their type refined, so we can exit // early here if (!w.value.isList()) { break; } - auto elem_type = w.static_type->castRaw()->getElementType(); + auto elem_type = w.type->containedType(0); auto lst = w.value.toList(); lst.unsafeSetElementType(elem_type); for (const IValue item : lst) { @@ -137,13 +126,14 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { } } break; case DictType::Kind: { - auto dt = w.static_type->cast(); auto d = w.value.toGenericDict(); - d.unsafeSetKeyType(dt->getKeyType()); - d.unsafeSetValueType(dt->getValueType()); + auto keyType = w.type->containedType(0); + auto valType = w.type->containedType(1); + d.unsafeSetKeyType(keyType); + d.unsafeSetValueType(valType); for (const auto& item : d) { - Work kelem = {dt->getKeyType(), item.key()}; - Work velem = {dt->getValueType(), item.value()}; + Work kelem = {keyType, item.key()}; + Work velem = {valType, item.value()}; to_process.emplace_back(std::move(kelem)); to_process.emplace_back(std::move(velem)); }