[jit][edge] Use dynamic type instead of union types for schema parsers. (#70509)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70509

TypeFactory will construct DynamicType when building on Edge platforms. We use this facility to make FunctionSchema return DynamicType all the time for OptionalType. We don't explicitly use DynamicTypeFactory everywhere because that requires too many changes and will split the entire aten codebase.
ghstack-source-id: 146818621

Test Plan: CI

Reviewed By: iseeyuan

Differential Revision: D33306737

fbshipit-source-id: d7ce00b438f7c03b43945d578280cfd254b1f634
This commit is contained in:
Zhengxu Chen
2022-01-11 20:12:53 -08:00
committed by Facebook GitHub Bot
parent 40121456af
commit 9465c24245
9 changed files with 146 additions and 21 deletions

View File

@ -1,6 +1,7 @@
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <ATen/core/Reduction.h>
#include <ATen/core/type_factory.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/frontend/lexer.h>
#include <torch/csrc/jit/frontend/parse_string_literal.h>
@ -148,7 +149,7 @@ struct SchemaParser {
}
alias_info = std::move(container);
if (L.nextIf('?')) {
type = OptionalType::create(std::move(type));
type = c10::TypeFactory::create<c10::OptionalType>(std::move(type));
}
}
if (is_return) {
@ -161,7 +162,7 @@ struct SchemaParser {
} else {
name = L.expect(TK_IDENT).text();
if (L.nextIf('=')) {
default_value = parseDefaultValue(type, N);
default_value = parseDefaultValue(*type, type->kind(), N);
}
}
return Argument(
@ -172,7 +173,11 @@ struct SchemaParser {
!is_return && kwarg_only,
std::move(alias_info));
}
IValue parseSingleConstant(TypeKind kind) {
IValue parseSingleConstant(const c10::Type& type, TypeKind kind) {
if (kind == c10::TypeKind::DynamicType) {
return parseSingleConstant(
type, type.expectRef<c10::DynamicType>().dynamicKind());
}
switch (L.cur().kind) {
case TK_TRUE:
L.next();
@ -227,6 +232,7 @@ struct SchemaParser {
}
}
IValue convertToList(
const c10::Type& type,
TypeKind kind,
const SourceRange& range,
const std::vector<IValue>& vs) {
@ -239,21 +245,24 @@ struct SchemaParser {
return fmap(vs, [](const IValue& v) { return v.toInt(); });
case TypeKind::BoolType:
return fmap(vs, [](const IValue& v) { return v.toBool(); });
case TypeKind::DynamicType:
return convertToList(
type, type.expectRef<c10::DynamicType>().dynamicKind(), range, vs);
default:
throw ErrorReport(range)
<< "lists are only supported for float, int and complex types";
}
}
IValue parseConstantList(TypeKind kind) {
IValue parseConstantList(const c10::Type& type, TypeKind kind) {
auto tok = L.expect('[');
std::vector<IValue> vs;
if (L.cur().kind != ']') {
do {
vs.push_back(parseSingleConstant(kind));
vs.push_back(parseSingleConstant(type, kind));
} while (L.nextIf(','));
}
L.expect(']');
return convertToList(kind, tok.range, vs);
return convertToList(type, kind, tok.range, vs);
}
IValue parseTensorDefault(const SourceRange& range) {
@ -261,10 +270,11 @@ struct SchemaParser {
return IValue();
}
IValue parseDefaultValue(
const TypePtr& arg_type,
const c10::Type& arg_type,
TypeKind kind,
c10::optional<int32_t> arg_N) {
auto range = L.cur().range;
switch (arg_type->kind()) {
switch (kind) {
case TypeKind::TensorType:
case TypeKind::GeneratorType:
case TypeKind::QuantizerType: {
@ -277,7 +287,7 @@ struct SchemaParser {
case TypeKind::BoolType:
case TypeKind::FloatType:
case TypeKind::ComplexType:
return parseSingleConstant(arg_type->kind());
return parseSingleConstant(arg_type, kind);
break;
case TypeKind::DeviceObjType: {
auto device_text =
@ -286,17 +296,22 @@ struct SchemaParser {
break;
}
case TypeKind::ListType: {
auto elem_kind = arg_type->castRaw<ListType>()->getElementType();
auto elem_type = arg_type.containedType(0);
if (L.cur().kind == TK_IDENT) {
return parseTensorDefault(range);
} else if (arg_N && L.cur().kind != '[') {
IValue v = parseSingleConstant(elem_kind->kind());
IValue v = parseSingleConstant(*elem_type, elem_type->kind());
std::vector<IValue> repeated(*arg_N, v);
return convertToList(elem_kind->kind(), range, repeated);
return convertToList(*elem_type, elem_type->kind(), range, repeated);
} else {
return parseConstantList(elem_kind->kind());
return parseConstantList(*elem_type, elem_type->kind());
}
} break;
case TypeKind::DynamicType:
return parseDefaultValue(
arg_type,
arg_type.expectRef<c10::DynamicType>().dynamicKind(),
arg_N);
default:
throw ErrorReport(range) << "unexpected type, file a bug report";
}