mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit][edge] Print correct type strings in code file for mobile models. (#71968)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71968 Right now when we output type to python files under `code/`, we directly write the dynamic type representation `Dynamic<>`, which causes server side to load an unsupported type. Instead we should do the fallback in export_module.cpp. ghstack-source-id: 147856473 Test Plan: CI buck test //xplat/pytorch/mobile/test:test_read_all_mobile_model_configs ``` ... [ OK ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/37 (39142 ms) [ RUN ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/38 total: 6 success: 6 failure: 0 [ OK ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/38 (9651 ms) [ RUN ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/39 total: 4 success: 4 failure: 0 [ OK ] GeneralAndSpecial/BackPortTest.BackPortForChunkIdx/39 (5509 ms) [----------] 40 tests from GeneralAndSpecial/BackPortTest (806244 ms total) [----------] Global test environment tear-down [==========] 41 tests from 2 test cases ran. (810453 ms total) [ PASSED ] 41 tests. ``` Reviewed By: pavithranrao Differential Revision: D33830355 fbshipit-source-id: 0be608fadf14daa2b703f31118ab648cb7b75f9b (cherry picked from commit 6d65049ae5ac1ef6a11d19de48dd4d926b793b34)
This commit is contained in:
committed by
PyTorch MergeBot
parent
63429bf4b3
commit
bc0e216d1f
@ -467,7 +467,7 @@ std::ostream& printMaybeAnnotatedList(
|
||||
auto list_elem_type = the_list.type()->containedType(0);
|
||||
if (the_list.toListRef().size() == 0 ||
|
||||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
|
||||
out << "annotate(" << the_list.type()->annotation_str() << ", ";
|
||||
out << "annotate(" << the_list.type<c10::Type>()->annotation_str() << ", ";
|
||||
printList(out, the_list.toListRef(), "[", "]", formatter);
|
||||
out << ")";
|
||||
return out;
|
||||
@ -508,7 +508,7 @@ std::ostream& printMaybeAnnotatedDict(
|
||||
auto value_type = the_dict.type()->castRaw<DictType>()->getValueType();
|
||||
if (the_dict.toGenericDict().size() == 0 ||
|
||||
!elementTypeCanBeInferredFromMembers(value_type)) {
|
||||
out << "annotate(" << the_dict.type()->annotation_str() << ",";
|
||||
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ",";
|
||||
printDict(out, the_dict.toGenericDict(), formatter) << ")";
|
||||
} else {
|
||||
return printDict(out, the_dict.toGenericDict(), formatter);
|
||||
|
@ -706,6 +706,24 @@ void ScriptModuleSerializer::writeByteCode(
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
c10::optional<std::string> type_printer(
|
||||
const c10::Type& type,
|
||||
torch::jit::TypeNameUniquer& type_name_uniquer) {
|
||||
if (auto dyn = type.castRaw<c10::DynamicType>()) {
|
||||
return dyn->fallback()->annotation_str(
|
||||
[&](auto&& t) { return type_printer(t, type_name_uniquer); });
|
||||
}
|
||||
auto namedType = type.cast<c10::NamedType>();
|
||||
if (namedType && namedType->name()) {
|
||||
return type_name_uniquer.getUniqueName(namedType).qualifiedName();
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ScriptModuleSerializer::convertNamedType(
|
||||
const c10::NamedTypePtr& class_type) {
|
||||
if (converted_types_.count(class_type)) {
|
||||
@ -716,20 +734,15 @@ void ScriptModuleSerializer::convertNamedType(
|
||||
std::string qualifier = qualname.prefix();
|
||||
PythonPrint* pp = file_streams_.find(qualifier);
|
||||
|
||||
auto type_printer = [&](const c10::Type& t) -> c10::optional<std::string> {
|
||||
auto namedType = t.cast<c10::NamedType>();
|
||||
if (namedType && namedType->name()) {
|
||||
return type_name_uniquer_.getUniqueName(namedType).qualifiedName();
|
||||
}
|
||||
return c10::nullopt;
|
||||
};
|
||||
if (!pp) {
|
||||
pp = &file_streams_.insert(
|
||||
std::move(qualifier),
|
||||
PythonPrint(
|
||||
constant_table_,
|
||||
class_deps_,
|
||||
type_printer,
|
||||
[&](const c10::Type& t) {
|
||||
return type_printer(t, type_name_uniquer_);
|
||||
},
|
||||
/*enforce_importable=*/true));
|
||||
}
|
||||
pp->printNamedType(class_type);
|
||||
|
@ -922,15 +922,19 @@ struct PythonPrintImpl {
|
||||
void printConstant(TaggedStringStream& stmt, const IValue& v) {
|
||||
const auto customFormatter = [&](std::ostream& ss, const IValue& v) {
|
||||
if (v.isTensor() || containsNonASCIIString(v) || v.isObject()) {
|
||||
TORCH_INTERNAL_ASSERT(!v.type()->is_module());
|
||||
TORCH_INTERNAL_ASSERT(!v.type<c10::Type>()->is_module());
|
||||
ss << "CONSTANTS.c" << getOrAddConstant(v);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (v.isTuple() && v.type()->expectRef<TupleType>().schema()) {
|
||||
auto type = v.type();
|
||||
if (auto dyn = type->castRaw<c10::DynamicType>()) {
|
||||
type = dyn->fallback();
|
||||
}
|
||||
if (v.isTuple() && type->expectRef<TupleType>().schema()) {
|
||||
// print the namedtuple constructor and let rest of tuple printing
|
||||
// continue
|
||||
ss << v.type()->expectRef<TupleType>().annotation_str(type_printer_);
|
||||
ss << type->expectRef<TupleType>().annotation_str(type_printer_);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
Reference in New Issue
Block a user