mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit][edge] Migrate to TypeFactory for jit types on mobile (#71516)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71516 Mobile should be able to contruct dynamic types by default. ghstack-source-id: 147498365 Test Plan: CI. **-48KB** binary size reduction for igios BSB. UMBEX link: https://www.internalfb.com/intern/unigraph/explorer/?jsgq_traversal_spec=%7B%22builds%22%3A[%22bsb%3A422553426218394%5Cu0040base%22%2C%22bsb%3A422553426218394%5Cu0040diff%22]%7D&unigraph_project=UnigraphProjectMbex&is_mbex_redirected Reviewed By: iseeyuan Differential Revision: D33673958 fbshipit-source-id: 8600c04ae929283681971aae264d3774188df9cd (cherry picked from commit 64ebcec09e69d2eff64fdbf926fb43d3b67f99b2)
This commit is contained in:
committed by
PyTorch MergeBot
parent
e5794974cb
commit
fe277b8717
@ -308,7 +308,7 @@ c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
|
||||
#else
|
||||
// TODO: caffe2::PThreadPool only provides a data-parallel API.
|
||||
// Task parallelism is not currently supported.
|
||||
auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get());
|
||||
auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>());
|
||||
func();
|
||||
future->markCompleted();
|
||||
return future;
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/functional.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <unordered_map>
|
||||
@ -102,7 +103,7 @@ class_base::class_base(
|
||||
{
|
||||
detail::checkValidIdent(namespaceName, "Namespace name");
|
||||
detail::checkValidIdent(className, "Class name");
|
||||
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
|
||||
classTypePtr->addAttribute("capsule", c10::TypeFactory::get<c10::CapsuleType>());
|
||||
c10::getCustomClassTypeMap().insert(
|
||||
{std::type_index(intrusivePtrClassTypeid), classTypePtr});
|
||||
c10::getCustomClassTypeMap().insert(
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <ATen/core/class_type.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
@ -198,6 +199,11 @@ TypePtr DynamicType::containedType(size_t i) const {
|
||||
return arguments_.elems.at(i).ty;
|
||||
}
|
||||
|
||||
size_t DynamicType::containedTypeSize() const {
|
||||
TORCH_INTERNAL_ASSERT(tag_ != Tag::Class);
|
||||
return arguments_.elems.size();
|
||||
}
|
||||
|
||||
TypeKind DynamicType::dynamicKind() const {
|
||||
switch (tag_) {
|
||||
#define CASE_TYPE(T, _, __) \
|
||||
@ -271,6 +277,16 @@ TypePtr DynamicType::fallback() const {
|
||||
return VarType::create(*name_);
|
||||
case Tag::AnyClass:
|
||||
return AnyClassType::get();
|
||||
case Tag::QScheme:
|
||||
return QSchemeType::get();
|
||||
case Tag::Quantizer:
|
||||
return QuantizerType::get();
|
||||
case Tag::AnyEnum:
|
||||
return AnyEnumType::get();
|
||||
case Tag::RRef:
|
||||
return RRefType::create(arguments_.elems[0].ty->fallback());
|
||||
case Tag::Future:
|
||||
return FutureType::create(arguments_.elems[0].ty->fallback());
|
||||
case Tag::Any:
|
||||
return AnyType::get();
|
||||
}
|
||||
|
@ -3,8 +3,6 @@
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/core/class_type.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type_base.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
@ -53,8 +51,17 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
|
||||
_(Storage, DYNAMIC_TYPE_BIT(16), 1) \
|
||||
_(Var, DYNAMIC_TYPE_BIT(17), 0) \
|
||||
_(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \
|
||||
_(QScheme, DYNAMIC_TYPE_BIT(18), 1) \
|
||||
_(Quantizer, DYNAMIC_TYPE_BIT(19), 1) \
|
||||
_(AnyEnum, DYNAMIC_TYPE_BIT(20), 1) \
|
||||
_(RRef, DYNAMIC_TYPE_BIT(21), 0) \
|
||||
_(Future, DYNAMIC_TYPE_BIT(22), 0) \
|
||||
_(Any, 0xffffffff, 1)
|
||||
|
||||
#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;
|
||||
FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE)
|
||||
#undef FORWARD_DECL_TYPE
|
||||
|
||||
class DynamicType;
|
||||
using DynamicTypePtr = std::shared_ptr<DynamicType>;
|
||||
|
||||
@ -142,6 +149,7 @@ class DynamicType : public SharedType {
|
||||
explicit DynamicType(Tag, c10::string_view, Arguments);
|
||||
|
||||
TypePtr containedType(size_t) const override;
|
||||
size_t containedTypeSize() const override;
|
||||
Tag tag() const {
|
||||
return tag_;
|
||||
}
|
||||
@ -154,6 +162,9 @@ class DynamicType : public SharedType {
|
||||
TypeKind dynamicKind() const;
|
||||
|
||||
// Should be used only on the server side to restore static type information.
|
||||
#ifndef C10_MOBILE
|
||||
TORCH_API
|
||||
#endif
|
||||
TypePtr fallback() const;
|
||||
|
||||
private:
|
||||
@ -188,7 +199,7 @@ class DynamicType : public SharedType {
|
||||
|
||||
template <typename T>
|
||||
struct DynamicTypeTrait {
|
||||
static auto tagValue() {
|
||||
C10_NOINLINE static auto tagValue() {
|
||||
TORCH_CHECK(false);
|
||||
return DynamicType::Tag::Any;
|
||||
}
|
||||
@ -201,7 +212,7 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
|
||||
#define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \
|
||||
template <> \
|
||||
struct TORCH_API DynamicTypeTrait<NAME##Type> { \
|
||||
static auto tagValue() { \
|
||||
C10_ERASE static auto tagValue() { \
|
||||
return DynamicType::Tag::NAME; \
|
||||
} \
|
||||
static constexpr bool isBaseType = IS_BASE_TYPE; \
|
||||
@ -214,19 +225,4 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
|
||||
FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE)
|
||||
#undef DYNAMIC_TYPE_TAG_VALUE
|
||||
|
||||
template <>
|
||||
struct IValue::TagType<c10::DynamicType> {
|
||||
static DynamicType::Ptr get(const c10::IValue& v);
|
||||
};
|
||||
|
||||
namespace ivalue {
|
||||
|
||||
template <>
|
||||
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
|
||||
static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
|
||||
static DynamicTypePtr fallback(const Type&);
|
||||
};
|
||||
|
||||
} // namespace ivalue
|
||||
|
||||
} // namespace c10
|
||||
|
@ -390,7 +390,7 @@ struct FunctionSchema {
|
||||
|
||||
// Check that inputs have the correct types and appends any missing default
|
||||
// values.
|
||||
template <typename T = c10::Type>
|
||||
template <typename T = c10::PlatformType>
|
||||
void checkAndNormalizeInputs(
|
||||
std::vector<IValue>& inputs,
|
||||
const std::unordered_map<std::string, IValue>& kwargs =
|
||||
|
@ -293,7 +293,7 @@ inline void FunctionSchema::checkArg(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
formatTypeMismatchMsg(
|
||||
argument, value.type()->repr_str(), pos));
|
||||
argument, value.type<T>()->repr_str(), pos));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <ATen/core/function.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/hash.h>
|
||||
@ -403,6 +404,39 @@ bool IValue::is(const IValue& rhs) const {
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool IValue::isListOf() const {
|
||||
// note: avoids calling type() to avoid extra referencing counting for the returned type.
|
||||
if (!isList()) {
|
||||
return false;
|
||||
}
|
||||
const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
|
||||
if (ty->kind() == T::Kind) {
|
||||
return true;
|
||||
}
|
||||
return *ty == *TypeFactory::get<T>();
|
||||
}
|
||||
|
||||
bool IValue::isDoubleList() const {
|
||||
return isListOf<c10::FloatType>();
|
||||
}
|
||||
|
||||
bool IValue::isComplexDoubleList() const {
|
||||
return isListOf<c10::ComplexType>();
|
||||
}
|
||||
|
||||
bool IValue::isTensorList() const {
|
||||
return isListOf<c10::TensorType>();
|
||||
}
|
||||
|
||||
bool IValue::isIntList() const {
|
||||
return isListOf<c10::IntType>();
|
||||
}
|
||||
|
||||
bool IValue::isBoolList() const {
|
||||
return isListOf<c10::BoolType>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using IValueFormatter = std::function<void(std::ostream&, const IValue&)>;
|
||||
@ -430,7 +464,7 @@ std::ostream& printMaybeAnnotatedList(
|
||||
std::ostream& out,
|
||||
const IValue& the_list,
|
||||
IValueFormatter formatter) {
|
||||
auto list_elem_type = the_list.type()->expectRef<ListType>().getElementType();
|
||||
auto list_elem_type = the_list.type()->containedType(0);
|
||||
if (the_list.toListRef().size() == 0 ||
|
||||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
|
||||
out << "annotate(" << the_list.type()->annotation_str() << ", ";
|
||||
@ -925,7 +959,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(IValue::HashAliasedI
|
||||
auto cu = type_.cu_;
|
||||
auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());
|
||||
for (const auto i : c10::irange(slots_.size())) {
|
||||
if (slots_[i].type() == c10::CapsuleType::get()) {
|
||||
if (*slots_[i].type() == *c10::TypeFactory::get<CapsuleType>()) {
|
||||
// If we've gotten here, it means that we have *not* copied this
|
||||
// class via __getstate__ and __setstate__. That fact and the
|
||||
// fact that we have a Capsule attribute mean that this is a
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <ATen/core/custom_class.h>
|
||||
#include <ATen/core/ivalue_to.h>
|
||||
#include <ATen/core/jit_type_base.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
@ -895,8 +896,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T = c10::Type>
|
||||
typename T::Ptr type() const;
|
||||
template <typename T = c10::PlatformType>
|
||||
TypePtr type() const;
|
||||
|
||||
// Detect aliased tensors.
|
||||
struct HashAliasedIValue {
|
||||
|
@ -586,6 +586,12 @@ struct TORCH_API TupleTypeFactory<TupleType> {
|
||||
static TupleTypePtr fallback(const Type& type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
|
||||
static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
|
||||
static DynamicTypePtr fallback(const Type&);
|
||||
};
|
||||
|
||||
struct TORCH_API Tuple : c10::intrusive_ptr_target {
|
||||
private:
|
||||
TupleElements elements_;
|
||||
@ -1915,39 +1921,6 @@ inline ivalue::Tuple& IValue::toTupleRef() const {
|
||||
payload.u.as_intrusive_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool IValue::isListOf() const {
|
||||
// note: avoids calling type() to avoid extra referencing counting for the returned type.
|
||||
if (!isList()) {
|
||||
return false;
|
||||
}
|
||||
const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
|
||||
if (ty->kind() == T::Kind) {
|
||||
return true;
|
||||
}
|
||||
return *ty == *T::get();
|
||||
}
|
||||
|
||||
inline bool IValue::isDoubleList() const {
|
||||
return isListOf<c10::FloatType>();
|
||||
}
|
||||
|
||||
inline bool IValue::isComplexDoubleList() const {
|
||||
return isListOf<c10::ComplexType>();
|
||||
}
|
||||
|
||||
inline bool IValue::isTensorList() const {
|
||||
return isListOf<c10::TensorType>();
|
||||
}
|
||||
|
||||
inline bool IValue::isIntList() const {
|
||||
return isListOf<c10::IntType>();
|
||||
}
|
||||
|
||||
inline bool IValue::isBoolList() const {
|
||||
return isListOf<c10::BoolType>();
|
||||
}
|
||||
|
||||
inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
|
||||
: tag(Tag::Tuple), is_intrusive_ptr(true) {
|
||||
payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
|
||||
@ -2285,8 +2258,13 @@ struct IValue::TagType<c10::Type> {
|
||||
static TORCH_API c10::TypePtr get(const IValue&);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IValue::TagType<c10::DynamicType> {
|
||||
static TORCH_API c10::TypePtr get(const IValue&);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Ptr IValue::type() const {
|
||||
TypePtr IValue::type() const {
|
||||
return IValue::TagType<T>::get(*this);
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <ATen/core/TensorBody.h>
|
||||
#include <ATen/core/functional.h>
|
||||
#include <ATen/core/symbol.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <c10/util/Optional.h>
|
||||
@ -1730,7 +1731,8 @@ struct getTypePtr_<c10::QScheme> final {
|
||||
template <>
|
||||
struct getTypePtr_<at::Generator> final {
|
||||
static decltype(auto) call() {
|
||||
return OptionalType::create(GeneratorType::get());
|
||||
return TypeFactory::create<OptionalType>(
|
||||
TypeFactory::get<GeneratorType>());
|
||||
}
|
||||
};
|
||||
template <>
|
||||
@ -1798,7 +1800,8 @@ struct getTypePtr_<c10::Dict<K, V>> final {
|
||||
template <class T>
|
||||
struct getTypePtr_<at::optional<T>> final {
|
||||
static const auto& call() {
|
||||
static auto type = OptionalType::create(getTypePtr_<T>::call());
|
||||
static auto type = TypeFactory::create<OptionalType>(
|
||||
getTypePtr_<T>::call());
|
||||
return type;
|
||||
}
|
||||
};
|
||||
|
@ -558,6 +558,9 @@ struct TORCH_API Type {
|
||||
virtual TypePtr containedType(size_t i) const {
|
||||
return containedTypes().at(i);
|
||||
}
|
||||
virtual size_t containedTypeSize() const {
|
||||
return containedTypes().size();
|
||||
}
|
||||
// create a new version of this type, replacing its contained types with
|
||||
// contained_types
|
||||
TypePtr withContained(std::vector<TypePtr> contained_types);
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include <ATen/core/type_factory.h>
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// Dtype constraints are not constrained in compilation. Therefore, we map
|
||||
@ -56,4 +58,11 @@ const std::unordered_map<std::string, c10::TypePtr>& DefaultTypeFactory::
|
||||
return map;
|
||||
}
|
||||
|
||||
c10::TypePtr DefaultTypeFactory::createNamedTuple(
|
||||
const std::string& name,
|
||||
const std::vector<c10::string_view>& fields,
|
||||
const std::vector<c10::TypePtr>& types) {
|
||||
return c10::TupleType::createNamed(name, fields, types);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -1,12 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/dynamic_type.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <ATen/core/dynamic_type.h>
|
||||
#include <ATen/core/jit_type_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
struct TORCH_API DynamicTypeFactory {
|
||||
template <typename T>
|
||||
struct TORCH_API TypeFactoryBase {};
|
||||
|
||||
template <>
|
||||
struct TORCH_API TypeFactoryBase<c10::DynamicType> {
|
||||
template <typename T, typename... Args>
|
||||
static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) {
|
||||
return std::make_shared<c10::DynamicType>(
|
||||
@ -29,26 +36,40 @@ struct TORCH_API DynamicTypeFactory {
|
||||
name,
|
||||
c10::DynamicType::Arguments(fields, types));
|
||||
}
|
||||
template <typename T>
|
||||
C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) {
|
||||
return std::make_shared<c10::DynamicType>(
|
||||
c10::DynamicTypeTrait<T>::tagValue(),
|
||||
name,
|
||||
c10::DynamicType::Arguments{});
|
||||
}
|
||||
template <typename T>
|
||||
C10_ERASE static c10::DynamicTypePtr get() {
|
||||
return DynamicTypeTrait<T>::getBaseType();
|
||||
}
|
||||
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
|
||||
};
|
||||
|
||||
using DynamicTypeFactory = TypeFactoryBase<c10::DynamicType>;
|
||||
|
||||
// Helper functions for constructing DynamicTypes inline.
|
||||
template <
|
||||
typename T,
|
||||
std::enable_if_t<DynamicTypeTrait<T>::isBaseType, int> = 0>
|
||||
DynamicTypePtr dynT() {
|
||||
return DynamicTypeTrait<T>::getBaseType();
|
||||
C10_ERASE DynamicTypePtr dynT() {
|
||||
return DynamicTypeFactory::get<T>();
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename... Args,
|
||||
std::enable_if_t<!DynamicTypeTrait<T>::isBaseType, int> = 0>
|
||||
DynamicTypePtr dynT(Args&&... args) {
|
||||
C10_ERASE DynamicTypePtr dynT(Args&&... args) {
|
||||
return DynamicTypeFactory::create<T>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
struct TORCH_API DefaultTypeFactory {
|
||||
template <>
|
||||
struct TORCH_API TypeFactoryBase<c10::Type> {
|
||||
template <typename T, typename... Args>
|
||||
static c10::TypePtr create(TypePtr ty, Args&&... args) {
|
||||
return T::create(std::move(ty), std::forward<Args>(args)...);
|
||||
@ -60,18 +81,28 @@ struct TORCH_API DefaultTypeFactory {
|
||||
static c10::TypePtr createNamedTuple(
|
||||
const std::string& name,
|
||||
const std::vector<c10::string_view>& fields,
|
||||
const std::vector<c10::TypePtr>& types) {
|
||||
return c10::TupleType::createNamed(name, fields, types);
|
||||
const std::vector<c10::TypePtr>& types);
|
||||
template <typename T>
|
||||
C10_ERASE static c10::TypePtr createNamed(const std::string& name) {
|
||||
return T::create(name);
|
||||
}
|
||||
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
|
||||
template <typename T>
|
||||
C10_ERASE static c10::TypePtr get() {
|
||||
return T::get();
|
||||
}
|
||||
};
|
||||
|
||||
using TypeFactory =
|
||||
using DefaultTypeFactory = TypeFactoryBase<c10::Type>;
|
||||
|
||||
using PlatformType =
|
||||
#ifdef C10_MOBILE
|
||||
DynamicTypeFactory
|
||||
c10::DynamicType
|
||||
#else
|
||||
DefaultTypeFactory
|
||||
c10::Type
|
||||
#endif
|
||||
;
|
||||
|
||||
using TypeFactory = TypeFactoryBase<PlatformType>;
|
||||
|
||||
} // namespace c10
|
||||
|
@ -225,6 +225,16 @@ using namespace c10::hip;
|
||||
#define C10_ALWAYS_INLINE inline
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define C10_ATTR_VISIBILITY_HIDDEN
|
||||
#elif defined(__GNUC__)
|
||||
#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden")))
|
||||
#else
|
||||
#define C10_ATTR_VISIBILITY_HIDDEN
|
||||
#endif
|
||||
|
||||
#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN
|
||||
|
||||
// C10_FALLTHROUGH - Annotate fallthrough to the next case in a switch.
|
||||
#if C10_HAS_CPP_ATTRIBUTE(fallthrough)
|
||||
#define C10_FALLTHROUGH [[fallthrough]]
|
||||
|
@ -16,6 +16,9 @@ namespace torch {
|
||||
namespace jit {
|
||||
|
||||
static inline TypePtr unwrapOptional(TypePtr opt_type) {
|
||||
if (auto dyn = opt_type->castRaw<c10::DynamicType>()) {
|
||||
return unwrapOptional(dyn->fallback());
|
||||
}
|
||||
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
|
||||
return unwrap_list_type->getElementType();
|
||||
}
|
||||
@ -282,12 +285,17 @@ static bool varargsCanBeUsedAsList(
|
||||
bool is_last_argument = arg_index + 1 == schema.arguments().size() ||
|
||||
schema.arguments()[arg_index + 1].kwarg_only();
|
||||
|
||||
auto arg_type = arg.type();
|
||||
if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
|
||||
arg_type = dyn->fallback();
|
||||
}
|
||||
|
||||
// The formal must be a list
|
||||
bool argument_is_list = arg.type()->kind() == TypeKind::ListType;
|
||||
bool argument_is_list = arg_type->kind() == TypeKind::ListType;
|
||||
|
||||
// matching varargs of typevar list nyi
|
||||
bool typevar_list = argument_is_list &&
|
||||
arg.type()->castRaw<ListType>()->getElementType()->cast<VarType>();
|
||||
arg_type->castRaw<ListType>()->getElementType()->cast<VarType>();
|
||||
|
||||
// it must not be a broadcasting list like int[3],
|
||||
// otherwise a single int is a valid input
|
||||
|
@ -41,32 +41,33 @@ namespace jit {
|
||||
|
||||
TypePtr SchemaTypeParser::parseBaseType() {
|
||||
static std::unordered_map<std::string, TypePtr> type_map = {
|
||||
{"Generator", GeneratorType::get()},
|
||||
{"Dimname", StringType::get()},
|
||||
{"ScalarType", IntType::get()},
|
||||
{"Layout", IntType::get()},
|
||||
{"MemoryFormat", IntType::get()},
|
||||
{"Storage", StorageType::get()},
|
||||
{"QScheme", QSchemeType::get()},
|
||||
{"Quantizer", QuantizerType::get()},
|
||||
{"Generator", c10::TypeFactory::get<GeneratorType>()},
|
||||
{"Dimname", c10::TypeFactory::get<StringType>()},
|
||||
{"ScalarType", c10::TypeFactory::get<IntType>()},
|
||||
{"Layout", c10::TypeFactory::get<IntType>()},
|
||||
{"MemoryFormat", c10::TypeFactory::get<IntType>()},
|
||||
{"Storage", c10::TypeFactory::get<StorageType>()},
|
||||
{"QScheme", c10::TypeFactory::get<QSchemeType>()},
|
||||
{"Quantizer", c10::TypeFactory::get<QuantizerType>()},
|
||||
{"ConstQuantizerPtr",
|
||||
IntType::get()}, // TODO This type should be removed from the schema
|
||||
// parser, it should use the custom class mechanism
|
||||
// instead. @jerryzh
|
||||
{"Device", DeviceObjType::get()},
|
||||
{"Stream", StreamObjType::get()},
|
||||
{"Scalar", NumberType::get()},
|
||||
{"str", StringType::get()},
|
||||
{"float", FloatType::get()},
|
||||
{"complex", ComplexType::get()},
|
||||
{"int", IntType::get()},
|
||||
{"bool", BoolType::get()},
|
||||
{"None", NoneType::get()},
|
||||
{"NoneType", NoneType::get()},
|
||||
{"Capsule", CapsuleType::get()},
|
||||
{"Any", at::AnyType::get()},
|
||||
{"AnyClassType", at::AnyClassType::get()},
|
||||
{"AnyEnumType", at::AnyEnumType::get()},
|
||||
c10::TypeFactory::get<IntType>()}, // TODO This type should be removed
|
||||
// from the schema parser, it should
|
||||
// use the custom class mechanism
|
||||
// instead. @jerryzh
|
||||
{"Device", c10::TypeFactory::get<DeviceObjType>()},
|
||||
{"Stream", c10::TypeFactory::get<StreamObjType>()},
|
||||
{"Scalar", c10::TypeFactory::get<NumberType>()},
|
||||
{"str", c10::TypeFactory::get<StringType>()},
|
||||
{"float", c10::TypeFactory::get<FloatType>()},
|
||||
{"complex", c10::TypeFactory::get<ComplexType>()},
|
||||
{"int", c10::TypeFactory::get<IntType>()},
|
||||
{"bool", c10::TypeFactory::get<BoolType>()},
|
||||
{"None", c10::TypeFactory::get<NoneType>()},
|
||||
{"NoneType", c10::TypeFactory::get<NoneType>()},
|
||||
{"Capsule", c10::TypeFactory::get<CapsuleType>()},
|
||||
{"Any", c10::TypeFactory::get<c10::AnyType>()},
|
||||
{"AnyClassType", c10::TypeFactory::get<c10::AnyClassType>()},
|
||||
{"AnyEnumType", c10::TypeFactory::get<c10::AnyEnumType>()},
|
||||
};
|
||||
auto tok = L.cur();
|
||||
if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) {
|
||||
@ -79,7 +80,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
|
||||
if (text.size() > 0 && islower(text[0])) {
|
||||
// lower case identifiers that are not otherwise valid types
|
||||
// are treated as type variables
|
||||
return VarType::create(text);
|
||||
return c10::TypeFactory::createNamed<VarType>(text);
|
||||
}
|
||||
throw ErrorReport(tok.range) << "unknown type specifier";
|
||||
}
|
||||
@ -313,7 +314,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||
alias_info->addContainedType(std::move(*r.second));
|
||||
}
|
||||
});
|
||||
value = TupleType::create(std::move(types));
|
||||
value = c10::TypeFactory::create<TupleType>(std::move(types));
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
|
||||
L.next(); // Future
|
||||
L.expect('(');
|
||||
@ -321,7 +322,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||
auto subtype = std::move(p.first);
|
||||
auto subalias = std::move(p.second);
|
||||
L.expect(')');
|
||||
value = FutureType::create(subtype);
|
||||
value = c10::TypeFactory::create<FutureType>(subtype);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
|
||||
L.next(); // RRef
|
||||
L.expect('(');
|
||||
@ -329,10 +330,10 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||
auto subtype = std::move(p.first);
|
||||
auto subalias = std::move(p.second);
|
||||
L.expect(')');
|
||||
value = RRefType::create(subtype);
|
||||
value = c10::TypeFactory::create<RRefType>(subtype);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
|
||||
L.next();
|
||||
value = TensorType::get();
|
||||
value = c10::TypeFactory::get<TensorType>();
|
||||
alias_info = parseAliasAnnotation();
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
|
||||
L.next();
|
||||
@ -342,7 +343,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||
auto value_type = parseType().first;
|
||||
L.expect(')');
|
||||
alias_info = parseAliasAnnotation();
|
||||
value = DictType::create(key_type, value_type);
|
||||
value = c10::TypeFactory::create<DictType>(key_type, value_type);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
|
||||
L.next();
|
||||
L.expect('(');
|
||||
@ -395,7 +396,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||
if (L.cur().kind == '[' && L.lookahead().kind == ']') {
|
||||
L.next(); // [
|
||||
L.next(); // ]
|
||||
value = ListType::create(value);
|
||||
value = c10::TypeFactory::create<ListType>(value);
|
||||
auto container = parseAliasAnnotation();
|
||||
if (container && alias_info) {
|
||||
container->addContainedType(std::move(*alias_info));
|
||||
|
@ -1485,6 +1485,9 @@ inline Value::Value(Node* node_, size_t offset_)
|
||||
|
||||
inline Value* Value::setType(TypePtr type) {
|
||||
AT_ASSERT(type);
|
||||
if (auto dyn = type->castRaw<c10::DynamicType>()) {
|
||||
type = dyn->fallback();
|
||||
}
|
||||
type_ = std::move(type);
|
||||
for (Use& use : uses_) {
|
||||
use.user->op_ = nullptr;
|
||||
|
@ -108,7 +108,11 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
static const std::string torch_prefix("__torch__");
|
||||
static const std::string class_prefix("__torch__.torch.classes");
|
||||
|
||||
for (const TypePtr& t : mobile_code.types_) {
|
||||
for (const TypePtr& ty : mobile_code.types_) {
|
||||
auto t = ty;
|
||||
if (auto dyn = t->castRaw<c10::DynamicType>()) {
|
||||
t = dyn->fallback();
|
||||
}
|
||||
std::string type_str = t->annotation_str();
|
||||
if (t->kind() == TypeKind::TupleType) {
|
||||
TORCH_CHECK(
|
||||
@ -216,9 +220,13 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
arg.type()->annotation_str(type_printer) => mangled unique name of the
|
||||
module/submodule
|
||||
*/
|
||||
auto arg_type = arg.type();
|
||||
if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
|
||||
arg_type = dyn->fallback();
|
||||
}
|
||||
argTables.emplace_back(Table({
|
||||
{"name", arg.name()},
|
||||
{"type", arg.type()->annotation_str(type_printer)},
|
||||
{"type", arg_type->annotation_str(type_printer)},
|
||||
{"default_value", arg.default_value()},
|
||||
}));
|
||||
}
|
||||
|
@ -575,7 +575,11 @@ void Pickler::endTypeTag(const IValue& ivalue) {
|
||||
|
||||
// Push the dict type
|
||||
TORCH_INTERNAL_ASSERT(ivalue.type());
|
||||
pushString(ivalue.type()->annotation_str());
|
||||
auto type = ivalue.type();
|
||||
if (auto dyn = type->castRaw<c10::DynamicType>()) {
|
||||
type = dyn->fallback();
|
||||
}
|
||||
pushString(type->annotation_str());
|
||||
|
||||
// Pop the dict and type into a tuple
|
||||
push<PickleOpCode>(PickleOpCode::TUPLE2);
|
||||
|
@ -34,7 +34,7 @@ static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
|
||||
// of the contained objects and cannot restore the tags.
|
||||
void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||
struct Work {
|
||||
TypePtr static_type;
|
||||
TypePtr type;
|
||||
IValue value;
|
||||
};
|
||||
std::vector<Work> to_process = {{type_tag, root}};
|
||||
@ -53,7 +53,11 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||
}
|
||||
scanned.emplace_hint(it, key);
|
||||
}
|
||||
switch (w.static_type->kind()) {
|
||||
auto kind = w.type->kind();
|
||||
if (auto dyn = w.type->castRaw<c10::DynamicType>()) {
|
||||
kind = dyn->dynamicKind();
|
||||
}
|
||||
switch (kind) {
|
||||
case TensorType::Kind:
|
||||
case StorageType::Kind:
|
||||
case NumberType::Kind:
|
||||
@ -83,52 +87,37 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||
// no op, there is nothing to tag
|
||||
break;
|
||||
case DynamicType::Kind:
|
||||
case UnionType::Kind:
|
||||
case EnumType::Kind:
|
||||
// TODO(gmagogsfm): Implement serialization/deserialization of Enum.
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
case TupleType::Kind: {
|
||||
auto t = w.value.toTuple();
|
||||
auto ttype = w.static_type->expect<TupleType>();
|
||||
for (size_t i = 0; i < ttype->containedTypes().size(); ++i) {
|
||||
Work elem = {ttype->containedTypes().at(i), t->elements().at(i)};
|
||||
for (size_t i = 0; i < w.type->containedTypeSize(); ++i) {
|
||||
Work elem = {w.type->containedType(i), t->elements().at(i)};
|
||||
to_process.emplace_back(std::move(elem));
|
||||
}
|
||||
} break;
|
||||
case FutureType::Kind: {
|
||||
auto f = w.value.toFuture();
|
||||
auto t = w.static_type->expect<FutureType>();
|
||||
if (f->completed()) {
|
||||
Work elem = {t->getElementType(), f->value()};
|
||||
Work elem = {w.type->containedType(0), f->value()};
|
||||
to_process.emplace_back(std::move(elem));
|
||||
}
|
||||
} break;
|
||||
case OptionalType::Kind: {
|
||||
if (!w.value.isNone()) {
|
||||
auto t = w.static_type->expect<OptionalType>();
|
||||
Work elem = {t->getElementType(), w.value};
|
||||
Work elem = {w.type->containedType(0), w.value};
|
||||
to_process.emplace_back(std::move(elem));
|
||||
}
|
||||
} break;
|
||||
case UnionType::Kind: {
|
||||
auto t = w.static_type->expect<UnionType>();
|
||||
if (t->containedTypes().size() == 2 &&
|
||||
t->canHoldType(*NoneType::get())) {
|
||||
if (!w.value.isNone()) {
|
||||
auto inner = t->containedTypes()[0] != NoneType::get()
|
||||
? t->containedTypes()[0]
|
||||
: t->containedTypes()[1];
|
||||
Work elem = {inner, w.value};
|
||||
to_process.emplace_back(std::move(elem));
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case ListType::Kind: {
|
||||
// specialized lists do not need their type refined, so we can exit
|
||||
// early here
|
||||
if (!w.value.isList()) {
|
||||
break;
|
||||
}
|
||||
auto elem_type = w.static_type->castRaw<ListType>()->getElementType();
|
||||
auto elem_type = w.type->containedType(0);
|
||||
auto lst = w.value.toList();
|
||||
lst.unsafeSetElementType(elem_type);
|
||||
for (const IValue item : lst) {
|
||||
@ -137,13 +126,14 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||
}
|
||||
} break;
|
||||
case DictType::Kind: {
|
||||
auto dt = w.static_type->cast<DictType>();
|
||||
auto d = w.value.toGenericDict();
|
||||
d.unsafeSetKeyType(dt->getKeyType());
|
||||
d.unsafeSetValueType(dt->getValueType());
|
||||
auto keyType = w.type->containedType(0);
|
||||
auto valType = w.type->containedType(1);
|
||||
d.unsafeSetKeyType(keyType);
|
||||
d.unsafeSetValueType(valType);
|
||||
for (const auto& item : d) {
|
||||
Work kelem = {dt->getKeyType(), item.key()};
|
||||
Work velem = {dt->getValueType(), item.value()};
|
||||
Work kelem = {keyType, item.key()};
|
||||
Work velem = {valType, item.value()};
|
||||
to_process.emplace_back(std::move(kelem));
|
||||
to_process.emplace_back(std::move(velem));
|
||||
}
|
||||
|
Reference in New Issue
Block a user