[jit][edge] Migrate to TypeFactory for jit types on mobile (#71516)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71516

Mobile should be able to contruct dynamic types by default.
ghstack-source-id: 147498365

Test Plan:
CI.

**-48KB** binary size reduction for igios BSB.
UMBEX link: https://www.internalfb.com/intern/unigraph/explorer/?jsgq_traversal_spec=%7B%22builds%22%3A[%22bsb%3A422553426218394%5Cu0040base%22%2C%22bsb%3A422553426218394%5Cu0040diff%22]%7D&unigraph_project=UnigraphProjectMbex&is_mbex_redirected

Reviewed By: iseeyuan

Differential Revision: D33673958

fbshipit-source-id: 8600c04ae929283681971aae264d3774188df9cd
(cherry picked from commit 64ebcec09e69d2eff64fdbf926fb43d3b67f99b2)
This commit is contained in:
Zhengxu Chen
2022-01-25 22:58:45 -08:00
committed by PyTorch MergeBot
parent e5794974cb
commit fe277b8717
20 changed files with 236 additions and 140 deletions

View File

@ -308,7 +308,7 @@ c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
#else
// TODO: caffe2::PThreadPool only provides a data-parallel API.
// Task parallelism is not currently supported.
auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get());
auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>());
func();
future->markCompleted();
return future;

View File

