mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit][edge] Use dynamic type instead of union types for schema parsers. (#70509)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70509 TypeFactory will construct DynamicType when building on Edge platforms. We use this facility to make FunctionSchema return DynamicType all the time for OptionalType. We don't explicitly use DynamicTypeFactory everywhere because that requires too many changes and will split the entire aten codebase. ghstack-source-id: 146818621 Test Plan: CI Reviewed By: iseeyuan Differential Revision: D33306737 fbshipit-source-id: d7ce00b438f7c03b43945d578280cfd254b1f634
This commit is contained in:
committed by
Facebook GitHub Bot
parent
40121456af
commit
9465c24245
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||
|
||||
#include <ATen/core/Reduction.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
#include <torch/csrc/jit/frontend/lexer.h>
|
||||
#include <torch/csrc/jit/frontend/parse_string_literal.h>
|
||||
@ -148,7 +149,7 @@ struct SchemaParser {
|
||||
}
|
||||
alias_info = std::move(container);
|
||||
if (L.nextIf('?')) {
|
||||
type = OptionalType::create(std::move(type));
|
||||
type = c10::TypeFactory::create<c10::OptionalType>(std::move(type));
|
||||
}
|
||||
}
|
||||
if (is_return) {
|
||||
@ -161,7 +162,7 @@ struct SchemaParser {
|
||||
} else {
|
||||
name = L.expect(TK_IDENT).text();
|
||||
if (L.nextIf('=')) {
|
||||
default_value = parseDefaultValue(type, N);
|
||||
default_value = parseDefaultValue(*type, type->kind(), N);
|
||||
}
|
||||
}
|
||||
return Argument(
|
||||
@ -172,7 +173,11 @@ struct SchemaParser {
|
||||
!is_return && kwarg_only,
|
||||
std::move(alias_info));
|
||||
}
|
||||
IValue parseSingleConstant(TypeKind kind) {
|
||||
IValue parseSingleConstant(const c10::Type& type, TypeKind kind) {
|
||||
if (kind == c10::TypeKind::DynamicType) {
|
||||
return parseSingleConstant(
|
||||
type, type.expectRef<c10::DynamicType>().dynamicKind());
|
||||
}
|
||||
switch (L.cur().kind) {
|
||||
case TK_TRUE:
|
||||
L.next();
|
||||
@ -227,6 +232,7 @@ struct SchemaParser {
|
||||
}
|
||||
}
|
||||
IValue convertToList(
|
||||
const c10::Type& type,
|
||||
TypeKind kind,
|
||||
const SourceRange& range,
|
||||
const std::vector<IValue>& vs) {
|
||||
@ -239,21 +245,24 @@ struct SchemaParser {
|
||||
return fmap(vs, [](const IValue& v) { return v.toInt(); });
|
||||
case TypeKind::BoolType:
|
||||
return fmap(vs, [](const IValue& v) { return v.toBool(); });
|
||||
case TypeKind::DynamicType:
|
||||
return convertToList(
|
||||
type, type.expectRef<c10::DynamicType>().dynamicKind(), range, vs);
|
||||
default:
|
||||
throw ErrorReport(range)
|
||||
<< "lists are only supported for float, int and complex types";
|
||||
}
|
||||
}
|
||||
IValue parseConstantList(TypeKind kind) {
|
||||
IValue parseConstantList(const c10::Type& type, TypeKind kind) {
|
||||
auto tok = L.expect('[');
|
||||
std::vector<IValue> vs;
|
||||
if (L.cur().kind != ']') {
|
||||
do {
|
||||
vs.push_back(parseSingleConstant(kind));
|
||||
vs.push_back(parseSingleConstant(type, kind));
|
||||
} while (L.nextIf(','));
|
||||
}
|
||||
L.expect(']');
|
||||
return convertToList(kind, tok.range, vs);
|
||||
return convertToList(type, kind, tok.range, vs);
|
||||
}
|
||||
|
||||
IValue parseTensorDefault(const SourceRange& range) {
|
||||
@ -261,10 +270,11 @@ struct SchemaParser {
|
||||
return IValue();
|
||||
}
|
||||
IValue parseDefaultValue(
|
||||
const TypePtr& arg_type,
|
||||
const c10::Type& arg_type,
|
||||
TypeKind kind,
|
||||
c10::optional<int32_t> arg_N) {
|
||||
auto range = L.cur().range;
|
||||
switch (arg_type->kind()) {
|
||||
switch (kind) {
|
||||
case TypeKind::TensorType:
|
||||
case TypeKind::GeneratorType:
|
||||
case TypeKind::QuantizerType: {
|
||||
@ -277,7 +287,7 @@ struct SchemaParser {
|
||||
case TypeKind::BoolType:
|
||||
case TypeKind::FloatType:
|
||||
case TypeKind::ComplexType:
|
||||
return parseSingleConstant(arg_type->kind());
|
||||
return parseSingleConstant(arg_type, kind);
|
||||
break;
|
||||
case TypeKind::DeviceObjType: {
|
||||
auto device_text =
|
||||
@ -286,17 +296,22 @@ struct SchemaParser {
|
||||
break;
|
||||
}
|
||||
case TypeKind::ListType: {
|
||||
auto elem_kind = arg_type->castRaw<ListType>()->getElementType();
|
||||
auto elem_type = arg_type.containedType(0);
|
||||
if (L.cur().kind == TK_IDENT) {
|
||||
return parseTensorDefault(range);
|
||||
} else if (arg_N && L.cur().kind != '[') {
|
||||
IValue v = parseSingleConstant(elem_kind->kind());
|
||||
IValue v = parseSingleConstant(*elem_type, elem_type->kind());
|
||||
std::vector<IValue> repeated(*arg_N, v);
|
||||
return convertToList(elem_kind->kind(), range, repeated);
|
||||
return convertToList(*elem_type, elem_type->kind(), range, repeated);
|
||||
} else {
|
||||
return parseConstantList(elem_kind->kind());
|
||||
return parseConstantList(*elem_type, elem_type->kind());
|
||||
}
|
||||
} break;
|
||||
case TypeKind::DynamicType:
|
||||
return parseDefaultValue(
|
||||
arg_type,
|
||||
arg_type.expectRef<c10::DynamicType>().dynamicKind(),
|
||||
arg_N);
|
||||
default:
|
||||
throw ErrorReport(range) << "unexpected type, file a bug report";
|
||||
}
|
||||
|
Reference in New Issue
Block a user