mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add all list specializations to pickler (#20191)
Summary: TensorList, DoubleList, and BoolList were missing from the pickler, so this adds them. As a follow up a lot of the code for these could be templated and cut down ](https://our.intern.facebook.com/intern/diff/15299106/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/20191 Pulled By: driazati Differential Revision: D15299106 fbshipit-source-id: f10c0c9af9d60a6b7fb8d93cea9f550b1a7e2415
This commit is contained in:
committed by
Facebook Github Bot
parent
6197eed409
commit
00d0ddb140
@ -12235,22 +12235,35 @@ a")
|
||||
m = self.createFunctionFromGraph(foo.graph)
|
||||
self.getExportImportCopy(m)
|
||||
|
||||
def get_pickle_values(self):
|
||||
return (('dict', {"I": "am", "a test": "test"}, Dict[str, str]),
|
||||
('float', 2.3, float),
|
||||
('int', 99, int),
|
||||
('bool', False, bool),
|
||||
('tuple', (1, 2, 3, 4), Tuple[int, int, int, int]),
|
||||
('list', [(1, 2), (3, 4)], List[Tuple[int, int]]),
|
||||
('tensor', torch.randn(2, 2), torch.Tensor),
|
||||
('int_list', [1, 2, 3, 4], List[int]),
|
||||
('tensor_list', [torch.ones(2, 2) + i for i in range(4)], List[torch.Tensor]),
|
||||
('bool_list', [True, True, False, True], List[bool]),
|
||||
('float_list', [1., 2., 3., 4.], List[float]),
|
||||
('str_list', ['hello', 'bye'], List[str]),
|
||||
('none', None, Optional[int]),)
|
||||
|
||||
def test_attribute_serialization(self):
|
||||
tester = self
|
||||
|
||||
class M(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
|
||||
self.float = torch.jit.Attribute(2.3, float)
|
||||
self.int = torch.jit.Attribute(99, int)
|
||||
self.bool = torch.jit.Attribute(False, bool)
|
||||
self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
|
||||
self.list = torch.jit.Attribute([(1, 2), (3, 4)], List[Tuple[int, int]])
|
||||
self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
|
||||
self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
|
||||
for name, value, the_type in tester.get_pickle_values():
|
||||
setattr(self, name, torch.jit.Attribute(value, the_type))
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self):
|
||||
return (self.table, self.float, self.int, self.bool, self.tuple, self.list, self.int_list)
|
||||
return (self.dict, self.float, self.int, self.bool, self.tuple,
|
||||
self.list, self.int_list, self.tensor_list, self.bool_list,
|
||||
self.float_list, self.str_list, self.none)
|
||||
|
||||
m = M()
|
||||
imported_m = self.getExportImportCopy(m)
|
||||
@ -12268,21 +12281,19 @@ a")
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
|
||||
def test_attribute_unpickling(self):
|
||||
tensor = torch.randn(2, 2)
|
||||
tester = self
|
||||
|
||||
class M(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
|
||||
self.float = torch.jit.Attribute(2.3, float)
|
||||
self.int = torch.jit.Attribute(99, int)
|
||||
self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
|
||||
self.list = torch.jit.Attribute([(1, 2), (3, 4)], List[Tuple[int, int]])
|
||||
self.tensor = torch.jit.Attribute(tensor, torch.Tensor)
|
||||
self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
|
||||
for name, value, the_type in tester.get_pickle_values():
|
||||
setattr(self, name, torch.jit.Attribute(value, the_type))
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self):
|
||||
return (self.table, self.float, self.int, self.tuple, self.list, self.int_list)
|
||||
return (self.dict, self.float, self.int, self.bool, self.tuple,
|
||||
self.list, self.int_list, self.tensor_list, self.bool_list,
|
||||
self.float_list, self.str_list, self.none)
|
||||
|
||||
with TemporaryFileName() as fname:
|
||||
M().save(fname)
|
||||
@ -12291,10 +12302,17 @@ a")
|
||||
pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
|
||||
out = pickle.load(io.BytesIO(pickled_data))
|
||||
|
||||
self.assertEqual(out[0], {"I": "am", "a test": "test"})
|
||||
self.assertEqual(out[1], 2.3)
|
||||
self.assertEqual(out[2], 99)
|
||||
self.assertEqual(out[6], [1, 2, 3, 4])
|
||||
def is_tensor_value(item):
|
||||
if isinstance(item, torch.Tensor):
|
||||
return True
|
||||
if isinstance(item, list):
|
||||
return is_tensor_value(item[0])
|
||||
return False
|
||||
|
||||
for loaded_item, item in zip(out, self.get_pickle_values()):
|
||||
if is_tensor_value(item[1]):
|
||||
continue
|
||||
self.assertEqual(item[1], loaded_item)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
|
||||
def test_old_models_bc(self):
|
||||
|
@ -179,7 +179,7 @@ def _parameter_list(parameter_names_fn):
|
||||
|
||||
try:
|
||||
import typing
|
||||
from typing import Tuple, List, Dict
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
|
||||
def is_tuple(ann):
|
||||
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
|
||||
@ -196,6 +196,22 @@ try:
|
||||
return ann.__module__ == 'typing' and \
|
||||
(getattr(ann, '__origin__', None) is typing.Dict or
|
||||
getattr(ann, '__origin__', None) is dict)
|
||||
|
||||
def is_optional(ann):
|
||||
# Optional[T] is just shorthand for Union[T, None], so check for both
|
||||
union_optional = False
|
||||
if ann.__module__ == 'typing' and \
|
||||
(getattr(ann, '__origin__', None) is typing.Union):
|
||||
args = getattr(ann, '__args__', ())
|
||||
if len(args) == 2:
|
||||
union_optional = (issubclass(args[1], type(None)) and not issubclass(args[0], type(None))) \
|
||||
or (issubclass(args[0], type(None)) and not issubclass(args[1], type(None)))
|
||||
|
||||
optional = ann.__module__ == 'typing' and \
|
||||
(getattr(ann, '__origin__', None) is typing.Optional)
|
||||
|
||||
return optional or union_optional
|
||||
|
||||
except ImportError:
|
||||
# A minimal polyfill for versions of Python that don't have typing.
|
||||
# Note that this means that they also don't support the fancy annotation syntax, so
|
||||
@ -232,9 +248,20 @@ except ImportError:
|
||||
def __getitem__(self, types):
|
||||
return DictInstance(types)
|
||||
|
||||
class OptionalInstance(object):
|
||||
__slots__ = ['__args__']
|
||||
|
||||
def __init__(self, types):
|
||||
self.__args__ = types
|
||||
|
||||
class OptionalCls(object):
|
||||
def __getitem__(self, types):
|
||||
return OptionalInstance(types)
|
||||
|
||||
Tuple = TupleCls() # noqa: T484
|
||||
List = ListCls() # noqa: T484
|
||||
Dict = DictCls() # noqa: T484
|
||||
Optional = DictCls() # noqa: T484
|
||||
|
||||
def is_tuple(ann):
|
||||
return isinstance(ann, TupleInstance)
|
||||
@ -245,6 +272,9 @@ except ImportError:
|
||||
def is_dict(ann):
|
||||
return isinstance(ann, DictInstance)
|
||||
|
||||
def is_optional(ann):
|
||||
return isinstance(ann, OptionalInstance)
|
||||
|
||||
|
||||
# allows BroadcastingList instance to be subscriptable
|
||||
class BroadcastingListCls(object):
|
||||
|
@ -16,6 +16,12 @@ PicklerClass getClass(const std::string& str) {
|
||||
return PicklerClass::TENSOR;
|
||||
} else if (str == "build_intlist") {
|
||||
return PicklerClass::INTLIST;
|
||||
} else if (str == "build_tensorlist") {
|
||||
return PicklerClass::TENSORLIST;
|
||||
} else if (str == "build_doublelist") {
|
||||
return PicklerClass::DOUBLELIST;
|
||||
} else if (str == "build_boollist") {
|
||||
return PicklerClass::BOOLLIST;
|
||||
}
|
||||
|
||||
// TODO [unpickler refactor]
|
||||
@ -30,11 +36,20 @@ PicklerClass getClass(const std::string& str) {
|
||||
const std::string& getClassName(PicklerClass cls) {
|
||||
static const std::string tensor_class("build_tensor_from_id\n");
|
||||
static const std::string intlist_class("build_intlist\n");
|
||||
static const std::string tensorlist_class("build_tensorlist\n");
|
||||
static const std::string doublelist_class("build_doublelist\n");
|
||||
static const std::string boollist_class("build_boollist\n");
|
||||
switch (cls) {
|
||||
case PicklerClass::TENSOR:
|
||||
return tensor_class;
|
||||
case PicklerClass::INTLIST:
|
||||
return intlist_class;
|
||||
case PicklerClass::TENSORLIST:
|
||||
return tensorlist_class;
|
||||
case PicklerClass::DOUBLELIST:
|
||||
return doublelist_class;
|
||||
case PicklerClass::BOOLLIST:
|
||||
return boollist_class;
|
||||
default:
|
||||
AT_ERROR("Unknown class for pickler");
|
||||
}
|
||||
@ -158,13 +173,39 @@ void Pickler::addIValue(const IValue& ivalue) {
|
||||
} else if (ivalue.isString()) {
|
||||
pushMemoizedString(ivalue);
|
||||
} else if (ivalue.isGenericList()) {
|
||||
pushList(ivalue);
|
||||
pushGenericList(ivalue);
|
||||
} else if (ivalue.isGenericDict()) {
|
||||
pushDict(ivalue);
|
||||
} else if (ivalue.isNone()) {
|
||||
push<OpCode>(OpCode::NONE);
|
||||
} else if (ivalue.isIntList()) {
|
||||
pushIntList(ivalue);
|
||||
pushSpecializedList(
|
||||
ivalue, PicklerClass::INTLIST, [=](const IValue& ivalue) {
|
||||
for (const auto& item : ivalue.toIntListRef()) {
|
||||
addIValue(item);
|
||||
}
|
||||
});
|
||||
} else if (ivalue.isTensorList()) {
|
||||
pushSpecializedList(
|
||||
ivalue, PicklerClass::TENSORLIST, [=](const IValue& ivalue) {
|
||||
for (const auto& item : ivalue.toTensorListRef()) {
|
||||
addIValue(item);
|
||||
}
|
||||
});
|
||||
} else if (ivalue.isDoubleList()) {
|
||||
pushSpecializedList(
|
||||
ivalue, PicklerClass::DOUBLELIST, [=](const IValue& ivalue) {
|
||||
for (const auto& item : ivalue.toDoubleListRef()) {
|
||||
addIValue(item);
|
||||
}
|
||||
});
|
||||
} else if (ivalue.isBoolList()) {
|
||||
pushSpecializedList(
|
||||
ivalue, PicklerClass::BOOLLIST, [=](const IValue& ivalue) {
|
||||
for (const auto& item : ivalue.toBoolListRef()) {
|
||||
addIValue(bool(item));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
|
||||
}
|
||||
@ -185,6 +226,12 @@ const void* Pickler::getPointer(const IValue& ivalue) {
|
||||
return ivalue.toString().get();
|
||||
} else if (ivalue.isIntList()) {
|
||||
return ivalue.toIntList().get();
|
||||
} else if (ivalue.isTensorList()) {
|
||||
return ivalue.toTensorList().get();
|
||||
} else if (ivalue.isDoubleList()) {
|
||||
return ivalue.toDoubleList().get();
|
||||
} else if (ivalue.isBoolList()) {
|
||||
return ivalue.toBoolList().get();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
@ -342,9 +389,11 @@ void Pickler::pushTensorReference(const IValue& ivalue) {
|
||||
push<OpCode>(OpCode::REDUCE);
|
||||
}
|
||||
|
||||
void Pickler::pushIntList(const IValue& ivalue) {
|
||||
pushClass(PicklerClass::INTLIST);
|
||||
|
||||
void Pickler::pushSpecializedList(
|
||||
const IValue& ivalue,
|
||||
PicklerClass cls,
|
||||
const std::function<void(const IValue&)>& item_pusher) {
|
||||
pushClass(cls);
|
||||
|
||||
// Reduce arguments are spread (e.g. `*args`) before calling the global,
|
||||
// so wrap in a tuple
|
||||
@ -354,10 +403,8 @@ void Pickler::pushIntList(const IValue& ivalue) {
|
||||
// Mark list
|
||||
push<OpCode>(OpCode::MARK);
|
||||
|
||||
// Add items
|
||||
for (const auto& item : ivalue.toIntListRef()) {
|
||||
addIValue(item);
|
||||
}
|
||||
// Add all items
|
||||
item_pusher(ivalue);
|
||||
|
||||
// Finish list
|
||||
push<OpCode>(OpCode::APPENDS);
|
||||
@ -428,7 +475,7 @@ void Pickler::pushMemoization(const IValue& ivalue) {
|
||||
pushMemoization(ptr);
|
||||
}
|
||||
|
||||
void Pickler::pushList(const IValue& ivalue) {
|
||||
void Pickler::pushGenericList(const IValue& ivalue) {
|
||||
auto list = ivalue.toGenericListRef();
|
||||
push<OpCode>(OpCode::EMPTY_LIST);
|
||||
pushMemoization(ivalue);
|
||||
@ -508,7 +555,6 @@ void Unpickler::run() {
|
||||
AT_ERROR("Overran buffer while unpickling data, didn't find STOP opcode");
|
||||
}
|
||||
|
||||
|
||||
OpCode Unpickler::readInstruction() {
|
||||
auto opcode = readOpCode();
|
||||
switch (opcode) {
|
||||
@ -533,6 +579,14 @@ OpCode Unpickler::readInstruction() {
|
||||
// specialization
|
||||
if (stack_.back().pickler_class() == PicklerClass::INTLIST) {
|
||||
stack_.emplace_back(std::vector<int64_t>());
|
||||
} else if (stack_.back().pickler_class() == PicklerClass::INTLIST) {
|
||||
stack_.emplace_back(std::vector<int64_t>());
|
||||
} else if (stack_.back().pickler_class() == PicklerClass::TENSORLIST) {
|
||||
stack_.emplace_back(std::vector<at::Tensor>());
|
||||
} else if (stack_.back().pickler_class() == PicklerClass::DOUBLELIST) {
|
||||
stack_.emplace_back(std::vector<double>());
|
||||
} else if (stack_.back().pickler_class() == PicklerClass::BOOLLIST) {
|
||||
stack_.emplace_back(std::vector<bool>());
|
||||
} else {
|
||||
AT_ERROR("Unknown list specialization");
|
||||
}
|
||||
@ -572,6 +626,9 @@ OpCode Unpickler::readInstruction() {
|
||||
case OpCode::NEWFALSE: {
|
||||
stack_.emplace_back(false);
|
||||
} break;
|
||||
case OpCode::NONE: {
|
||||
stack_.emplace_back(IValue());
|
||||
} break;
|
||||
case OpCode::BININT1: {
|
||||
int8_t value = read<int8_t>();
|
||||
stack_.emplace_back(int64_t(value));
|
||||
@ -651,20 +708,19 @@ OpCode Unpickler::readInstruction() {
|
||||
auto setitem_data = stack_.back().ivalue();
|
||||
stack_.pop_back();
|
||||
|
||||
|
||||
auto class_name =
|
||||
static_cast<PicklerClass>(uint8_t(stack_.back().ivalue().toInt()));
|
||||
static_cast<PicklerClass>(uint8_t(stack_.back().ivalue().toInt()));
|
||||
stack_.pop_back();
|
||||
|
||||
switch (class_name) {
|
||||
case PicklerClass::TENSOR:
|
||||
stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
|
||||
break;
|
||||
case PicklerClass::INTLIST:
|
||||
stack_.emplace_back(setitem_data);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unknown pickler class id");
|
||||
case PicklerClass::TENSOR:
|
||||
stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
|
||||
break;
|
||||
case PicklerClass::INTLIST:
|
||||
stack_.emplace_back(setitem_data);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unknown pickler class id");
|
||||
}
|
||||
} break;
|
||||
case OpCode::REDUCE: {
|
||||
@ -684,33 +740,69 @@ OpCode Unpickler::readInstruction() {
|
||||
case PicklerClass::INTLIST:
|
||||
stack_.emplace_back(data->elements().at(0).toIntListRef());
|
||||
break;
|
||||
case PicklerClass::TENSORLIST:
|
||||
stack_.emplace_back(data->elements().at(0).toTensorListRef());
|
||||
break;
|
||||
case PicklerClass::DOUBLELIST:
|
||||
stack_.emplace_back(data->elements().at(0).toDoubleListRef());
|
||||
break;
|
||||
case PicklerClass::BOOLLIST:
|
||||
stack_.emplace_back(data->elements().at(0).toBoolListRef());
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unknown pickler class id");
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
AT_ERROR("Unknown opcode for unpickling at ", reinterpret_cast<void*>(opcode),": ", static_cast<uint8_t>(opcode));
|
||||
AT_ERROR(
|
||||
"Unknown opcode for unpickling at ",
|
||||
reinterpret_cast<void*>(opcode),
|
||||
": ",
|
||||
static_cast<uint8_t>(opcode));
|
||||
}
|
||||
return opcode;
|
||||
}
|
||||
|
||||
// Pop all the list items off of the stack and append them to the list at the
|
||||
// corresponding MARK
|
||||
void Unpickler::readList() {
|
||||
size_t start = marks_.back();
|
||||
marks_.pop_back();
|
||||
auto list_ivalue = stack_.at(start - 1);
|
||||
auto list_ivalue = stack_.at(start - 1).ivalue();
|
||||
auto num_elements = stack_.size() - start;
|
||||
if (list_ivalue.ivalue().isIntList()) {
|
||||
auto list = stack_.at(start - 1).ivalue().toIntList();
|
||||
list->elements().reserve(num_elements);
|
||||
for (auto it = stack_.begin() + start; it != stack_.end(); ++it) {
|
||||
list->elements().emplace_back(it->ivalue().toInt());
|
||||
auto elements = at::ArrayRef<StackItem>(stack_).slice(start);
|
||||
if (list_ivalue.isIntList()) {
|
||||
auto& list = list_ivalue.toIntList()->elements();
|
||||
list.reserve(num_elements);
|
||||
for (const auto& elem : elements) {
|
||||
list.emplace_back(elem.ivalue().toInt());
|
||||
}
|
||||
} else if (list_ivalue.isTensorList()) {
|
||||
auto& list = list_ivalue.toTensorList()->elements();
|
||||
list.reserve(num_elements);
|
||||
for (const auto& elem : elements) {
|
||||
list.emplace_back(elem.ivalue().toTensor());
|
||||
}
|
||||
} else if (list_ivalue.isDoubleList()) {
|
||||
auto& list = list_ivalue.toDoubleList()->elements();
|
||||
list.reserve(num_elements);
|
||||
for (const auto& elem : elements) {
|
||||
list.emplace_back(elem.ivalue().toDouble());
|
||||
}
|
||||
} else if (list_ivalue.isBoolList()) {
|
||||
auto& list = list_ivalue.toBoolList()->elements();
|
||||
list.reserve(num_elements);
|
||||
for (const auto& elem : elements) {
|
||||
list.push_back(elem.ivalue().toBool());
|
||||
}
|
||||
} else if (list_ivalue.isGenericList()) {
|
||||
auto& list = list_ivalue.toGenericList()->elements();
|
||||
list.reserve(num_elements);
|
||||
for (const auto& elem : elements) {
|
||||
list.emplace_back(elem.ivalue());
|
||||
}
|
||||
} else {
|
||||
auto list = stack_.at(start - 1).ivalue().toGenericList();
|
||||
list->elements().reserve(num_elements);
|
||||
for (auto it = stack_.begin() + start; it != stack_.end(); ++it) {
|
||||
list->elements().emplace_back(it->ivalue());
|
||||
}
|
||||
AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind());
|
||||
}
|
||||
|
||||
stack_.erase(stack_.begin() + start, stack_.end());
|
||||
|
@ -90,6 +90,12 @@ enum PicklerClass : uint8_t {
|
||||
TENSOR = 0,
|
||||
// List[int]
|
||||
INTLIST = 1,
|
||||
// List[Tensor]
|
||||
TENSORLIST = 2,
|
||||
// List[float]
|
||||
DOUBLELIST = 3,
|
||||
// List[bool]
|
||||
BOOLLIST = 4
|
||||
};
|
||||
|
||||
using ::c10::IValue;
|
||||
@ -122,6 +128,7 @@ class Pickler {
|
||||
private:
|
||||
void pushDict(const IValue& ivalue);
|
||||
void pushDouble(const IValue& ivalue);
|
||||
void pushGenericList(const IValue& ivalue);
|
||||
void pushInt(const IValue& ivalue);
|
||||
void pushIntList(const IValue& ivalue);
|
||||
void pushList(const IValue& ivalue);
|
||||
@ -134,6 +141,10 @@ class Pickler {
|
||||
|
||||
void pushBinGet(uint32_t memo_id);
|
||||
void pushClass(PicklerClass cls);
|
||||
void pushSpecializedList(
|
||||
const IValue& ivalue,
|
||||
PicklerClass cls,
|
||||
const std::function<void(const IValue&)>& item_pusher);
|
||||
void pushGlobal(const std::string& name);
|
||||
void pushMemoization(const void* item);
|
||||
void pushString(const std::string& string);
|
||||
@ -186,19 +197,19 @@ struct StackItem {
|
||||
StackItem(PicklerClass pickler_class)
|
||||
: pickler_class_(pickler_class), ivalue_(c10::nullopt) {}
|
||||
|
||||
IValue ivalue() {
|
||||
IValue ivalue() const {
|
||||
return *ivalue_;
|
||||
}
|
||||
|
||||
PicklerClass pickler_class() {
|
||||
PicklerClass pickler_class() const {
|
||||
return *pickler_class_;
|
||||
}
|
||||
|
||||
c10::optional<IValue> ivalue_opt() {
|
||||
c10::optional<IValue> ivalue_opt() const {
|
||||
return ivalue_;
|
||||
}
|
||||
|
||||
c10::optional<PicklerClass> pickler_class_opt() {
|
||||
c10::optional<PicklerClass> pickler_class_opt() const {
|
||||
return pickler_class_;
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,22 @@
|
||||
# These functions are referenced from the pickle archives produced by
|
||||
# ScriptModule.save()
|
||||
|
||||
def build_intlist(data):
|
||||
return data
|
||||
|
||||
|
||||
def build_tensorlist(data):
|
||||
return data
|
||||
|
||||
|
||||
def build_doublelist(data):
|
||||
return data
|
||||
|
||||
|
||||
def build_boollist(data):
|
||||
return data
|
||||
|
||||
|
||||
def build_tensor_from_id(data):
|
||||
if isinstance(data, int):
|
||||
# just the id, can't really do anything
|
||||
|
@ -3,9 +3,10 @@ import ast
|
||||
import inspect
|
||||
import torch
|
||||
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
|
||||
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict
|
||||
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
|
||||
is_optional
|
||||
from torch._C import TensorType, TupleType, FloatType, IntType, \
|
||||
ListType, StringType, DictType, BoolType
|
||||
ListType, StringType, DictType, BoolType, OptionalType
|
||||
from textwrap import dedent
|
||||
|
||||
|
||||
@ -31,6 +32,7 @@ _eval_env = {
|
||||
'Tuple': Tuple,
|
||||
'List': List,
|
||||
'Dict': Dict,
|
||||
'Optional': Optional,
|
||||
}
|
||||
|
||||
|
||||
@ -173,6 +175,11 @@ def ann_to_type(ann):
|
||||
key = ann_to_type(ann.__args__[0])
|
||||
value = ann_to_type(ann.__args__[1])
|
||||
return DictType(key, value)
|
||||
elif is_optional(ann):
|
||||
if issubclass(ann.__args__[1], type(None)):
|
||||
return OptionalType(ann_to_type(ann.__args__[0]))
|
||||
else:
|
||||
return OptionalType(ann_to_type(ann.__args__[1]))
|
||||
elif ann is float:
|
||||
return FloatType.get()
|
||||
elif ann is int:
|
||||
@ -181,7 +188,7 @@ def ann_to_type(ann):
|
||||
return StringType.get()
|
||||
elif ann is bool:
|
||||
return BoolType.get()
|
||||
raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))
|
||||
raise ValueError("Unknown type annotation: '{}'".format(ann))
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
Reference in New Issue
Block a user