@ -3,6 +3,7 @@
#include <ATen/core/jit_type.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/functional.h>
#include <ATen/core/type_factory.h>
#include <atomic>
#include <unordered_map>
@ -102,7 +103,7 @@ class_base::class_base(
{
detail::checkValidIdent(namespaceName, "Namespace name");
detail::checkValidIdent(className, "Class name");
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
classTypePtr->addAttribute("capsule", c10::TypeFactory::get<c10::CapsuleType>());
c10::getCustomClassTypeMap().insert(
{std::type_index(intrusivePtrClassTypeid), classTypePtr});
c10::getCustomClassTypeMap().insert(

View File

@ -2,6 +2,7 @@
#include <string>
#include <ATen/core/class_type.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/type_factory.h>
@ -198,6 +199,11 @@ TypePtr DynamicType::containedType(size_t i) const {
return arguments_.elems.at(i).ty;
}
size_t DynamicType::containedTypeSize() const {
TORCH_INTERNAL_ASSERT(tag_ != Tag::Class);
return arguments_.elems.size();
}
TypeKind DynamicType::dynamicKind() const {
switch (tag_) {
#define CASE_TYPE(T, _, __) \
@ -271,6 +277,16 @@ TypePtr DynamicType::fallback() const {
return VarType::create(*name_);
case Tag::AnyClass:
return AnyClassType::get();
case Tag::QScheme:
return QSchemeType::get();
case Tag::Quantizer:
return QuantizerType::get();
case Tag::AnyEnum:
return AnyEnumType::get();
case Tag::RRef:
return RRefType::create(arguments_.elems[0].ty->fallback());
case Tag::Future:
return FutureType::create(arguments_.elems[0].ty->fallback());
case Tag::Any:
return AnyType::get();
}

View File

@ -3,8 +3,6 @@
#include <memory>
#include <type_traits>
#include <ATen/core/class_type.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type_base.h>
#include <c10/util/Optional.h>
@ -53,8 +51,17 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
_(Storage, DYNAMIC_TYPE_BIT(16), 1) \
_(Var, DYNAMIC_TYPE_BIT(17), 0) \
_(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \
_(QScheme, DYNAMIC_TYPE_BIT(18), 1) \
_(Quantizer, DYNAMIC_TYPE_BIT(19), 1) \
_(AnyEnum, DYNAMIC_TYPE_BIT(20), 1) \
_(RRef, DYNAMIC_TYPE_BIT(21), 0) \
_(Future, DYNAMIC_TYPE_BIT(22), 0) \
_(Any, 0xffffffff, 1)
#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;
FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE)
#undef FORWARD_DECL_TYPE
class DynamicType;
using DynamicTypePtr = std::shared_ptr<DynamicType>;
@ -142,6 +149,7 @@ class DynamicType : public SharedType {
explicit DynamicType(Tag, c10::string_view, Arguments);
TypePtr containedType(size_t) const override;
size_t containedTypeSize() const override;
Tag tag() const {
return tag_;
}
@ -154,6 +162,9 @@ class DynamicType : public SharedType {
TypeKind dynamicKind() const;
// Should be used only on the server side to restore static type information.
#ifndef C10_MOBILE
TORCH_API
#endif
TypePtr fallback() const;
private:
@ -188,7 +199,7 @@ class DynamicType : public SharedType {
template <typename T>
struct DynamicTypeTrait {
static auto tagValue() {
C10_NOINLINE static auto tagValue() {
TORCH_CHECK(false);
return DynamicType::Tag::Any;
}
@ -201,7 +212,7 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
#define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \
template <> \
struct TORCH_API DynamicTypeTrait<NAME##Type> { \
static auto tagValue() { \
C10_ERASE static auto tagValue() { \
return DynamicType::Tag::NAME; \
} \
static constexpr bool isBaseType = IS_BASE_TYPE; \
@ -214,19 +225,4 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE)
#undef DYNAMIC_TYPE_TAG_VALUE
template <>
struct IValue::TagType<c10::DynamicType> {
static DynamicType::Ptr get(const c10::IValue& v);
};
namespace ivalue {
template <>
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
static DynamicTypePtr fallback(const Type&);
};
} // namespace ivalue
} // namespace c10

View File

@ -390,7 +390,7 @@ struct FunctionSchema {
// Check that inputs have the correct types and appends any missing default
// values.
template <typename T = c10::Type>
template <typename T = c10::PlatformType>
void checkAndNormalizeInputs(
std::vector<IValue>& inputs,
const std::unordered_map<std::string, IValue>& kwargs =

View File

@ -293,7 +293,7 @@ inline void FunctionSchema::checkArg(
TORCH_CHECK(
false,
formatTypeMismatchMsg(
argument, value.type()->repr_str(), pos));
argument, value.type<T>()->repr_str(), pos));
}
}

View File

@ -6,6 +6,7 @@
#include <ATen/core/function.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/stack.h>
#include <ATen/core/type_factory.h>
#include <c10/util/irange.h>
#include <c10/util/StringUtil.h>
#include <c10/util/hash.h>
@ -403,6 +404,39 @@ bool IValue::is(const IValue& rhs) const {
return lhs == rhs;
}
template <typename T>
inline bool IValue::isListOf() const {
// note: avoids calling type() to avoid extra referencing counting for the returned type.
if (!isList()) {
return false;
}
const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
if (ty->kind() == T::Kind) {
return true;
}
return *ty == *TypeFactory::get<T>();
}
bool IValue::isDoubleList() const {
return isListOf<c10::FloatType>();
}
bool IValue::isComplexDoubleList() const {
return isListOf<c10::ComplexType>();
}
bool IValue::isTensorList() const {
return isListOf<c10::TensorType>();
}
bool IValue::isIntList() const {
return isListOf<c10::IntType>();
}
bool IValue::isBoolList() const {
return isListOf<c10::BoolType>();
}
namespace {
using IValueFormatter = std::function<void(std::ostream&, const IValue&)>;
@ -430,7 +464,7 @@ std::ostream& printMaybeAnnotatedList(
std::ostream& out,
const IValue& the_list,
IValueFormatter formatter) {
auto list_elem_type = the_list.type()->expectRef<ListType>().getElementType();
auto list_elem_type = the_list.type()->containedType(0);
if (the_list.toListRef().size() == 0 ||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
out << "annotate(" << the_list.type()->annotation_str() << ", ";
@ -925,7 +959,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(IValue::HashAliasedI
auto cu = type_.cu_;
auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());
for (const auto i : c10::irange(slots_.size())) {
if (slots_[i].type() == c10::CapsuleType::get()) {
if (*slots_[i].type() == *c10::TypeFactory::get<CapsuleType>()) {
// If we've gotten here, it means that we have *not* copied this
// class via __getstate__ and __setstate__. That fact and the
// fact that we have a Capsule attribute mean that this is a

View File

@ -6,6 +6,7 @@
#include <ATen/core/custom_class.h>
#include <ATen/core/ivalue_to.h>
#include <ATen/core/jit_type_base.h>
#include <ATen/core/type_factory.h>
#include <c10/util/C++17.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/intrusive_ptr.h>
@ -895,8 +896,8 @@ public:
}
}
template <typename T = c10::Type>
typename T::Ptr type() const;
template <typename T = c10::PlatformType>
TypePtr type() const;
// Detect aliased tensors.
struct HashAliasedIValue {

View File

@ -586,6 +586,12 @@ struct TORCH_API TupleTypeFactory<TupleType> {
static TupleTypePtr fallback(const Type& type);
};
template <>
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
static DynamicTypePtr fallback(const Type&);
};
struct TORCH_API Tuple : c10::intrusive_ptr_target {
private:
TupleElements elements_;
@ -1915,39 +1921,6 @@ inline ivalue::Tuple& IValue::toTupleRef() const {
payload.u.as_intrusive_ptr);
}
template <typename T>
inline bool IValue::isListOf() const {
// note: avoids calling type() to avoid extra referencing counting for the returned type.
if (!isList()) {
return false;
}
const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
if (ty->kind() == T::Kind) {
return true;
}
return *ty == *T::get();
}
inline bool IValue::isDoubleList() const {
return isListOf<c10::FloatType>();
}
inline bool IValue::isComplexDoubleList() const {
return isListOf<c10::ComplexType>();
}
inline bool IValue::isTensorList() const {
return isListOf<c10::TensorType>();
}
inline bool IValue::isIntList() const {
return isListOf<c10::IntType>();
}
inline bool IValue::isBoolList() const {
return isListOf<c10::BoolType>();
}
inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
: tag(Tag::Tuple), is_intrusive_ptr(true) {
payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
@ -2285,8 +2258,13 @@ struct IValue::TagType<c10::Type> {
static TORCH_API c10::TypePtr get(const IValue&);
};
template <>
struct IValue::TagType<c10::DynamicType> {
static TORCH_API c10::TypePtr get(const IValue&);
};
template <typename T>
typename T::Ptr IValue::type() const {
TypePtr IValue::type() const {
return IValue::TagType<T>::get(*this);
}

View File

@ -5,6 +5,7 @@
#include <ATen/core/TensorBody.h>
#include <ATen/core/functional.h>
#include <ATen/core/symbol.h>
#include <ATen/core/type_factory.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/TypeList.h>
#include <c10/util/Optional.h>
@ -1730,7 +1731,8 @@ struct getTypePtr_<c10::QScheme> final {
template <>
struct getTypePtr_<at::Generator> final {
static decltype(auto) call() {
return OptionalType::create(GeneratorType::get());
return TypeFactory::create<OptionalType>(
TypeFactory::get<GeneratorType>());
}
};
template <>
@ -1798,7 +1800,8 @@ struct getTypePtr_<c10::Dict<K, V>> final {
template <class T>
struct getTypePtr_<at::optional<T>> final {
static const auto& call() {
static auto type = OptionalType::create(getTypePtr_<T>::call());
static auto type = TypeFactory::create<OptionalType>(
getTypePtr_<T>::call());
return type;
}
};

View File

@ -558,6 +558,9 @@ struct TORCH_API Type {
virtual TypePtr containedType(size_t i) const {
return containedTypes().at(i);
}
virtual size_t containedTypeSize() const {
return containedTypes().size();
}
// create a new version of this type, replacing its contained types with
// contained_types
TypePtr withContained(std::vector<TypePtr> contained_types);

View File

@ -1,5 +1,7 @@
#include <ATen/core/type_factory.h>
#include <ATen/core/jit_type.h>
namespace c10 {
// Dtype constraints are not constrained in compilation. Therefore, we map
@ -56,4 +58,11 @@ const std::unordered_map<std::string, c10::TypePtr>& DefaultTypeFactory::
return map;
}
c10::TypePtr DefaultTypeFactory::createNamedTuple(
const std::string& name,
const std::vector<c10::string_view>& fields,
const std::vector<c10::TypePtr>& types) {
return c10::TupleType::createNamed(name, fields, types);
}
} // namespace c10

View File

@ -1,12 +1,19 @@
#pragma once
#include <ATen/core/dynamic_type.h>
#include <ATen/core/jit_type.h>
#include <type_traits>
#include <unordered_map>
#include <ATen/core/dynamic_type.h>
#include <ATen/core/jit_type_base.h>
#include <c10/macros/Macros.h>
namespace c10 {
struct TORCH_API DynamicTypeFactory {
template <typename T>
struct TORCH_API TypeFactoryBase {};
template <>
struct TORCH_API TypeFactoryBase<c10::DynamicType> {
template <typename T, typename... Args>
static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) {
return std::make_shared<c10::DynamicType>(
@ -29,26 +36,40 @@ struct TORCH_API DynamicTypeFactory {
name,
c10::DynamicType::Arguments(fields, types));
}
template <typename T>
C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) {
return std::make_shared<c10::DynamicType>(
c10::DynamicTypeTrait<T>::tagValue(),
name,
c10::DynamicType::Arguments{});
}
template <typename T>
C10_ERASE static c10::DynamicTypePtr get() {
return DynamicTypeTrait<T>::getBaseType();
}
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
};
using DynamicTypeFactory = TypeFactoryBase<c10::DynamicType>;
// Helper functions for constructing DynamicTypes inline.
template <
typename T,
std::enable_if_t<DynamicTypeTrait<T>::isBaseType, int> = 0>
DynamicTypePtr dynT() {
return DynamicTypeTrait<T>::getBaseType();
C10_ERASE DynamicTypePtr dynT() {
return DynamicTypeFactory::get<T>();
}
template <
typename T,
typename... Args,
std::enable_if_t<!DynamicTypeTrait<T>::isBaseType, int> = 0>
DynamicTypePtr dynT(Args&&... args) {
C10_ERASE DynamicTypePtr dynT(Args&&... args) {
return DynamicTypeFactory::create<T>(std::forward<Args>(args)...);
}
struct TORCH_API DefaultTypeFactory {
template <>
struct TORCH_API TypeFactoryBase<c10::Type> {
template <typename T, typename... Args>
static c10::TypePtr create(TypePtr ty, Args&&... args) {
return T::create(std::move(ty), std::forward<Args>(args)...);
@ -60,18 +81,28 @@ struct TORCH_API DefaultTypeFactory {
static c10::TypePtr createNamedTuple(
const std::string& name,
const std::vector<c10::string_view>& fields,
const std::vector<c10::TypePtr>& types) {
return c10::TupleType::createNamed(name, fields, types);
const std::vector<c10::TypePtr>& types);
template <typename T>
C10_ERASE static c10::TypePtr createNamed(const std::string& name) {
return T::create(name);
}
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
template <typename T>
C10_ERASE static c10::TypePtr get() {
return T::get();
}
};
using TypeFactory =
using DefaultTypeFactory = TypeFactoryBase<c10::Type>;
using PlatformType =
#ifdef C10_MOBILE
DynamicTypeFactory
c10::DynamicType
#else
DefaultTypeFactory
c10::Type
#endif
;
using TypeFactory = TypeFactoryBase<PlatformType>;
} // namespace c10

View File

@ -225,6 +225,16 @@ using namespace c10::hip;
#define C10_ALWAYS_INLINE inline
#endif
#if defined(_MSC_VER)
#define C10_ATTR_VISIBILITY_HIDDEN
#elif defined(__GNUC__)
#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden")))
#else
#define C10_ATTR_VISIBILITY_HIDDEN
#endif
#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN
// C10_FALLTHROUGH - Annotate fallthrough to the next case in a switch.
#if C10_HAS_CPP_ATTRIBUTE(fallthrough)
#define C10_FALLTHROUGH [[fallthrough]]

View File

@ -16,6 +16,9 @@ namespace torch {
namespace jit {
static inline TypePtr unwrapOptional(TypePtr opt_type) {
if (auto dyn = opt_type->castRaw<c10::DynamicType>()) {
return unwrapOptional(dyn->fallback());
}
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
return unwrap_list_type->getElementType();
}
@ -282,12 +285,17 @@ static bool varargsCanBeUsedAsList(
bool is_last_argument = arg_index + 1 == schema.arguments().size() ||
schema.arguments()[arg_index + 1].kwarg_only();
auto arg_type = arg.type();
if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
arg_type = dyn->fallback();
}
// The formal must be a list
bool argument_is_list = arg.type()->kind() == TypeKind::ListType;
bool argument_is_list = arg_type->kind() == TypeKind::ListType;
// matching varargs of typevar list nyi
bool typevar_list = argument_is_list &&
arg.type()->castRaw<ListType>()->getElementType()->cast<VarType>();
arg_type->castRaw<ListType>()->getElementType()->cast<VarType>();
// it must not be a broadcasting list like int[3],
// otherwise a single int is a valid input

View File

@ -41,32 +41,33 @@ namespace jit {
TypePtr SchemaTypeParser::parseBaseType() {
static std::unordered_map<std::string, TypePtr> type_map = {
{"Generator", GeneratorType::get()},
{"Dimname", StringType::get()},
{"ScalarType", IntType::get()},
{"Layout", IntType::get()},
{"MemoryFormat", IntType::get()},
{"Storage", StorageType::get()},
{"QScheme", QSchemeType::get()},
{"Quantizer", QuantizerType::get()},
{"Generator", c10::TypeFactory::get<GeneratorType>()},
{"Dimname", c10::TypeFactory::get<StringType>()},
{"ScalarType", c10::TypeFactory::get<IntType>()},
{"Layout", c10::TypeFactory::get<IntType>()},
{"MemoryFormat", c10::TypeFactory::get<IntType>()},
{"Storage", c10::TypeFactory::get<StorageType>()},
{"QScheme", c10::TypeFactory::get<QSchemeType>()},
{"Quantizer", c10::TypeFactory::get<QuantizerType>()},
{"ConstQuantizerPtr",
IntType::get()}, // TODO This type should be removed from the schema
// parser, it should use the custom class mechanism
c10::TypeFactory::get<IntType>()}, // TODO This type should be removed
// from the schema parser, it should
// use the custom class mechanism
// instead. @jerryzh
{"Device", DeviceObjType::get()},
{"Stream", StreamObjType::get()},
{"Scalar", NumberType::get()},
{"str", StringType::get()},
{"float", FloatType::get()},
{"complex", ComplexType::get()},
{"int", IntType::get()},
{"bool", BoolType::get()},
{"None", NoneType::get()},
{"NoneType", NoneType::get()},
{"Capsule", CapsuleType::get()},
{"Any", at::AnyType::get()},
{"AnyClassType", at::AnyClassType::get()},
{"AnyEnumType", at::AnyEnumType::get()},
{"Device", c10::TypeFactory::get<DeviceObjType>()},
{"Stream", c10::TypeFactory::get<StreamObjType>()},
{"Scalar", c10::TypeFactory::get<NumberType>()},
{"str", c10::TypeFactory::get<StringType>()},
{"float", c10::TypeFactory::get<FloatType>()},
{"complex", c10::TypeFactory::get<ComplexType>()},
{"int", c10::TypeFactory::get<IntType>()},
{"bool", c10::TypeFactory::get<BoolType>()},
{"None", c10::TypeFactory::get<NoneType>()},
{"NoneType", c10::TypeFactory::get<NoneType>()},
{"Capsule", c10::TypeFactory::get<CapsuleType>()},
{"Any", c10::TypeFactory::get<c10::AnyType>()},
{"AnyClassType", c10::TypeFactory::get<c10::AnyClassType>()},
{"AnyEnumType", c10::TypeFactory::get<c10::AnyEnumType>()},
};
auto tok = L.cur();
if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) {
@ -79,7 +80,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
if (text.size() > 0 && islower(text[0])) {
// lower case identifiers that are not otherwise valid types
// are treated as type variables
return VarType::create(text);
return c10::TypeFactory::createNamed<VarType>(text);
}
throw ErrorReport(tok.range) << "unknown type specifier";
}
@ -313,7 +314,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
alias_info->addContainedType(std::move(*r.second));
}
});
value = TupleType::create(std::move(types));
value = c10::TypeFactory::create<TupleType>(std::move(types));
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
L.next(); // Future
L.expect('(');
@ -321,7 +322,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
value = FutureType::create(subtype);
value = c10::TypeFactory::create<FutureType>(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
L.next(); // RRef
L.expect('(');
@ -329,10 +330,10 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
value = RRefType::create(subtype);
value = c10::TypeFactory::create<RRefType>(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
L.next();
value = TensorType::get();
value = c10::TypeFactory::get<TensorType>();
alias_info = parseAliasAnnotation();
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
L.next();
@ -342,7 +343,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
auto value_type = parseType().first;
L.expect(')');
alias_info = parseAliasAnnotation();
value = DictType::create(key_type, value_type);
value = c10::TypeFactory::create<DictType>(key_type, value_type);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
L.next();
L.expect('(');
@ -395,7 +396,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
if (L.cur().kind == '[' && L.lookahead().kind == ']') {
L.next(); // [
L.next(); // ]
value = ListType::create(value);
value = c10::TypeFactory::create<ListType>(value);
auto container = parseAliasAnnotation();
if (container && alias_info) {
container->addContainedType(std::move(*alias_info));

View File

@ -1485,6 +1485,9 @@ inline Value::Value(Node* node_, size_t offset_)
inline Value* Value::setType(TypePtr type) {
AT_ASSERT(type);
if (auto dyn = type->castRaw<c10::DynamicType>()) {
type = dyn->fallback();
}
type_ = std::move(type);
for (Use& use : uses_) {
use.user->op_ = nullptr;

View File

@ -108,7 +108,11 @@ std::pair<IValue, IValue> getFunctionTuple(
static const std::string torch_prefix("__torch__");
static const std::string class_prefix("__torch__.torch.classes");
for (const TypePtr& t : mobile_code.types_) {
for (const TypePtr& ty : mobile_code.types_) {
auto t = ty;
if (auto dyn = t->castRaw<c10::DynamicType>()) {
t = dyn->fallback();
}
std::string type_str = t->annotation_str();
if (t->kind() == TypeKind::TupleType) {
TORCH_CHECK(
@ -216,9 +220,13 @@ std::pair<IValue, IValue> getFunctionTuple(
arg.type()->annotation_str(type_printer) => mangled unique name of the
module/submodule
*/
auto arg_type = arg.type();
if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
arg_type = dyn->fallback();
}
argTables.emplace_back(Table({
{"name", arg.name()},
{"type", arg.type()->annotation_str(type_printer)},
{"type", arg_type->annotation_str(type_printer)},
{"default_value", arg.default_value()},
}));
}

View File

@ -575,7 +575,11 @@ void Pickler::endTypeTag(const IValue& ivalue) {
// Push the dict type
TORCH_INTERNAL_ASSERT(ivalue.type());
pushString(ivalue.type()->annotation_str());
auto type = ivalue.type();
if (auto dyn = type->castRaw<c10::DynamicType>()) {
type = dyn->fallback();
}
pushString(type->annotation_str());
// Pop the dict and type into a tuple
push<PickleOpCode>(PickleOpCode::TUPLE2);

View File

@ -34,7 +34,7 @@ static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
// of the contained objects and cannot restore the tags.
void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
struct Work {
TypePtr static_type;
TypePtr type;
IValue value;
};
std::vector<Work> to_process = {{type_tag, root}};
@ -53,7 +53,11 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
}
scanned.emplace_hint(it, key);
}
switch (w.static_type->kind()) {
auto kind = w.type->kind();
if (auto dyn = w.type->castRaw<c10::DynamicType>()) {
kind = dyn->dynamicKind();
}
switch (kind) {
case TensorType::Kind:
case StorageType::Kind:
case NumberType::Kind:
@ -83,52 +87,37 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
// no op, there is nothing to tag
break;
case DynamicType::Kind:
case UnionType::Kind:
case EnumType::Kind:
// TODO(gmagogsfm): Implement serialization/deserialization of Enum.
TORCH_INTERNAL_ASSERT(false);
case TupleType::Kind: {
auto t = w.value.toTuple();
auto ttype = w.static_type->expect<TupleType>();
for (size_t i = 0; i < ttype->containedTypes().size(); ++i) {
Work elem = {ttype->containedTypes().at(i), t->elements().at(i)};
for (size_t i = 0; i < w.type->containedTypeSize(); ++i) {
Work elem = {w.type->containedType(i), t->elements().at(i)};
to_process.emplace_back(std::move(elem));
}
} break;
case FutureType::Kind: {
auto f = w.value.toFuture();
auto t = w.static_type->expect<FutureType>();
if (f->completed()) {
Work elem = {t->getElementType(), f->value()};
Work elem = {w.type->containedType(0), f->value()};
to_process.emplace_back(std::move(elem));
}
} break;
case OptionalType::Kind: {
if (!w.value.isNone()) {
auto t = w.static_type->expect<OptionalType>();
Work elem = {t->getElementType(), w.value};
Work elem = {w.type->containedType(0), w.value};
to_process.emplace_back(std::move(elem));
}
} break;
case UnionType::Kind: {
auto t = w.static_type->expect<UnionType>();
if (t->containedTypes().size() == 2 &&
t->canHoldType(*NoneType::get())) {
if (!w.value.isNone()) {
auto inner = t->containedTypes()[0] != NoneType::get()
? t->containedTypes()[0]
: t->containedTypes()[1];
Work elem = {inner, w.value};
to_process.emplace_back(std::move(elem));
}
}
} break;
case ListType::Kind: {
// specialized lists do not need their type refined, so we can exit
// early here
if (!w.value.isList()) {
break;
}
auto elem_type = w.static_type->castRaw<ListType>()->getElementType();
auto elem_type = w.type->containedType(0);
auto lst = w.value.toList();
lst.unsafeSetElementType(elem_type);
for (const IValue item : lst) {
@ -137,13 +126,14 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
}
} break;
case DictType::Kind: {
auto dt = w.static_type->cast<DictType>();
auto d = w.value.toGenericDict();
d.unsafeSetKeyType(dt->getKeyType());
d.unsafeSetValueType(dt->getValueType());
auto keyType = w.type->containedType(0);
auto valType = w.type->containedType(1);
d.unsafeSetKeyType(keyType);
d.unsafeSetValueType(valType);
for (const auto& item : d) {
Work kelem = {dt->getKeyType(), item.key()};
Work velem = {dt->getValueType(), item.value()};
Work kelem = {keyType, item.key()};
Work velem = {valType, item.value()};
to_process.emplace_back(std::move(kelem));
to_process.emplace_back(std::move(velem));
}