Add all list specializations to pickler (#20191)

Summary:
TensorList, DoubleList, and BoolList were missing from the pickler, so
this adds them.

As a follow up a lot of the code for these could be templated and cut
down

](https://our.intern.facebook.com/intern/diff/15299106/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20191

Pulled By: driazati

Differential Revision: D15299106

fbshipit-source-id: f10c0c9af9d60a6b7fb8d93cea9f550b1a7e2415
This commit is contained in:
davidriazati
2019-05-10 17:01:17 -07:00
committed by Facebook Github Bot
parent 6197eed409
commit 00d0ddb140
6 changed files with 235 additions and 62 deletions

View File

@ -12235,22 +12235,35 @@ a")
m = self.createFunctionFromGraph(foo.graph)
self.getExportImportCopy(m)
def get_pickle_values(self):
return (('dict', {"I": "am", "a test": "test"}, Dict[str, str]),
('float', 2.3, float),
('int', 99, int),
('bool', False, bool),
('tuple', (1, 2, 3, 4), Tuple[int, int, int, int]),
('list', [(1, 2), (3, 4)], List[Tuple[int, int]]),
('tensor', torch.randn(2, 2), torch.Tensor),
('int_list', [1, 2, 3, 4], List[int]),
('tensor_list', [torch.ones(2, 2) + i for i in range(4)], List[torch.Tensor]),
('bool_list', [True, True, False, True], List[bool]),
('float_list', [1., 2., 3., 4.], List[float]),
('str_list', ['hello', 'bye'], List[str]),
('none', None, Optional[int]),)
def test_attribute_serialization(self):
tester = self
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
self.float = torch.jit.Attribute(2.3, float)
self.int = torch.jit.Attribute(99, int)
self.bool = torch.jit.Attribute(False, bool)
self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
self.list = torch.jit.Attribute([(1, 2), (3, 4)], List[Tuple[int, int]])
self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
for name, value, the_type in tester.get_pickle_values():
setattr(self, name, torch.jit.Attribute(value, the_type))
@torch.jit.script_method
def forward(self):
return (self.table, self.float, self.int, self.bool, self.tuple, self.list, self.int_list)
return (self.dict, self.float, self.int, self.bool, self.tuple,
self.list, self.int_list, self.tensor_list, self.bool_list,
self.float_list, self.str_list, self.none)
m = M()
imported_m = self.getExportImportCopy(m)
@ -12268,21 +12281,19 @@ a")
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
def test_attribute_unpickling(self):
tensor = torch.randn(2, 2)
tester = self
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
self.float = torch.jit.Attribute(2.3, float)
self.int = torch.jit.Attribute(99, int)
self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
self.list = torch.jit.Attribute([(1, 2), (3, 4)], List[Tuple[int, int]])
self.tensor = torch.jit.Attribute(tensor, torch.Tensor)
self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
for name, value, the_type in tester.get_pickle_values():
setattr(self, name, torch.jit.Attribute(value, the_type))
@torch.jit.script_method
def forward(self):
return (self.table, self.float, self.int, self.tuple, self.list, self.int_list)
return (self.dict, self.float, self.int, self.bool, self.tuple,
self.list, self.int_list, self.tensor_list, self.bool_list,
self.float_list, self.str_list, self.none)
with TemporaryFileName() as fname:
M().save(fname)
@ -12291,10 +12302,17 @@ a")
pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
out = pickle.load(io.BytesIO(pickled_data))
self.assertEqual(out[0], {"I": "am", "a test": "test"})
self.assertEqual(out[1], 2.3)
self.assertEqual(out[2], 99)
self.assertEqual(out[6], [1, 2, 3, 4])
def is_tensor_value(item):
if isinstance(item, torch.Tensor):
return True
if isinstance(item, list):
return is_tensor_value(item[0])
return False
for loaded_item, item in zip(out, self.get_pickle_values()):
if is_tensor_value(item[1]):
continue
self.assertEqual(item[1], loaded_item)
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
def test_old_models_bc(self):

View File

@ -179,7 +179,7 @@ def _parameter_list(parameter_names_fn):
try:
import typing
from typing import Tuple, List, Dict
from typing import Tuple, List, Dict, Optional
def is_tuple(ann):
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
@ -196,6 +196,22 @@ try:
return ann.__module__ == 'typing' and \
(getattr(ann, '__origin__', None) is typing.Dict or
getattr(ann, '__origin__', None) is dict)
def is_optional(ann):
# Optional[T] is just shorthand for Union[T, None], so check for both
union_optional = False
if ann.__module__ == 'typing' and \
(getattr(ann, '__origin__', None) is typing.Union):
args = getattr(ann, '__args__', ())
if len(args) == 2:
union_optional = (issubclass(args[1], type(None)) and not issubclass(args[0], type(None))) \
or (issubclass(args[0], type(None)) and not issubclass(args[1], type(None)))
optional = ann.__module__ == 'typing' and \
(getattr(ann, '__origin__', None) is typing.Optional)
return optional or union_optional
except ImportError:
# A minimal polyfill for versions of Python that don't have typing.
# Note that this means that they also don't support the fancy annotation syntax, so
@ -232,9 +248,20 @@ except ImportError:
def __getitem__(self, types):
return DictInstance(types)
class OptionalInstance(object):
__slots__ = ['__args__']
def __init__(self, types):
self.__args__ = types
class OptionalCls(object):
def __getitem__(self, types):
return OptionalInstance(types)
Tuple = TupleCls() # noqa: T484
List = ListCls() # noqa: T484
Dict = DictCls() # noqa: T484
Optional = DictCls() # noqa: T484
def is_tuple(ann):
return isinstance(ann, TupleInstance)
@ -245,6 +272,9 @@ except ImportError:
def is_dict(ann):
return isinstance(ann, DictInstance)
def is_optional(ann):
return isinstance(ann, OptionalInstance)
# allows BroadcastingList instance to be subscriptable
class BroadcastingListCls(object):

View File

@ -16,6 +16,12 @@ PicklerClass getClass(const std::string& str) {
return PicklerClass::TENSOR;
} else if (str == "build_intlist") {
return PicklerClass::INTLIST;
} else if (str == "build_tensorlist") {
return PicklerClass::TENSORLIST;
} else if (str == "build_doublelist") {
return PicklerClass::DOUBLELIST;
} else if (str == "build_boollist") {
return PicklerClass::BOOLLIST;
}
// TODO [unpickler refactor]
@ -30,11 +36,20 @@ PicklerClass getClass(const std::string& str) {
const std::string& getClassName(PicklerClass cls) {
static const std::string tensor_class("build_tensor_from_id\n");
static const std::string intlist_class("build_intlist\n");
static const std::string tensorlist_class("build_tensorlist\n");
static const std::string doublelist_class("build_doublelist\n");
static const std::string boollist_class("build_boollist\n");
switch (cls) {
case PicklerClass::TENSOR:
return tensor_class;
case PicklerClass::INTLIST:
return intlist_class;
case PicklerClass::TENSORLIST:
return tensorlist_class;
case PicklerClass::DOUBLELIST:
return doublelist_class;
case PicklerClass::BOOLLIST:
return boollist_class;
default:
AT_ERROR("Unknown class for pickler");
}
@ -158,13 +173,39 @@ void Pickler::addIValue(const IValue& ivalue) {
} else if (ivalue.isString()) {
pushMemoizedString(ivalue);
} else if (ivalue.isGenericList()) {
pushList(ivalue);
pushGenericList(ivalue);
} else if (ivalue.isGenericDict()) {
pushDict(ivalue);
} else if (ivalue.isNone()) {
push<OpCode>(OpCode::NONE);
} else if (ivalue.isIntList()) {
pushIntList(ivalue);
pushSpecializedList(
ivalue, PicklerClass::INTLIST, [=](const IValue& ivalue) {
for (const auto& item : ivalue.toIntListRef()) {
addIValue(item);
}
});
} else if (ivalue.isTensorList()) {
pushSpecializedList(
ivalue, PicklerClass::TENSORLIST, [=](const IValue& ivalue) {
for (const auto& item : ivalue.toTensorListRef()) {
addIValue(item);
}
});
} else if (ivalue.isDoubleList()) {
pushSpecializedList(
ivalue, PicklerClass::DOUBLELIST, [=](const IValue& ivalue) {
for (const auto& item : ivalue.toDoubleListRef()) {
addIValue(item);
}
});
} else if (ivalue.isBoolList()) {
pushSpecializedList(
ivalue, PicklerClass::BOOLLIST, [=](const IValue& ivalue) {
for (const auto& item : ivalue.toBoolListRef()) {
addIValue(bool(item));
}
});
} else {
AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
}
@ -185,6 +226,12 @@ const void* Pickler::getPointer(const IValue& ivalue) {
return ivalue.toString().get();
} else if (ivalue.isIntList()) {
return ivalue.toIntList().get();
} else if (ivalue.isTensorList()) {
return ivalue.toTensorList().get();
} else if (ivalue.isDoubleList()) {
return ivalue.toDoubleList().get();
} else if (ivalue.isBoolList()) {
return ivalue.toBoolList().get();
}
return nullptr;
@ -342,9 +389,11 @@ void Pickler::pushTensorReference(const IValue& ivalue) {
push<OpCode>(OpCode::REDUCE);
}
void Pickler::pushIntList(const IValue& ivalue) {
pushClass(PicklerClass::INTLIST);
void Pickler::pushSpecializedList(
const IValue& ivalue,
PicklerClass cls,
const std::function<void(const IValue&)>& item_pusher) {
pushClass(cls);
// Reduce arguments are spread (e.g. `*args`) before calling the global,
// so wrap in a tuple
@ -354,10 +403,8 @@ void Pickler::pushIntList(const IValue& ivalue) {
// Mark list
push<OpCode>(OpCode::MARK);
// Add items
for (const auto& item : ivalue.toIntListRef()) {
addIValue(item);
}
// Add all items
item_pusher(ivalue);
// Finish list
push<OpCode>(OpCode::APPENDS);
@ -428,7 +475,7 @@ void Pickler::pushMemoization(const IValue& ivalue) {
pushMemoization(ptr);
}
void Pickler::pushList(const IValue& ivalue) {
void Pickler::pushGenericList(const IValue& ivalue) {
auto list = ivalue.toGenericListRef();
push<OpCode>(OpCode::EMPTY_LIST);
pushMemoization(ivalue);
@ -508,7 +555,6 @@ void Unpickler::run() {
AT_ERROR("Overran buffer while unpickling data, didn't find STOP opcode");
}
OpCode Unpickler::readInstruction() {
auto opcode = readOpCode();
switch (opcode) {
@ -533,6 +579,14 @@ OpCode Unpickler::readInstruction() {
// specialization
if (stack_.back().pickler_class() == PicklerClass::INTLIST) {
stack_.emplace_back(std::vector<int64_t>());
} else if (stack_.back().pickler_class() == PicklerClass::INTLIST) {
stack_.emplace_back(std::vector<int64_t>());
} else if (stack_.back().pickler_class() == PicklerClass::TENSORLIST) {
stack_.emplace_back(std::vector<at::Tensor>());
} else if (stack_.back().pickler_class() == PicklerClass::DOUBLELIST) {
stack_.emplace_back(std::vector<double>());
} else if (stack_.back().pickler_class() == PicklerClass::BOOLLIST) {
stack_.emplace_back(std::vector<bool>());
} else {
AT_ERROR("Unknown list specialization");
}
@ -572,6 +626,9 @@ OpCode Unpickler::readInstruction() {
case OpCode::NEWFALSE: {
stack_.emplace_back(false);
} break;
case OpCode::NONE: {
stack_.emplace_back(IValue());
} break;
case OpCode::BININT1: {
int8_t value = read<int8_t>();
stack_.emplace_back(int64_t(value));
@ -651,20 +708,19 @@ OpCode Unpickler::readInstruction() {
auto setitem_data = stack_.back().ivalue();
stack_.pop_back();
auto class_name =
static_cast<PicklerClass>(uint8_t(stack_.back().ivalue().toInt()));
static_cast<PicklerClass>(uint8_t(stack_.back().ivalue().toInt()));
stack_.pop_back();
switch (class_name) {
case PicklerClass::TENSOR:
stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
break;
case PicklerClass::INTLIST:
stack_.emplace_back(setitem_data);
break;
default:
AT_ERROR("Unknown pickler class id");
case PicklerClass::TENSOR:
stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
break;
case PicklerClass::INTLIST:
stack_.emplace_back(setitem_data);
break;
default:
AT_ERROR("Unknown pickler class id");
}
} break;
case OpCode::REDUCE: {
@ -684,33 +740,69 @@ OpCode Unpickler::readInstruction() {
case PicklerClass::INTLIST:
stack_.emplace_back(data->elements().at(0).toIntListRef());
break;
case PicklerClass::TENSORLIST:
stack_.emplace_back(data->elements().at(0).toTensorListRef());
break;
case PicklerClass::DOUBLELIST:
stack_.emplace_back(data->elements().at(0).toDoubleListRef());
break;
case PicklerClass::BOOLLIST:
stack_.emplace_back(data->elements().at(0).toBoolListRef());
break;
default:
AT_ERROR("Unknown pickler class id");
}
} break;
default:
AT_ERROR("Unknown opcode for unpickling at ", reinterpret_cast<void*>(opcode),": ", static_cast<uint8_t>(opcode));
AT_ERROR(
"Unknown opcode for unpickling at ",
reinterpret_cast<void*>(opcode),
": ",
static_cast<uint8_t>(opcode));
}
return opcode;
}
// Pop all the list items off of the stack and append them to the list at the
// corresponding MARK
void Unpickler::readList() {
size_t start = marks_.back();
marks_.pop_back();
auto list_ivalue = stack_.at(start - 1);
auto list_ivalue = stack_.at(start - 1).ivalue();
auto num_elements = stack_.size() - start;
if (list_ivalue.ivalue().isIntList()) {
auto list = stack_.at(start - 1).ivalue().toIntList();
list->elements().reserve(num_elements);
for (auto it = stack_.begin() + start; it != stack_.end(); ++it) {
list->elements().emplace_back(it->ivalue().toInt());
auto elements = at::ArrayRef<StackItem>(stack_).slice(start);
if (list_ivalue.isIntList()) {
auto& list = list_ivalue.toIntList()->elements();
list.reserve(num_elements);
for (const auto& elem : elements) {
list.emplace_back(elem.ivalue().toInt());
}
} else if (list_ivalue.isTensorList()) {
auto& list = list_ivalue.toTensorList()->elements();
list.reserve(num_elements);
for (const auto& elem : elements) {
list.emplace_back(elem.ivalue().toTensor());
}
} else if (list_ivalue.isDoubleList()) {
auto& list = list_ivalue.toDoubleList()->elements();
list.reserve(num_elements);
for (const auto& elem : elements) {
list.emplace_back(elem.ivalue().toDouble());
}
} else if (list_ivalue.isBoolList()) {
auto& list = list_ivalue.toBoolList()->elements();
list.reserve(num_elements);
for (const auto& elem : elements) {
list.push_back(elem.ivalue().toBool());
}
} else if (list_ivalue.isGenericList()) {
auto& list = list_ivalue.toGenericList()->elements();
list.reserve(num_elements);
for (const auto& elem : elements) {
list.emplace_back(elem.ivalue());
}
} else {
auto list = stack_.at(start - 1).ivalue().toGenericList();
list->elements().reserve(num_elements);
for (auto it = stack_.begin() + start; it != stack_.end(); ++it) {
list->elements().emplace_back(it->ivalue());
}
AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind());
}
stack_.erase(stack_.begin() + start, stack_.end());

View File

@ -90,6 +90,12 @@ enum PicklerClass : uint8_t {
TENSOR = 0,
// List[int]
INTLIST = 1,
// List[Tensor]
TENSORLIST = 2,
// List[float]
DOUBLELIST = 3,
// List[bool]
BOOLLIST = 4
};
using ::c10::IValue;
@ -122,6 +128,7 @@ class Pickler {
private:
void pushDict(const IValue& ivalue);
void pushDouble(const IValue& ivalue);
void pushGenericList(const IValue& ivalue);
void pushInt(const IValue& ivalue);
void pushIntList(const IValue& ivalue);
void pushList(const IValue& ivalue);
@ -134,6 +141,10 @@ class Pickler {
void pushBinGet(uint32_t memo_id);
void pushClass(PicklerClass cls);
void pushSpecializedList(
const IValue& ivalue,
PicklerClass cls,
const std::function<void(const IValue&)>& item_pusher);
void pushGlobal(const std::string& name);
void pushMemoization(const void* item);
void pushString(const std::string& string);
@ -186,19 +197,19 @@ struct StackItem {
StackItem(PicklerClass pickler_class)
: pickler_class_(pickler_class), ivalue_(c10::nullopt) {}
IValue ivalue() {
IValue ivalue() const {
return *ivalue_;
}
PicklerClass pickler_class() {
PicklerClass pickler_class() const {
return *pickler_class_;
}
c10::optional<IValue> ivalue_opt() {
c10::optional<IValue> ivalue_opt() const {
return ivalue_;
}
c10::optional<PicklerClass> pickler_class_opt() {
c10::optional<PicklerClass> pickler_class_opt() const {
return pickler_class_;
}

View File

@ -1,7 +1,22 @@
# These functions are referenced from the pickle archives produced by
# ScriptModule.save()
def build_intlist(data):
return data
def build_tensorlist(data):
return data
def build_doublelist(data):
return data
def build_boollist(data):
return data
def build_tensor_from_id(data):
if isinstance(data, int):
# just the id, can't really do anything

View File

@ -3,9 +3,10 @@ import ast
import inspect
import torch
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
is_optional
from torch._C import TensorType, TupleType, FloatType, IntType, \
ListType, StringType, DictType, BoolType
ListType, StringType, DictType, BoolType, OptionalType
from textwrap import dedent
@ -31,6 +32,7 @@ _eval_env = {
'Tuple': Tuple,
'List': List,
'Dict': Dict,
'Optional': Optional,
}
@ -173,6 +175,11 @@ def ann_to_type(ann):
key = ann_to_type(ann.__args__[0])
value = ann_to_type(ann.__args__[1])
return DictType(key, value)
elif is_optional(ann):
if issubclass(ann.__args__[1], type(None)):
return OptionalType(ann_to_type(ann.__args__[0]))
else:
return OptionalType(ann_to_type(ann.__args__[1]))
elif ann is float:
return FloatType.get()
elif ann is int:
@ -181,7 +188,7 @@ def ann_to_type(ann):
return StringType.get()
elif ann is bool:
return BoolType.get()
raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))
raise ValueError("Unknown type annotation: '{}'".format(ann))
__all__ = [