mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
The strategy is that we will heap allocate a LargeNegativeIntSymNodeImpl whenever we have a large negative int, so that we can keep the old `is_symbolic` test (now called `is_heap_allocated`) on SymInt. Whenever we need to do something with these ints, though, we convert them back into a plain `int64_t` (and then, e.g., wrap it in whatever user specificed SymNodeImpl they need.) We cannot wrap directly in the user specified SymNodeImpl as we generally do not know what the "tracing context" is from C++. We expect large negative ints to be rare, so we don't apply optimizations like singleton-ifying INT_MIN. Here's the order to review: * c10/core/SymInt.h and cpp * `is_symbolic` renamed to `is_heap_allocated` as I needed to audit all use sites: the old `is_symbolic` test would return true for large negative int, but it would be wrong to then try to dispatch on the LargeNegativeIntSymNodeImpl which supports very few operations. In this file, I had to update expect_int, * If you pass in a large negative integer, we instead heap allocate it in `promote_to_negative`. The function is written in a funny way to keep compact constructor code for SymInt (the heap allocation happens out of line) * clone is now moved out-of-line * New method maybe_as_int which will give you a constant int if it is possible, either because it's stored inline or in LargeNegativeIntSymNodeImpl. This is the preferred replacement for previous use of is_symbolic() and then as_int_unchecked(). * Rename toSymNodeImpl to toSymNode, which is more correct (since it returns a SymNode) * Complete rewrite of `normalize_symints.cpp` to use new `maybe_as_int`. Cannot easily use the old code structure, so it's now done doing a macro and typing out each case manually (it's actually not that bad.) * Reimplementations of all the unary operators by hand to use `maybe_as_int`, relatively simple. * c10/core/LargeNegativeIntSymNodeImpl.h - Just stores a int64_t value, but it has to be big and negative. Most methods are not implemented, since we will rewrap the large negative int in the real SymNodeImpl subclass before doing operations with it * The rest of the files are just rewriting code to use `maybe_as_int`. There is a nontrivial comment in c10/core/SymIntArrayRef.h Very minor test adjustment in c10/test/core/SymInt_test.cpp . Plan to exercise this properly in next PR. Companion XLA PR: https://github.com/pytorch/xla/pull/4882 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/99157 Approved by: https://github.com/albanD
1162 lines
41 KiB
C++
1162 lines
41 KiB
C++
#include <ATen/ATen.h>
|
|
#include <ATen/core/Dict.h>
|
|
#ifdef USE_RPC
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#endif
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/mobile/type_parser.h>
|
|
#include <torch/csrc/jit/serialization/pickler.h>
|
|
#include <torch/csrc/jit/serialization/storage_context.h>
|
|
#include <torch/csrc/jit/serialization/unpickler.h>
|
|
#include <torch/csrc/utils/byte_order.h>
|
|
#include <string>
|
|
|
|
namespace torch::jit {
|
|
|
|
using ::c10::IValue;
|
|
|
|
static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
|
|
if (root.isObject()) {
|
|
restoreAccurateTypeTags(root, root.type());
|
|
}
|
|
}
|
|
|
|
// Pickled objects are stored in a form compatible with Python pickling.
|
|
// In torchscript List[T]/Dict[K, V] are statically typed and contain
|
|
// dynamic type tags that allow T, K, and V to be recovered. But this
|
|
// info is not stored in the Python pickling information. However, we
|
|
// can recover this information from the static type of the top-level
|
|
// object being unpickled, because we have a record of the type of the
|
|
// objects it contains as attributes.
|
|
// `IfPossible` - we can only do this recovery when we have an object as
|
|
// the top-level unpickled thing (which is guaranteed for Modules, but
|
|
// not for torch.load/torch.save). Otherwise we do not know the types
|
|
// of the contained objects and cannot restore the tags.
|
|
void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
|
struct Work {
|
|
TypePtr type;
|
|
IValue value;
|
|
};
|
|
std::vector<Work> to_process = {{type_tag, root}};
|
|
std::unordered_set<const void*> scanned;
|
|
while (!to_process.empty()) {
|
|
Work w = std::move(to_process.back());
|
|
to_process.pop_back();
|
|
// ensure we only scan each pointer value once, otherwise this
|
|
// can become exponential (and if we allow recursive data in the future,
|
|
// it would not terminiate).
|
|
if (w.value.isPtrType()) {
|
|
const void* key = w.value.internalToPointer();
|
|
auto it = scanned.find(key);
|
|
if (it != scanned.end()) {
|
|
continue;
|
|
}
|
|
scanned.emplace_hint(it, key);
|
|
}
|
|
auto kind = w.type->kind();
|
|
if (auto dyn = w.type->castRaw<c10::DynamicType>()) {
|
|
kind = dyn->dynamicKind();
|
|
}
|
|
switch (kind) {
|
|
case TensorType::Kind:
|
|
case StorageType::Kind:
|
|
case NumberType::Kind:
|
|
case FloatType::Kind:
|
|
case ComplexType::Kind:
|
|
case IntType::Kind:
|
|
case NoneType::Kind:
|
|
case GeneratorType::Kind:
|
|
case QuantizerType::Kind:
|
|
case BoolType::Kind:
|
|
case VarType::Kind:
|
|
case CapsuleType::Kind:
|
|
case PyObjectType::Kind:
|
|
case StringType::Kind:
|
|
case FunctionType::Kind:
|
|
case DeviceObjType::Kind:
|
|
case StreamObjType::Kind:
|
|
case QSchemeType::Kind:
|
|
case LayoutType::Kind:
|
|
case MemoryFormatType::Kind:
|
|
case ScalarTypeType::Kind:
|
|
case RRefType::Kind:
|
|
case AnyType::Kind:
|
|
case AnyListType::Kind:
|
|
case AnyTupleType::Kind:
|
|
case AnyClassType::Kind:
|
|
case AnyEnumType::Kind:
|
|
// no op, there is nothing to tag
|
|
break;
|
|
case c10::SymIntType::Kind:
|
|
// TODO: Can this really show up though? :think:
|
|
TORCH_CHECK(!w.value.toSymInt().is_heap_allocated());
|
|
// no op, there is nothing to tag
|
|
break;
|
|
case c10::SymFloatType::Kind:
|
|
TORCH_CHECK(!w.value.toSymFloat().is_symbolic());
|
|
// no op, there is nothing to tag
|
|
break;
|
|
case c10::SymBoolType::Kind:
|
|
TORCH_CHECK(!w.value.toSymBool().is_symbolic());
|
|
// no op, there is nothing to tag
|
|
break;
|
|
case DynamicType::Kind:
|
|
case UnionType::Kind:
|
|
case EnumType::Kind:
|
|
// TODO(gmagogsfm): Implement serialization/deserialization of Enum.
|
|
TORCH_INTERNAL_ASSERT(false);
|
|
case TupleType::Kind: {
|
|
auto t = w.value.toTuple();
|
|
for (size_t i = 0; i < w.type->containedTypeSize(); ++i) {
|
|
Work elem = {w.type->containedType(i), t->elements().at(i)};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case FutureType::Kind: {
|
|
auto f = w.value.toFuture();
|
|
if (f->completed()) {
|
|
Work elem = {w.type->containedType(0), f->value()};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case AwaitType::Kind: {
|
|
auto aw = w.value.toAwait();
|
|
if (aw->completed()) {
|
|
Work elem = {w.type->containedType(0), aw->wait()};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case OptionalType::Kind: {
|
|
if (!w.value.isNone()) {
|
|
Work elem = {w.type->containedType(0), w.value};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case ListType::Kind: {
|
|
// specialized lists do not need their type refined, so we can exit
|
|
// early here
|
|
if (!w.value.isList()) {
|
|
break;
|
|
}
|
|
auto elem_type = w.type->containedType(0);
|
|
auto lst = w.value.toList();
|
|
lst.unsafeSetElementType(elem_type);
|
|
for (const IValue& item : lst) {
|
|
Work elem = {elem_type, item};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case DictType::Kind: {
|
|
auto d = w.value.toGenericDict();
|
|
auto keyType = w.type->containedType(0);
|
|
auto valType = w.type->containedType(1);
|
|
d.unsafeSetKeyType(keyType);
|
|
d.unsafeSetValueType(valType);
|
|
for (const auto& item : d) {
|
|
Work kelem = {keyType, item.key()};
|
|
Work velem = {valType, item.value()};
|
|
to_process.emplace_back(std::move(kelem));
|
|
to_process.emplace_back(std::move(velem));
|
|
}
|
|
} break;
|
|
// in both cases the dynamic type is a class, and we are going to tag with
|
|
// the dynamic type
|
|
case InterfaceType::Kind:
|
|
case ClassType::Kind: {
|
|
auto obj = w.value.toObject();
|
|
auto typ = obj->type(); // note: intentionally using the dynamic type,
|
|
// the static type is potentially less accurate
|
|
for (size_t i = 0; i < typ->numAttributes(); ++i) {
|
|
Work elem = {typ->getAttribute(i), obj->getSlot(i)};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
};
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
template <typename T>
|
|
bool is(const Type& type) {
|
|
if (type.kind() == T::Kind) {
|
|
return true;
|
|
}
|
|
if (auto dyn = type.castRaw<c10::DynamicType>()) {
|
|
return dyn->tag() == c10::DynamicTypeTrait<T>::tagValue();
|
|
}
|
|
return false;
|
|
}
|
|
} // namespace
|
|
|
|
void restoreContainerTypeTags(const IValue& ivalue, const TypePtr& type) {
|
|
if (is<DictType>(*type)) {
|
|
auto dict = ivalue.toGenericDict();
|
|
dict.unsafeSetKeyType(type->containedType(0));
|
|
dict.unsafeSetValueType(type->containedType(1));
|
|
} else if (is<ListType>(*type)) {
|
|
ivalue.toList().unsafeSetElementType(type->containedType(0));
|
|
} else {
|
|
AT_ERROR("Unknown type for tag restoration: " + type->annotation_str());
|
|
}
|
|
}
|
|
|
|
IValue Unpickler::parse_ivalue() {
|
|
run();
|
|
TORCH_CHECK(
|
|
stack_.size() == 1,
|
|
"Unpickler expected 1 element on the stack, but found ",
|
|
stack_.size());
|
|
if (version_ <= 2) {
|
|
// See [type tag serialization]
|
|
restoreAccurateTypeTagsIfPossible(stack_[0]);
|
|
}
|
|
return stack_[0];
|
|
}
|
|
|
|
double Unpickler::readFloat() {
|
|
AT_ASSERT(sizeof(double) == 8);
|
|
double big_endian = read<double>();
|
|
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
double little_endian;
|
|
|
|
// Pickle floats are big endian, so reverse the bytes
|
|
auto big_endian_ptr = reinterpret_cast<const char*>(&big_endian);
|
|
std::reverse_copy(
|
|
big_endian_ptr,
|
|
big_endian_ptr + sizeof(big_endian),
|
|
reinterpret_cast<char*>(&little_endian));
|
|
|
|
return little_endian;
|
|
#else /* __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ */
|
|
return big_endian;
|
|
#endif /* __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ */
|
|
}
|
|
|
|
void Unpickler::run() {
|
|
// Expect a PROTO opcode and protocol number at the start of blob
|
|
auto opcode = readOpCode();
|
|
TORCH_CHECK(
|
|
opcode == PickleOpCode::PROTO,
|
|
"Expected PROTO opcode at the start"
|
|
" of pickle archive, found ",
|
|
int(static_cast<uint8_t>(opcode)));
|
|
uint8_t protocol = read<uint8_t>();
|
|
TORCH_CHECK(
|
|
protocol == 2,
|
|
"Only Pickle protocol 2 is supported, found protocol = ",
|
|
protocol);
|
|
|
|
while (true) {
|
|
PickleOpCode opcode = readInstruction();
|
|
if (opcode == PickleOpCode::STOP) {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
void Unpickler::setInput(size_t memo_id) {
|
|
AT_ASSERT(!stack_.empty());
|
|
if (memo_id >= memo_table_.size()) {
|
|
memo_table_.insert(
|
|
memo_table_.end(), memo_id - memo_table_.size(), IValue());
|
|
memo_table_.push_back(stack_.back());
|
|
} else {
|
|
memo_table_[memo_id] = stack_.back();
|
|
}
|
|
}
|
|
|
|
// emplace_back on bool vectors does not exist on some systems
|
|
// avoid it by calling push_back for bool
|
|
template <typename T>
|
|
inline void append(std::vector<T>& a, T&& e) {
|
|
a.emplace_back(std::forward<T>(e));
|
|
}
|
|
template <>
|
|
inline void append<bool>(std::vector<bool>& a, bool&& e) {
|
|
a.push_back(e);
|
|
}
|
|
|
|
static std::vector<int64_t> tupleToIntList(const IValue& v) {
|
|
return fmap(v.toTupleRef().elements(), [](const IValue& v) -> int64_t {
|
|
return v.toInt();
|
|
});
|
|
}
|
|
|
|
// note we cannot use toIntList, toDoubleList because during unpickling the
|
|
// lists are not yet tagged
|
|
template <typename T>
|
|
static std::vector<T> convertList(const IValue& v) {
|
|
return fmap(v.toListRef(), [](const IValue& elem) { return elem.to<T>(); });
|
|
}
|
|
|
|
PickleOpCode Unpickler::readInstruction() {
|
|
auto opcode = readOpCode();
|
|
switch (opcode) {
|
|
case PickleOpCode::EMPTY_LIST: {
|
|
stack_.emplace_back(c10::impl::GenericList(AnyType::get()));
|
|
} break;
|
|
case PickleOpCode::EMPTY_TUPLE: {
|
|
if (empty_tuple_.isNone()) {
|
|
// we only need one object, since tuples are not mutable.
|
|
empty_tuple_ = c10::ivalue::Tuple::create(std::vector<IValue>());
|
|
}
|
|
stack_.emplace_back(empty_tuple_);
|
|
} break;
|
|
case PickleOpCode::BINPUT: {
|
|
size_t memo_id = read<uint8_t>();
|
|
setInput(memo_id);
|
|
} break;
|
|
case PickleOpCode::LONG_BINPUT: {
|
|
TORCH_CHECK(
|
|
std::numeric_limits<size_t>::max() >=
|
|
std::numeric_limits<uint32_t>::max(),
|
|
"Found a LONG_BINPUT opcode, but size_t on this system is "
|
|
"not big enough to decode it");
|
|
size_t memo_id = read<uint32_t>();
|
|
setInput(memo_id);
|
|
} break;
|
|
case PickleOpCode::MARK: {
|
|
// Mark location of the container ivalue in the stack
|
|
marks_.push_back(stack_.size());
|
|
} break;
|
|
case PickleOpCode::NEWTRUE: {
|
|
stack_.emplace_back(true);
|
|
} break;
|
|
case PickleOpCode::NEWFALSE: {
|
|
stack_.emplace_back(false);
|
|
} break;
|
|
case PickleOpCode::NONE: {
|
|
stack_.emplace_back();
|
|
} break;
|
|
case PickleOpCode::BININT1: {
|
|
uint8_t value = read<uint8_t>();
|
|
stack_.emplace_back(int64_t(value));
|
|
} break;
|
|
case PickleOpCode::BININT2: {
|
|
uint16_t value = from_le16(read<uint16_t>());
|
|
stack_.emplace_back(int64_t(value));
|
|
} break;
|
|
case PickleOpCode::BININT: {
|
|
int32_t value = from_le32(read<int32_t>());
|
|
stack_.emplace_back(int64_t(value));
|
|
} break;
|
|
case PickleOpCode::LONG1: {
|
|
// Only read LONG1s with 8 as the length
|
|
uint8_t length = read<uint8_t>();
|
|
TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length));
|
|
stack_.emplace_back(int64_t(from_le64(read<int64_t>())));
|
|
} break;
|
|
case PickleOpCode::BINUNICODE: {
|
|
uint32_t length = from_le32(read<uint32_t>());
|
|
stack_.emplace_back(readBytes(length));
|
|
} break;
|
|
case PickleOpCode::BINUNICODE8: {
|
|
int64_t length = from_le64(read<int64_t>());
|
|
stack_.emplace_back(readBytes(length));
|
|
} break;
|
|
case PickleOpCode::BINFLOAT:
|
|
stack_.emplace_back(readFloat());
|
|
break;
|
|
case PickleOpCode::TUPLE: {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
marks_.pop_back();
|
|
std::vector<IValue> elements;
|
|
const auto tupleSize = stack_.size() - start;
|
|
switch (tupleSize) {
|
|
case 3: {
|
|
auto e3 = pop(stack_);
|
|
auto e2 = pop(stack_);
|
|
auto e1 = pop(stack_);
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(
|
|
std::move(e1), std::move(e2), std::move(e3)));
|
|
break;
|
|
}
|
|
case 2: {
|
|
auto e2 = pop(stack_);
|
|
auto e1 = pop(stack_);
|
|
stack_.emplace_back(
|
|
c10::ivalue::Tuple::create(std::move(e1), std::move(e2)));
|
|
break;
|
|
}
|
|
case 1:
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_)));
|
|
break;
|
|
default: {
|
|
elements.reserve(stack_.size() - start);
|
|
auto start_it = stack_.begin() + start;
|
|
for (auto it = start_it; it != stack_.end(); ++it) {
|
|
elements.emplace_back(std::move(*it));
|
|
}
|
|
stack_.erase(start_it, stack_.end());
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(std::move(elements)));
|
|
break;
|
|
}
|
|
}
|
|
} break;
|
|
case PickleOpCode::TUPLE1: {
|
|
TORCH_CHECK(
|
|
stack_.size() > 0,
|
|
"Parsing error: stack_ contains ",
|
|
stack_.size(),
|
|
" elements, at least 1 expected");
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_)));
|
|
} break;
|
|
case PickleOpCode::TUPLE2: {
|
|
TORCH_CHECK(
|
|
stack_.size() > 1,
|
|
"Parsing error: stack_ contains ",
|
|
stack_.size(),
|
|
" elements, at least 2 expected");
|
|
auto e2 = pop(stack_);
|
|
auto e1 = pop(stack_);
|
|
stack_.emplace_back(
|
|
c10::ivalue::Tuple::create(std::move(e1), std::move(e2)));
|
|
} break;
|
|
case PickleOpCode::TUPLE3: {
|
|
TORCH_CHECK(
|
|
stack_.size() > 2,
|
|
"Parsing error: stack_ contains ",
|
|
stack_.size(),
|
|
" elements, at least 3 expected");
|
|
auto e3 = pop(stack_);
|
|
auto e2 = pop(stack_);
|
|
auto e1 = pop(stack_);
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(
|
|
std::move(e1), std::move(e2), std::move(e3)));
|
|
} break;
|
|
case PickleOpCode::EMPTY_DICT:
|
|
stack_.emplace_back(
|
|
c10::impl::GenericDict(AnyType::get(), AnyType::get()));
|
|
break;
|
|
case PickleOpCode::APPENDS: {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
TORCH_CHECK(
|
|
start > 0 && start <= stack_.size(),
|
|
"Parsing error: wrong start index for stack_");
|
|
auto list_ivalue = stack_.at(start - 1);
|
|
readList(list_ivalue);
|
|
} break;
|
|
case PickleOpCode::LIST: {
|
|
IValue list_ivalue = c10::impl::GenericList(AnyType::get());
|
|
readList(list_ivalue);
|
|
stack_.push_back(std::move(list_ivalue));
|
|
} break;
|
|
case PickleOpCode::DICT: {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
marks_.pop_back();
|
|
auto dict = c10::impl::GenericDict(AnyType::get(), AnyType::get());
|
|
for (size_t i = start; i < stack_.size(); i += 2) {
|
|
dict.insert_or_assign(stack_[i], stack_[i + 1]);
|
|
}
|
|
stack_.erase(stack_.begin() + start, stack_.end());
|
|
stack_.emplace_back(std::move(dict));
|
|
} break;
|
|
case PickleOpCode::SETITEMS: {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
marks_.pop_back();
|
|
TORCH_CHECK(
|
|
start > 0 && start <= stack_.size(),
|
|
"Parsing error: wrong start index for stack_");
|
|
auto dict = stack_.at(start - 1).toGenericDict();
|
|
for (size_t i = start; i < stack_.size(); i += 2) {
|
|
dict.insert_or_assign(stack_[i], stack_[i + 1]);
|
|
}
|
|
stack_.erase(stack_.begin() + start, stack_.end());
|
|
} break;
|
|
case PickleOpCode::BINGET: {
|
|
auto pos = read<uint8_t>();
|
|
TORCH_CHECK(
|
|
memo_table_.size() > pos,
|
|
"Parsing error: out of bounds access at ",
|
|
(size_t)pos,
|
|
" to memo_table_ which is of size ",
|
|
memo_table_.size());
|
|
stack_.push_back(memo_table_.at(pos));
|
|
} break;
|
|
case PickleOpCode::LONG_BINGET: {
|
|
auto pos = read<uint32_t>();
|
|
TORCH_CHECK(
|
|
memo_table_.size() > pos,
|
|
"Parsing error: out of bounds access at ",
|
|
(size_t)pos,
|
|
" to memo_table_ which is of size ",
|
|
memo_table_.size());
|
|
stack_.push_back(memo_table_.at(pos));
|
|
} break;
|
|
case PickleOpCode::STOP:
|
|
break;
|
|
case PickleOpCode::GLOBAL: {
|
|
// Module name, it's not needed for anything
|
|
auto module_name = readString();
|
|
auto class_name = readString();
|
|
readGlobal(module_name, class_name);
|
|
} break;
|
|
case PickleOpCode::NEWOBJ: {
|
|
TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty");
|
|
// pop empty tuple, the actual action is stored in the globals_stack_
|
|
stack_.pop_back();
|
|
} break;
|
|
// because we have NEWOBJ do nothing, BUILD and REDUCE end up doing
|
|
// the same thing
|
|
case PickleOpCode::BUILD:
|
|
case PickleOpCode::REDUCE: {
|
|
// stack is: <functor_idx> <functor_arg>
|
|
// extract <functor_idx> and remove from the stack:
|
|
TORCH_CHECK(
|
|
stack_.size() > 1,
|
|
"Parsing error: stack_ contains ",
|
|
stack_.size(),
|
|
" elements, at least 2 expected");
|
|
std::swap(*(stack_.end() - 2), *(stack_.end() - 1));
|
|
size_t idx = stack_.back().toInt();
|
|
stack_.pop_back();
|
|
// stack is: <functor_arg>
|
|
TORCH_CHECK(
|
|
idx < globals_.size(),
|
|
"Parsing error: out of bounds access to globals_");
|
|
globals_.at(idx)();
|
|
} break;
|
|
case PickleOpCode::BINPERSID: {
|
|
TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty");
|
|
auto tuple = pop(stack_).toTuple();
|
|
const auto& args = tuple->elements();
|
|
AT_ASSERT(
|
|
args.at(0).toStringRef() == "storage",
|
|
"unknown PERSID key ",
|
|
args.at(0).toStringRef());
|
|
at::ScalarType type = args.at(1).toScalarType();
|
|
const std::string& key = args.at(2).toStringRef();
|
|
|
|
at::Device device(args.at(3).toStringRef());
|
|
if (device_) {
|
|
device = *device_;
|
|
}
|
|
|
|
at::Storage storage;
|
|
if (storage_context_ != nullptr && storage_context_->hasStorage(key)) {
|
|
// for torch.package logic where storage may be loaded already
|
|
storage = storage_context_->getStorage(key);
|
|
} else {
|
|
int64_t numel = args.at(4).toInt();
|
|
caffe2::TypeMeta dtype = at::CPU(type).typeMeta();
|
|
|
|
at::DataPtr storage_ptr;
|
|
if (numel > 0) {
|
|
// If there are no elements in the tensor, there's no point in
|
|
// reading a zero (0) byte file from the input stream and paying
|
|
// that cost.
|
|
storage_ptr = read_record_(key);
|
|
}
|
|
|
|
storage = at::Storage(
|
|
c10::Storage::use_byte_size_t(),
|
|
numel * dtype.itemsize(),
|
|
std::move(storage_ptr),
|
|
/*allocator=*/nullptr,
|
|
/*resizable=*/false); // NB: we didn't set any allocator for the
|
|
// tensor
|
|
if (storage_context_ != nullptr) {
|
|
storage_context_->addStorage(key, storage);
|
|
}
|
|
}
|
|
|
|
auto options = at::CPU(type).options();
|
|
if (use_storage_device_) {
|
|
options = options.device(storage.device());
|
|
device = storage.device();
|
|
}
|
|
|
|
at::Tensor tensor;
|
|
if (options.backend() == c10::Backend::QuantizedCPU) {
|
|
tensor = at::_empty_affine_quantized({}, options, 0, 0)
|
|
.set_(storage, 0, {}, {});
|
|
} else {
|
|
tensor = at::empty({0}, options).set_(storage);
|
|
}
|
|
|
|
if (device.is_cuda() || device.is_xpu() || device.is_meta() ||
|
|
device.is_hpu()) {
|
|
tensor = tensor.to(device, tensor.scalar_type());
|
|
} else if (device.type() != DeviceType::CPU) {
|
|
AT_ERROR(
|
|
"supported devices include CPU, CUDA and HPU, however got ",
|
|
DeviceTypeName(device.type(), false));
|
|
}
|
|
stack_.emplace_back(std::move(tensor));
|
|
} break;
|
|
case PickleOpCode::SETITEM: {
|
|
// At this OpCode, stack looks like
|
|
// | Stack Bottom |
|
|
// | ...... |
|
|
// | Dict | -> (stack_size - 3)
|
|
// | Key | -> (stack_size - 2)
|
|
// | Value | -> (stack_size - 1)
|
|
auto stack_size = stack_.size();
|
|
auto dict_pos = stack_size - 3;
|
|
auto key_pos = stack_size - 2;
|
|
auto val_pos = stack_size - 1;
|
|
auto dict = stack_.at(dict_pos).toGenericDict();
|
|
dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos));
|
|
stack_.erase(stack_.begin() + (key_pos), stack_.end());
|
|
} break;
|
|
default: {
|
|
AT_ERROR(
|
|
"Unknown opcode for unpickling at ",
|
|
reinterpret_cast<void*>(opcode),
|
|
": ",
|
|
int(static_cast<uint8_t>(opcode)));
|
|
} break;
|
|
}
|
|
return opcode;
|
|
}
|
|
|
|
void Unpickler::readGlobal(
|
|
const std::string& module_name,
|
|
const std::string& class_name) {
|
|
if (this->skip_next_read_global) {
|
|
// See [NOTE] skip_next_read_global
|
|
this->skip_next_read_global--;
|
|
if (this->skip_next_read_global == 1) {
|
|
// Pass through to the correct handler
|
|
} else if (this->skip_next_read_global == 0) {
|
|
// Corresponds to the type of `Tensor` being unpickled
|
|
if (module_name != "torch" || class_name != "Tensor") {
|
|
TORCH_WARN(
|
|
"Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++");
|
|
}
|
|
stack_.emplace_back(int64_t(globals_.size() - 1));
|
|
return;
|
|
} else {
|
|
TORCH_CHECK(false, "INVALID VALUES")
|
|
}
|
|
}
|
|
// TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this
|
|
// is only here for bc-compatibility reasons
|
|
if (module_name == "__main__") {
|
|
if (class_name == "TensorID") {
|
|
globals_.emplace_back([this] {
|
|
auto setitem_data = stack_.back();
|
|
stack_.pop_back();
|
|
TORCH_INTERNAL_ASSERT(
|
|
!tensor_table_.empty(),
|
|
"Pickler tried to write a tensor but had no tensor table to write to");
|
|
stack_.emplace_back(tensor_table_.at(setitem_data.toInt()));
|
|
});
|
|
} else if (class_name == "IntList") {
|
|
globals_.emplace_back([this] {
|
|
stack_.back().toList().unsafeSetElementType(IntType::get());
|
|
});
|
|
} else {
|
|
AT_ERROR("Unknown pickler class id", class_name);
|
|
}
|
|
} else if (module_name == "torch.jit._pickle") {
|
|
if (class_name == "build_tensor_from_id") {
|
|
globals_.emplace_back([this] {
|
|
// Pop reduce arg off the stack
|
|
auto data = stack_.back().toTupleRef().elements().at(0);
|
|
stack_.pop_back();
|
|
TORCH_CHECK(
|
|
!tensor_table_.empty(),
|
|
"Found a tensor table reference but Unpickler"
|
|
" has no tensor table\n");
|
|
stack_.emplace_back(tensor_table_.at(data.toInt()));
|
|
});
|
|
} else if (class_name == "restore_type_tag") {
|
|
globals_.emplace_back([this] {
|
|
auto tuple = stack_.back().toTuple();
|
|
const auto& data = tuple->elements();
|
|
auto type_str = data.at(1).toStringRef();
|
|
stack_.pop_back();
|
|
TypePtr type = nullptr;
|
|
auto entry = type_cache_.find(type_str);
|
|
if (entry != type_cache_.end()) {
|
|
type = entry->second;
|
|
} else {
|
|
if (type_resolver_ == nullptr) {
|
|
// If we haven't injected a custom way of retrieving types from
|
|
// names, use a barebones type parser.
|
|
type = type_parser_(type_str);
|
|
} else {
|
|
type = type_resolver_(type_str).type_;
|
|
}
|
|
type_cache_[type_str] = type;
|
|
}
|
|
// TODO: Use lookahead to avoid creating the tuple and immediately
|
|
// destroying it here
|
|
restoreContainerTypeTags(data.at(0), type);
|
|
stack_.emplace_back(data.at(0));
|
|
});
|
|
} else {
|
|
TypePtr elem_type = nullptr;
|
|
if (class_name == "build_intlist") {
|
|
elem_type = IntType::get();
|
|
} else if (class_name == "build_tensorlist") {
|
|
elem_type = TensorType::get();
|
|
} else if (class_name == "build_doublelist") {
|
|
elem_type = FloatType::get();
|
|
} else if (class_name == "build_boollist") {
|
|
elem_type = BoolType::get();
|
|
} else {
|
|
AT_ERROR("Unknown pickler class id ", class_name);
|
|
}
|
|
// Unpickle a list specialization (e.g. List[Tensor], List[int], ...)
|
|
globals_.emplace_back([this, elem_type] {
|
|
// Pop reduce arg off the stack
|
|
auto data = stack_.back().toTupleRef().elements().at(0).toList();
|
|
stack_.pop_back();
|
|
data.unsafeSetElementType(elem_type);
|
|
stack_.emplace_back(std::move(data));
|
|
});
|
|
}
|
|
} else if (
|
|
module_name == "torch._utils" &&
|
|
(class_name == "_rebuild_tensor_v2" ||
|
|
class_name == "_rebuild_qtensor")) {
|
|
// Unpickle a tensor
|
|
bool quantized = class_name == "_rebuild_qtensor";
|
|
rebuildTensor(quantized);
|
|
} else if (
|
|
module_name == "torch._tensor" &&
|
|
(class_name == "_rebuild_from_type_v2")) {
|
|
// Unpickle a Tensor with Python attributes or
|
|
// a Subclassed Tensor.
|
|
rebuildTensorFromTypeV2();
|
|
} else if (
|
|
module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") {
|
|
rebuildSparseTensor();
|
|
} else if (module_name == "builtins" && class_name == "complex") {
|
|
globals_.emplace_back([this] {
|
|
auto tuple = pop(stack_).toTuple();
|
|
const auto& elems = tuple->elements();
|
|
AT_ASSERT(elems.size() == 2);
|
|
auto complex =
|
|
c10::complex<double>(elems.at(0).toDouble(), elems.at(1).toDouble());
|
|
stack_.emplace_back(complex);
|
|
});
|
|
|
|
} else if (module_name == "collections" && class_name == "OrderedDict") {
|
|
// collections.OrderedDict is used in tensor serialization for a tensor's
|
|
// backward hooks (but they are not actually saved with this Pickler)
|
|
globals_.emplace_back([this] {
|
|
// drop the Tuple that was argument to OrderedDict, and replace it
|
|
// with None OrderedDicts only appear in tensor deserialization and
|
|
// their value is never used
|
|
stack_.back() = IValue();
|
|
});
|
|
} else if (module_name == "torch" && class_name == "device") {
|
|
globals_.emplace_back([this] {
|
|
auto device_string = stack_.back().toTupleRef().elements().at(0);
|
|
stack_.pop_back();
|
|
stack_.emplace_back(c10::Device(device_string.toStringRef()));
|
|
});
|
|
stack_.emplace_back(int64_t(globals_.size() - 1));
|
|
return;
|
|
} else if (module_name == "torch.distributed.rpc" && class_name == "rref") {
|
|
#ifdef USE_RPC
|
|
return rebuildRRef();
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"RRef unpickling is only supported with the distributed package");
|
|
#endif
|
|
} else if (module_name == "torch") {
|
|
// Try to manually resolve several global enums
|
|
// NOTE: this does not put a global into the global table,
|
|
// like the other branches here because no REDUCE or BUILD will
|
|
// be called on this value. Instead, we just put it on the stack
|
|
// and return early
|
|
c10::optional<c10::ScalarType> scalar_type;
|
|
#define CHECK_SCALAR(_, name) \
|
|
if (class_name == #name "Storage") { \
|
|
scalar_type = c10::k##name; \
|
|
}
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CHECK_SCALAR)
|
|
#undef CHECK_SCALAR
|
|
if (scalar_type.has_value()) {
|
|
stack_.emplace_back(int64_t(*scalar_type));
|
|
return;
|
|
}
|
|
|
|
c10::optional<at::QScheme> qscheme;
|
|
for (int i = 0; i < at::COMPILE_TIME_NUM_QSCHEMES; ++i) {
|
|
if (class_name == toString(static_cast<at::QScheme>(i))) {
|
|
qscheme = static_cast<at::QScheme>(i);
|
|
}
|
|
}
|
|
if (qscheme.has_value()) {
|
|
stack_.emplace_back(int64_t(*qscheme));
|
|
return;
|
|
}
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unpickler found unknown torch global, 'torch.",
|
|
class_name,
|
|
"'");
|
|
} else {
|
|
TORCH_CHECK(
|
|
type_resolver_,
|
|
"Unpickler found unknown type ",
|
|
module_name,
|
|
".",
|
|
class_name);
|
|
at::StrongTypePtr type =
|
|
type_resolver_(c10::QualifiedName(module_name, class_name));
|
|
if (auto enum_type = type.type_->cast<c10::EnumType>()) {
|
|
globals_.emplace_back([this, enum_type] {
|
|
auto val = stack_.back();
|
|
stack_.pop_back();
|
|
for (const auto& p : enum_type->enumNamesValues()) {
|
|
if (p.second == val) {
|
|
auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
|
|
enum_type, p.first, p.second);
|
|
stack_.emplace_back(std::move(enum_holder));
|
|
return;
|
|
}
|
|
}
|
|
});
|
|
} else {
|
|
// Otherwise, global is a class/object type.
|
|
globals_.emplace_back([this, type] {
|
|
auto val = stack_.back();
|
|
stack_.pop_back();
|
|
auto obj = obj_loader_(type, val);
|
|
stack_.emplace_back(std::move(obj));
|
|
});
|
|
}
|
|
}
|
|
stack_.emplace_back(int64_t(globals_.size() - 1));
|
|
}
|
|
|
|
void Unpickler::rebuildSparseTensor() {
|
|
globals_.emplace_back([this] {
|
|
auto tup = pop(stack_).toTuple();
|
|
const auto& elements = tup->elements();
|
|
size_t idx = 0;
|
|
auto layout = elements.at(idx++).toInt();
|
|
at::Tensor result;
|
|
switch (layout) {
|
|
case static_cast<int>(c10::Layout::Sparse): {
|
|
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
|
|
bool requires_grad = elements.at(idx++).toBool();
|
|
auto& indices_tensor = elements.at(idx++).toTensor();
|
|
auto& values_tensor = elements.at(idx++).toTensor();
|
|
auto options = values_tensor.options()
|
|
.layout(c10::Layout::Sparse)
|
|
.requires_grad(requires_grad);
|
|
result = at::_sparse_coo_tensor_unsafe(
|
|
indices_tensor, values_tensor, size, options);
|
|
result = autograd::make_variable(result, options.requires_grad());
|
|
break;
|
|
}
|
|
case static_cast<int>(c10::Layout::SparseCsr): {
|
|
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
|
|
bool requires_grad = elements.at(idx++).toBool();
|
|
auto& crow_indices = elements.at(idx++).toTensor();
|
|
auto& col_indices = elements.at(idx++).toTensor();
|
|
auto& values_tensor = elements.at(idx++).toTensor();
|
|
auto options = values_tensor.options()
|
|
.layout(c10::Layout::SparseCsr)
|
|
.requires_grad(requires_grad);
|
|
result = at::_sparse_csr_tensor_unsafe(
|
|
crow_indices, col_indices, values_tensor, size, options);
|
|
result =
|
|
autograd::make_variable(std::move(result), options.requires_grad());
|
|
break;
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unsupported sparse tensor layout type in serialization ",
|
|
static_cast<c10::Layout>(layout));
|
|
break;
|
|
}
|
|
stack_.emplace_back(std::move(result));
|
|
});
|
|
}
|
|
|
|
void Unpickler::rebuildTensor(bool quantized) {
|
|
globals_.emplace_back([this, quantized] {
|
|
auto tup = pop(stack_).toTuple();
|
|
const auto& elements = tup->elements();
|
|
size_t idx = 0;
|
|
auto& storage_tensor = elements.at(idx++).toTensor();
|
|
int64_t storage_offset = elements.at(idx++).toInt();
|
|
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
|
|
std::vector<int64_t> stride = tupleToIntList(elements.at(idx++));
|
|
at::Tensor result;
|
|
if (quantized) {
|
|
auto qparams_tuple = elements.at(idx++).toTuple();
|
|
const auto& qparams = qparams_tuple->elements();
|
|
auto qscheme = static_cast<at::QScheme>(qparams.at(0).toInt());
|
|
switch (qscheme) {
|
|
case at::kPerTensorAffine: {
|
|
double q_scale = qparams.at(1).toDouble();
|
|
int64_t q_zero_point = qparams.at(2).toInt();
|
|
result = at::_empty_affine_quantized(
|
|
{0}, storage_tensor.options(), q_scale, q_zero_point);
|
|
} break;
|
|
case at::kPerChannelAffineFloatQParams:
|
|
case at::kPerChannelAffine: {
|
|
const auto& scales = qparams.at(1).toTensor();
|
|
const auto& zero_points = qparams.at(2).toTensor();
|
|
int64_t axis = qparams.at(3).toInt();
|
|
result = at::_empty_per_channel_affine_quantized(
|
|
{0}, scales, zero_points, axis, storage_tensor.options());
|
|
} break;
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unsupported tensor quantization type in serialization ",
|
|
toString(qscheme));
|
|
break;
|
|
}
|
|
} else {
|
|
result = at::empty({0}, storage_tensor.options());
|
|
}
|
|
bool requires_grad = elements.at(idx++).toBool();
|
|
idx++; // backwards hooks is empty
|
|
at::TensorImpl* impl = result.unsafeGetTensorImpl();
|
|
impl->set_storage_keep_dtype(storage_tensor.storage());
|
|
impl->set_storage_offset(storage_offset);
|
|
impl->set_sizes_and_strides(size, stride);
|
|
result = autograd::make_variable(result, requires_grad);
|
|
|
|
// Handle if math_bits were pickled.
|
|
// See `args` of _reduce_ex_internal
|
|
// for a regular tensor (final else case).
|
|
// Tensors pickled before this patch didn't
|
|
// have this argument for storing MathBits,
|
|
// in that case, we do nothing.
|
|
// NOTE: `math_bits` is the 7th arg.
|
|
// NOTE: This is only meant for regular tensor and not quantized
|
|
// which also has 7 args serialized.
|
|
if (!quantized && elements.size() == 7) {
|
|
auto math_bits = elements.at(idx++).toGenericDict();
|
|
torch::jit::setTensorMetadata(result, math_bits);
|
|
}
|
|
|
|
stack_.emplace_back(std::move(result));
|
|
});
|
|
}
|
|
|
|
void Unpickler::rebuildTensorFromTypeV2() {
|
|
// [NOTE] skip_next_read_global
|
|
// When rebuilding Tensor with Python Attr or Subclassed Tensor,
|
|
// we receive `(func, type(self), args, state)` on stack for
|
|
// `rebuildTensorFromTypeV2`.
|
|
// Thus next call to readGlobal corresponds to `func` which is
|
|
// the function to rebuild the base tensor.
|
|
// The call after `func` to readGlobal corresponds to `type` of the
|
|
// Tensor where we raise warning if the type is not `torch.Tensor`.
|
|
this->skip_next_read_global = 2;
|
|
auto curr_globals_idx = globals_.size();
|
|
globals_.emplace_back([this, curr_globals_idx] {
|
|
// args is a tuple with following data
|
|
// (function to rebuild base tensor, type of tensor,
|
|
// arguments to construct base tensor, Python State (as dict))
|
|
auto args = pop(stack_).toTuple();
|
|
size_t tup_idx = 0;
|
|
const auto args_elems = args->elements();
|
|
auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple();
|
|
auto py_state = args_elems.at(tup_idx + 3).toGenericDict();
|
|
if (!py_state.empty()) {
|
|
TORCH_WARN(
|
|
"Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded");
|
|
}
|
|
// This calls the function to rebuild the
|
|
// base tensor.
|
|
// Eg. `rebuildTensor`, `rebuildSpareTensor`.
|
|
stack_.emplace_back(base_tensor_args);
|
|
globals_[curr_globals_idx + 1]();
|
|
stack_.emplace_back(pop(stack_));
|
|
});
|
|
}
|
|
|
|
#ifdef USE_RPC
|
|
void Unpickler::rebuildRRef() {
|
|
globals_.emplace_back([this] {
|
|
// It is the same as how rref is unpickled in python,
|
|
// see PyRRef::unpickle
|
|
auto tuple = std::move(stack_.back()).toTuple();
|
|
const auto& args = tuple->elements();
|
|
stack_.pop_back();
|
|
TORCH_INTERNAL_ASSERT(
|
|
args.size() == distributed::rpc::RFD_TUPLE_SIZE,
|
|
"Pickled RRefForkData must contain 7 numbers.");
|
|
auto ownerId =
|
|
static_cast<int16_t>(args.at(distributed::rpc::OWNER_IDX).toInt());
|
|
// const reference will extend the lifetime of the temporary variable
|
|
const auto& rrefId = distributed::rpc::RRefId(
|
|
static_cast<int16_t>(args.at(distributed::rpc::RREFID_ON_IDX).toInt()),
|
|
static_cast<int64_t>(args.at(distributed::rpc::RREFID_ID_IDX).toInt()));
|
|
const auto& forkId = distributed::rpc::RRefId(
|
|
static_cast<int16_t>(args.at(distributed::rpc::FORKID_ON_IDX).toInt()),
|
|
static_cast<int64_t>(args.at(distributed::rpc::FORKID_ID_IDX).toInt()));
|
|
auto parent =
|
|
static_cast<int16_t>(args.at(distributed::rpc::PARENT_IDX).toInt());
|
|
const auto& typeStr = static_cast<std::string>(
|
|
args.at(distributed::rpc::TYPE_IDX).toStringRef());
|
|
auto rrefForkData = distributed::rpc::RRefForkData(
|
|
ownerId, rrefId, forkId, parent, typeStr);
|
|
auto& ctx = distributed::rpc::RRefContext::getInstance();
|
|
c10::intrusive_ptr<distributed::rpc::RRef> rref;
|
|
TORCH_INTERNAL_ASSERT(
|
|
type_resolver_ != nullptr, "type_resolver_ is nullptr.");
|
|
at::StrongTypePtr type = type_resolver_(c10::QualifiedName(typeStr));
|
|
rref = ctx.getOrCreateRRef(rrefForkData, type.type_);
|
|
ctx.notifyOwnerAndParentOfFork(
|
|
rrefForkData.forkId_, rrefForkData.parent_, rref);
|
|
stack_.emplace_back(
|
|
c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref));
|
|
});
|
|
stack_.emplace_back(int64_t(globals_.size() - 1));
|
|
return;
|
|
}
|
|
#endif
|
|
|
|
void Unpickler::readSlowWithBuffer(char* dest, size_t sz) {
|
|
// First, read any partial from buffer (may be 0).
|
|
// We explicitly assume that sz > buffer_remaining_,
|
|
// and that sz is never bigger than buffer_.size().
|
|
AT_ASSERT(sz > buffer_remaining_);
|
|
const size_t from_old_buf = buffer_remaining_;
|
|
if (from_old_buf != 0) {
|
|
memcpy(dest, buffer_.data() + buffer_pos_, from_old_buf);
|
|
}
|
|
const size_t needed = sz - from_old_buf;
|
|
// Full read into the buffer. The calls here all explicitly
|
|
// assume that one buffer will be enough for any sz.
|
|
AT_ASSERT(sz <= buffer_.size());
|
|
buffer_remaining_ = reader_(buffer_.data(), buffer_.size());
|
|
if (buffer_remaining_ < needed) {
|
|
AT_ERROR("Unexpected end of pickler archive.");
|
|
}
|
|
memcpy(dest + from_old_buf, buffer_.data(), needed);
|
|
buffer_pos_ = needed; // assignment (0'ed from read)
|
|
buffer_remaining_ -= needed;
|
|
}
|
|
|
|
// Read a number of bytes from the input stream
|
|
std::string Unpickler::readBytes(size_t length) {
|
|
std::string data;
|
|
static const size_t kSmallString = 64;
|
|
if (length <= buffer_remaining_) {
|
|
// Fast-path: entirely in buffer.
|
|
data.assign(buffer_.data() + buffer_pos_, length);
|
|
buffer_pos_ += length;
|
|
buffer_remaining_ -= length;
|
|
} else if (length <= kSmallString) {
|
|
// If the string is smallish, do a full buffer read,
|
|
// and read out of that buffer.
|
|
data.resize(length);
|
|
readSlowWithBuffer(&data[0], length);
|
|
} else {
|
|
// Otherwise, for larger strings, read what we can from
|
|
// the buffer, and then read directly to the destination.
|
|
const size_t from_old_buf = buffer_remaining_;
|
|
if (from_old_buf != 0) {
|
|
data.reserve(length);
|
|
data.append(buffer_.data() + buffer_pos_, from_old_buf);
|
|
}
|
|
data.resize(length);
|
|
const size_t needed = length - from_old_buf;
|
|
size_t nread = reader_(&data[from_old_buf], needed);
|
|
if (nread != needed) {
|
|
AT_ERROR("Unexpected end of pickler archive.");
|
|
}
|
|
buffer_remaining_ = 0;
|
|
// buffer_pos_ has no meaning with buffer_remaining_ == 0.
|
|
}
|
|
return data;
|
|
}
|
|
|
|
// Pop all the list items off of the stack and append them to the list at
|
|
// the corresponding MARK
|
|
void Unpickler::readList(IValue list_ivalue) {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
marks_.pop_back();
|
|
auto num_elements = stack_.size() - start;
|
|
auto elements = c10::ArrayRef<IValue>(stack_).slice(start);
|
|
if (list_ivalue.isIntList()) {
|
|
auto list = std::move(list_ivalue).toIntList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.emplace_back(elem.toInt());
|
|
}
|
|
} else if (list_ivalue.isTensorList()) {
|
|
auto list = std::move(list_ivalue).toTensorList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.emplace_back(elem.toTensor());
|
|
}
|
|
} else if (list_ivalue.isDoubleList()) {
|
|
auto list = std::move(list_ivalue).toDoubleList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.emplace_back(elem.toDouble());
|
|
}
|
|
} else if (list_ivalue.isBoolList()) {
|
|
auto list = std::move(list_ivalue).toBoolList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.push_back(elem.toBool());
|
|
}
|
|
} else if (list_ivalue.isList()) {
|
|
auto list = std::move(list_ivalue).toList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.emplace_back(elem);
|
|
}
|
|
} else {
|
|
AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind());
|
|
}
|
|
|
|
stack_.erase(stack_.begin() + start, stack_.end());
|
|
}
|
|
|
|
inline bool is_valid_python_id_char(char c) {
|
|
return c == '_' || c == '.' || (c >= '0' && c <= '9') ||
|
|
(c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
|
|
}
|
|
|
|
// Read a newline terminated string
|
|
std::string Unpickler::readString() {
|
|
std::string ss;
|
|
while (true) {
|
|
auto* const bufferStart = buffer_.data() + buffer_pos_;
|
|
const auto bufferLeft = buffer_.size() - buffer_pos_;
|
|
char* const newlinePtr =
|
|
static_cast<char*>(memchr(bufferStart, '\n', bufferLeft));
|
|
if (newlinePtr) {
|
|
// read up to newline and we are done.
|
|
auto const charsRead = newlinePtr - bufferStart;
|
|
ss.append(bufferStart, charsRead);
|
|
buffer_remaining_ -= charsRead + 1;
|
|
buffer_pos_ += charsRead + 1;
|
|
break;
|
|
} else {
|
|
// read whole buffer, refill
|
|
for (const char* p = bufferStart; p < bufferStart + bufferLeft; ++p) {
|
|
// Simple check just in case there is no terminating '\n'
|
|
TORCH_CHECK(
|
|
is_valid_python_id_char(*p),
|
|
"Found character '",
|
|
int(uint8_t(*p)),
|
|
"' in string, ",
|
|
"strings must be qualified Python identifiers");
|
|
}
|
|
ss.append(bufferStart, bufferLeft);
|
|
buffer_remaining_ = reader_(buffer_.data(), buffer_.size());
|
|
buffer_pos_ = 0;
|
|
}
|
|
}
|
|
return ss;
|
|
}
|
|
|
|
} // namespace torch::jit
|