mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Export NamedTuple when it's nested in first type layer Dict (#75996)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75996 Nested NamedTuple is supported when loading the model. However one case is missing when exporting the model. if it's the first layer, we haven't covered the `Dict` type yet. Before: ``` // ty is a generic type pointer and can be any type for (const TypePtr& ty : mobile_code.types_) { std::string type_str = get_type_str(t); if (t is TupleType) do B } ``` After: ``` for (const TypePtr& ty : mobile_code.types_) { std::string type_str = get_type_str(t); if (t is DictType) do A else if (t is TupleType) do B } ``` ghstack-source-id: 154292348 Test Plan: Use the uploaded model from Everstore: `GBE5xgh6J6T0ZfsAAAhQ7n_pxB90br0LAAAP`. Get it by `clowder get GBE5xgh6J6T0ZfsAAAhQ7n_pxB90br0LAAAP namedtuple.ptl`. ``` TEST(LiteInterpreterTest, DebugDper) { std::string path = "/data/sandcastle/boxes/fbsource/fbcode/caffe2/test/cpp/jit/namedtuple.ptl"; // mobile::Module bc = _load_for_mobile(path); Module jit_m = load(path); std::string resave_path = "/data/sandcastle/boxes/fbsource/fbcode/caffe2/test/cpp/jit/namedtuple_reave.ptl"; jit_m._save_for_mobile(resave_path); mobile::Module bc = _load_for_mobile(resave_path); } ``` ``` buck test //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.DebugDper' buck test mode/opt-split-dwarf //dper3/dper3/modules/tests:id_score_list_to_id_list_test ``` Reviewed By: iseeyuan Differential Revision: D35705480 fbshipit-source-id: b8da2e720b8ca247bb40f13b67b75b5a04709f7a (cherry picked from commit 73bb6f9ddbefcd7e55e8660a9b55ae6b9eb9759c)
This commit is contained in:
committed by
PyTorch MergeBot
parent
317b8fa7ae
commit
d938867f91
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import torch.utils.bundled_inputs
|
||||
import io
|
||||
from typing import List, NamedTuple
|
||||
from typing import Dict, List, NamedTuple
|
||||
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
@ -33,6 +33,69 @@ class TestLiteScriptModule(TestCase):
|
||||
mobile_module_result
|
||||
)
|
||||
|
||||
|
||||
def test_typing_dict_with_namedtuple(self):
|
||||
class Foo(NamedTuple):
|
||||
id: torch.Tensor
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Bar, self).__init__()
|
||||
self.foo = Foo(torch.tensor(1))
|
||||
|
||||
def forward(self, a: torch.Tensor):
|
||||
self.foo = Foo(a)
|
||||
re: Dict[str, Foo] = dict()
|
||||
re["test"] = Foo(a)
|
||||
return self.foo, re["test"]
|
||||
|
||||
# The corresponding bytecode is
|
||||
# (8,
|
||||
# ('__torch__.___torch_mangle_2.Bar.forward',
|
||||
# (('instructions',
|
||||
# (('STOREN', 1, 2),
|
||||
# ('DROPR', 1, 0),
|
||||
# ('DICT_CONSTRUCT', 0, 0),
|
||||
# ('STORE', 3, 0),
|
||||
# ('LOAD', 3, 0),
|
||||
# ('LOADC', 1, 0),
|
||||
# ('MOVE', 2, 0),
|
||||
# ('NAMED_TUPLE_CONSTRUCT', 1, 1),
|
||||
# ('OP', 0, 0),
|
||||
# ('MOVE', 3, 0),
|
||||
# ('LOADC', 1, 0),
|
||||
# ('DICT_INDEX', 0, 0),
|
||||
# ('LOADC', 0, 0),
|
||||
# ('TUPLE_INDEX', 0, 0),
|
||||
# ('RET', 0, 0))),
|
||||
# ('operators', (('aten::_set_item', 'str', 3),)),
|
||||
# ('constants', (0, 'test')),
|
||||
# ('types',
|
||||
# ('Dict[str,__torch__.Foo[NamedTuple, [[id, Tensor]]]]',
|
||||
# '__torch__.Foo[NamedTuple, [[id, Tensor]]]')),
|
||||
# ('register_size', 3)),
|
||||
# (('arguments',
|
||||
# ((('name', 'self'),
|
||||
# ('type', '__torch__.___torch_mangle_2.Bar'),
|
||||
# ('default_value', None)),
|
||||
# (('name', 'a'), ('type', 'Tensor'), ('default_value', None)))),
|
||||
# ('returns',
|
||||
# ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
|
||||
|
||||
sample_input = torch.tensor(5)
|
||||
script_module = torch.jit.script(Bar())
|
||||
|
||||
script_module_result = script_module(sample_input)
|
||||
|
||||
buffer_mobile = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
||||
buffer_mobile.seek(0)
|
||||
mobile_module = _load_for_lite_interpreter(buffer_mobile)
|
||||
mobile_module_result = mobile_module(sample_input)
|
||||
torch.testing.assert_allclose(
|
||||
script_module_result,
|
||||
mobile_module_result
|
||||
)
|
||||
|
||||
def test_typing_namedtuple_custom_classtype(self):
|
||||
class Foo(NamedTuple):
|
||||
id: torch.Tensor
|
||||
|
@ -80,6 +80,88 @@ ExportModuleExtraFilesHook& GetExtraFilesHook() {
|
||||
return func;
|
||||
}
|
||||
|
||||
/**
|
||||
* If the type is not NamedTuple, it will return default_type_str. If the type
|
||||
* is a NamedTuple, it will return a string with following structure to describe
|
||||
* the content in the NamedTuple: "qualified_named[ NamedTuple, [ [filed_name_1,
|
||||
* field_type_1], [filed_name_2, field_type_2]
|
||||
* ]
|
||||
* ]"
|
||||
* Example NamedTuple type:
|
||||
* "__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[
|
||||
* NamedTuple, [
|
||||
* [float_features, Tensor],
|
||||
* [id_list_features, List[Tensor]],
|
||||
* [label, Tensor],
|
||||
* [weight, Tensor],
|
||||
* ]
|
||||
* ]"
|
||||
*
|
||||
* @param compilation_unit Jit compilcation unit to look up function schema.
|
||||
* @param type_ptr A type pointer and it can be possibly any type.
|
||||
* @param default_type_str The default string representation. The string can
|
||||
* either from type_ptr->str(), type_ptr->annotation_str(), or
|
||||
* type_ptr->repr_str(). In some cases, they could be different in different
|
||||
* scenario. For example, Tensor type can be "Tensor", "Tensor (inferred)" and
|
||||
* "Tensor[]", and we only want "Tensor". Leave it as part of arguments as the
|
||||
* default return, when type_ptr is not a NamedTuple.
|
||||
* @return string representation.
|
||||
*/
|
||||
std::string get_named_tuple_str_or_default(
|
||||
const CompilationUnit& compilation_unit,
|
||||
const TypePtr& type_ptr,
|
||||
std::string default_type_str) {
|
||||
if (type_ptr->kind() == TypeKind::TupleType) {
|
||||
TORCH_CHECK(
|
||||
compilation_unit.get_named_tuple(type_ptr->str()),
|
||||
"Can't find definition for the qualified name: ",
|
||||
type_ptr->str(),
|
||||
"(TypeKind::TupleType) in compilation unit.",
|
||||
"Please report a bug to PyTorch.");
|
||||
auto named_tuple_ptr = compilation_unit.get_named_tuple(type_ptr->str());
|
||||
if (named_tuple_ptr != nullptr) {
|
||||
std::string named_tuple_str = type_ptr->str();
|
||||
named_tuple_str.append("[NamedTuple, [");
|
||||
std::vector<IValue> name_type_pairs;
|
||||
|
||||
// Get the field name and field type for the NamedTuple
|
||||
for (auto it = named_tuple_ptr->schema()->arguments().begin();
|
||||
it != named_tuple_ptr->schema()->arguments().end();
|
||||
it++) {
|
||||
const std::string named_tuple_name = it->name();
|
||||
const c10::TypePtr& named_tuple_type = it->type();
|
||||
// When it->type() is Tensor type, in Python, if it's inferred type,
|
||||
// str() return "Tensor" and repr_str() return "Tensor (inferred)". If
|
||||
// it's not inferred type, str() return "Tensor[]" and repr_str()
|
||||
// return "Tensor". In cpp, repr_str() will always return "Tensor"
|
||||
// regardless inferred type. When exporing custom type in bytecode,
|
||||
// "Tensor" is the preferred way to deserialize Tensor type
|
||||
std::string named_tuple_type_str = it->is_inferred_type()
|
||||
? named_tuple_type->str()
|
||||
: named_tuple_type->repr_str();
|
||||
// The type can also be NamedTuple. Will parse it recursively and get
|
||||
// it's string representation.
|
||||
named_tuple_type_str = get_named_tuple_str_or_default(
|
||||
compilation_unit, named_tuple_type, named_tuple_type->repr_str());
|
||||
name_type_pairs.emplace_back(
|
||||
c10::ivalue::Tuple::create({it->name(), named_tuple_type_str}));
|
||||
|
||||
named_tuple_str.append("[")
|
||||
.append(named_tuple_name)
|
||||
.append(", ")
|
||||
.append(named_tuple_type_str)
|
||||
.append("]");
|
||||
if (it != named_tuple_ptr->schema()->arguments().end() - 1) {
|
||||
named_tuple_str.append(",");
|
||||
}
|
||||
}
|
||||
named_tuple_str.append("]]");
|
||||
return named_tuple_str;
|
||||
}
|
||||
}
|
||||
return default_type_str;
|
||||
}
|
||||
|
||||
std::pair<IValue, IValue> getFunctionTuple(
|
||||
const CompilationUnit& compilation_unit,
|
||||
const mobile::Function& func,
|
||||
@ -120,59 +202,36 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
t = dyn->fallback();
|
||||
}
|
||||
std::string type_str = t->annotation_str();
|
||||
if (t->kind() == TypeKind::TupleType) {
|
||||
TORCH_CHECK(
|
||||
compilation_unit.get_named_tuple(t->str()),
|
||||
"Can't find definition for the qualified name: ",
|
||||
t->str(),
|
||||
"(TypeKind::TupleType) in compilation unit.",
|
||||
"Please report a bug to PyTorch.");
|
||||
auto named_tuple_type = compilation_unit.get_named_tuple(t->str());
|
||||
if (named_tuple_type != nullptr) {
|
||||
std::string named_tuple_str = t->str();
|
||||
named_tuple_str.append("[NamedTuple, [");
|
||||
std::vector<IValue> name_type_pairs;
|
||||
if (t->kind() == TypeKind::DictType) {
|
||||
// For DictType, there are two items in t->containedTypes(), the first one
|
||||
// is key and the second one is value. Both of them could be NamedTuple
|
||||
// type.
|
||||
const TypePtr& key_type = t->containedTypes()[0];
|
||||
const TypePtr& value_type = t->containedTypes()[1];
|
||||
std::string key_type_str = get_named_tuple_str_or_default(
|
||||
compilation_unit, key_type, key_type->annotation_str());
|
||||
std::string value_type_str = get_named_tuple_str_or_default(
|
||||
compilation_unit, value_type, value_type->annotation_str());
|
||||
|
||||
// Get the field name and field type for the NamedTuple
|
||||
for (auto it = named_tuple_type->schema()->arguments().begin();
|
||||
it != named_tuple_type->schema()->arguments().end();
|
||||
it++) {
|
||||
name_type_pairs.emplace_back(
|
||||
c10::ivalue::Tuple::create({it->name(), it->type()->repr_str()}));
|
||||
|
||||
// When it->type() is Tensor type, in Python, if it's inferred type,
|
||||
// str() return "Tensor" and repr_str() return "Tensor (inferred)". If
|
||||
// it's not inferred type, str() return "Tensor[]" and repr_str()
|
||||
// return "Tensor". In cpp, repr_str() will always return "Tensor"
|
||||
// regardless inferred type. When exporing custom type in bytecode,
|
||||
// "Tensor" is the preferred way to deserialize Tensor type
|
||||
type_str = it->is_inferred_type() ? it->type()->str()
|
||||
: it->type()->repr_str();
|
||||
named_tuple_str.append("[" + it->name() + ", " + type_str + "]");
|
||||
if (it != named_tuple_type->schema()->arguments().end() - 1) {
|
||||
named_tuple_str.append(",");
|
||||
}
|
||||
}
|
||||
named_tuple_str.append("]]");
|
||||
// Create a named_tuple type with following structure
|
||||
// "qualified_named[
|
||||
// NamedTuple, [
|
||||
// [filed_name_1, field_type_1],
|
||||
// [filed_name_2, field_type_2]
|
||||
// ]
|
||||
// ]"
|
||||
// Example NamedTuple type:
|
||||
// "__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[
|
||||
// NamedTuple, [
|
||||
// [float_features, Tensor],
|
||||
// [id_list_features, List[Tensor]],
|
||||
// [label, Tensor],
|
||||
// [weight, Tensor],
|
||||
// ]
|
||||
// ]"
|
||||
types.emplace_back(named_tuple_str);
|
||||
continue;
|
||||
}
|
||||
// Construct the dict representation after achieving correct string
|
||||
// representation for both key and value, like
|
||||
// "Dict[str,__torch__.dper3.core.pytorch_schema_utils.IdScoreListFeatureTuple[NamedTuple,
|
||||
// [[lengths, Tensor],[values,
|
||||
// __torch__.dper3.core.pytorch_schema_utils.IdScoreTuple[NamedTuple,
|
||||
// [[ids, Tensor],[scores, Tensor]]]],[offsets, Optional[Tensor]]]]]"
|
||||
std::string dict_str;
|
||||
dict_str.append("Dict[")
|
||||
.append(key_type_str)
|
||||
.append(",")
|
||||
.append(value_type_str)
|
||||
.append("]");
|
||||
types.emplace_back(dict_str);
|
||||
continue;
|
||||
} else if (t->kind() == TypeKind::TupleType) {
|
||||
std::string named_tuple_str =
|
||||
get_named_tuple_str_or_default(compilation_unit, t, type_str);
|
||||
types.emplace_back(named_tuple_str);
|
||||
continue;
|
||||
} else if (type_str.find(torch_prefix) == 0) {
|
||||
TORCH_CHECK(
|
||||
type_str.find(class_prefix) == 0,
|
||||
|
Reference in New Issue
Block a user