[jit][edge] Print correct type strings in code file for mobile models. (#71968)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71968

Right now when we output type to python files under `code/`, we directly write the dynamic type representation `Dynamic<>`, which causes server side to load an unsupported type. Instead we should do the fallback in export_module.cpp.
ghstack-source-id: 147856473

Test Plan:
CI
buck test //xplat/pytorch/mobile/test:test_read_all_mobile_model_configs
```
...
[       OK ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/37 (39142 ms)
[ RUN      ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/38
 total: 6 success: 6 failure: 0
[       OK ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/38 (9651 ms)
[ RUN      ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/39
 total: 4 success: 4 failure: 0
[       OK ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/39 (5509 ms)
[----------] 40 tests from GeneralAndSpecial/BackPortTest (806244 ms total)

[----------] Global test environment tear-down
[==========] 41 tests from 2 test cases ran. (810453 ms total)
[  PASSED  ] 41 tests.
```

Reviewed By: pavithranrao

Differential Revision: D33830355

fbshipit-source-id: 0be608fadf14daa2b703f31118ab648cb7b75f9b
(cherry picked from commit 6d65049ae5ac1ef6a11d19de48dd4d926b793b34)
This commit is contained in:
Zhengxu Chen
2022-01-28 12:01:53 -08:00
committed by PyTorch MergeBot
parent 63429bf4b3
commit bc0e216d1f
3 changed files with 30 additions and 13 deletions

View File

@ -467,7 +467,7 @@ std::ostream& printMaybeAnnotatedList(
auto list_elem_type = the_list.type()->containedType(0);
if (the_list.toListRef().size() == 0 ||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
out << "annotate(" << the_list.type()->annotation_str() << ", ";
out << "annotate(" << the_list.type<c10::Type>()->annotation_str() << ", ";
printList(out, the_list.toListRef(), "[", "]", formatter);
out << ")";
return out;
@ -508,7 +508,7 @@ std::ostream& printMaybeAnnotatedDict(
auto value_type = the_dict.type()->castRaw<DictType>()->getValueType();
if (the_dict.toGenericDict().size() == 0 ||
!elementTypeCanBeInferredFromMembers(value_type)) {
out << "annotate(" << the_dict.type()->annotation_str() << ",";
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ",";
printDict(out, the_dict.toGenericDict(), formatter) << ")";
} else {
return printDict(out, the_dict.toGenericDict(), formatter);

View File

@ -706,6 +706,24 @@ void ScriptModuleSerializer::writeByteCode(
}
}
namespace {
c10::optional<std::string> type_printer(
const c10::Type& type,
torch::jit::TypeNameUniquer& type_name_uniquer) {
if (auto dyn = type.castRaw<c10::DynamicType>()) {
return dyn->fallback()->annotation_str(
[&](auto&& t) { return type_printer(t, type_name_uniquer); });
}
auto namedType = type.cast<c10::NamedType>();
if (namedType && namedType->name()) {
return type_name_uniquer.getUniqueName(namedType).qualifiedName();
}
return c10::nullopt;
}
} // namespace
void ScriptModuleSerializer::convertNamedType(
const c10::NamedTypePtr& class_type) {
if (converted_types_.count(class_type)) {
@ -716,20 +734,15 @@ void ScriptModuleSerializer::convertNamedType(
std::string qualifier = qualname.prefix();
PythonPrint* pp = file_streams_.find(qualifier);
auto type_printer = [&](const c10::Type& t) -> c10::optional<std::string> {
auto namedType = t.cast<c10::NamedType>();
if (namedType && namedType->name()) {
return type_name_uniquer_.getUniqueName(namedType).qualifiedName();
}
return c10::nullopt;
};
if (!pp) {
pp = &file_streams_.insert(
std::move(qualifier),
PythonPrint(
constant_table_,
class_deps_,
type_printer,
[&](const c10::Type& t) {
return type_printer(t, type_name_uniquer_);
},
/*enforce_importable=*/true));
}
pp->printNamedType(class_type);

View File

@ -922,15 +922,19 @@ struct PythonPrintImpl {
void printConstant(TaggedStringStream& stmt, const IValue& v) {
const auto customFormatter = [&](std::ostream& ss, const IValue& v) {
if (v.isTensor() || containsNonASCIIString(v) || v.isObject()) {
TORCH_INTERNAL_ASSERT(!v.type()->is_module());
TORCH_INTERNAL_ASSERT(!v.type<c10::Type>()->is_module());
ss << "CONSTANTS.c" << getOrAddConstant(v);
return true;
}
if (v.isTuple() && v.type()->expectRef<TupleType>().schema()) {
auto type = v.type();
if (auto dyn = type->castRaw<c10::DynamicType>()) {
type = dyn->fallback();
}
if (v.isTuple() && type->expectRef<TupleType>().schema()) {
// print the namedtuple constructor and let rest of tuple printing
// continue
ss << v.type()->expectRef<TupleType>().annotation_str(type_printer_);
ss << type->expectRef<TupleType>().annotation_str(type_printer_);
}
return false;
};