mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Hi! I've been fuzzing different pytorch modules, and found a few crashes. Proposed checks fixes multiple segmentation faults and heap buffer overflows that was found during fuzzing pytorch with [sydr-fuzz](https://github.com/ispras/oss-sydr-fuzz/tree/master/projects/pytorch). ### Crash files ### 1) Heap buffer overflow that leads to crash [crash-842314913bf1820ec19cddfbb7400ffdbb756920.zip](https://github.com/pytorch/pytorch/files/9461316/crash-842314913bf1820ec19cddfbb7400ffdbb756920.zip) ``` "AsanReport": [ "==3751==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x619000033478 at pc 0x0000005f9bc3 bp 0x7fffffff1eb0 sp 0x7fffffff1ea8\n", "READ of size 4 at 0x619000033478 thread T0\n", "[Detaching after fork from child process 3762]\n", " #0 0x5f9bc2 in c10::IValue::IValue(c10::IValue&&) /pytorch_fuzz/aten/src/ATen/core/ivalue.h:192:43\n", " #1 0x9ecd0a7 in torch::jit::pop(std::vector<c10::IValue, std::allocator<c10::IValue> >&) /pytorch_fuzz/aten/src/ATen/core/stack.h:102:12\n", " #2 0x9ecd0a7 in torch::jit::Unpickler::readInstruction() /pytorch_fuzz/torch/csrc/jit/serialization/unpickler.cpp:380:17\n", " #3 0x9ecafc7 in torch::jit::Unpickler::run() /pytorch_fuzz/torch/csrc/jit/serialization/unpickler.cpp:226:27\n", " #4 0x9ecac62 in torch::jit::Unpickler::parse_ivalue() /pytorch_fuzz/torch/csrc/jit/serialization/unpickler.cpp:183:3\n", " #5 0x9e45996 in torch::jit::unpickle(std::function<unsigned long (char*, unsigned long)>, std::function<c10::StrongTypePtr (c10::QualifiedName const&)>, c10::ArrayRef<at::Tensor>, c10::Type::SingletonOrSharedTypePtr<c10::Type> (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)) /pytorch_fuzz/torch/csrc/jit/serialization/pickle.cpp:127:20\n", " #6 0x9e4626d in torch::jit::unpickle(char const*, unsigned long, std::function<c10::StrongTypePtr (c10::QualifiedName const&)>, c10::ArrayRef<at::Tensor>, c10::Type::SingletonOrSharedTypePtr<c10::Type> (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)) /pytorch_fuzz/torch/csrc/jit/serialization/pickle.cpp:137:10\n", ``` 2) Segmentation fault [crash-e690c58718e88921350562f0b4d9180938145d77.zip](https://github.com/pytorch/pytorch/files/9461331/crash-e690c58718e88921350562f0b4d9180938145d77.zip) ``` "AsanReport": [ "==3744==ERROR: AddressSanitizer: SEGV on unknown address (pc 0x000009122754 bp 0x7fffffff5290 sp 0x7fffffff5270 T0)\n", "==3744==The signal is caused by a READ memory access.\n", "==3744==Hint: this fault was caused by a dereference of a high value address (see register values below). Disassemble the provided pc to learn which register was used.\n", "[Detaching after fork from child process 3763]\n", " #0 0x9122754 in c10::intrusive_ptr<torch::jit::Tree, c10::detail::intrusive_target_default_null_type<torch::jit::Tree> >::retain_() /pytorch_fuzz/c10/util/intrusive_ptr.h:269:54\n", " #1 0x9127929 in c10::intrusive_ptr<torch::jit::Tree, c10::detail::intrusive_target_default_null_type<torch::jit::Tree> >::intrusive_ptr(c10::intrusive_ptr<torch::jit::Tree, c10::detail::intrusive_target_default_null_type<torch::jit::Tree> > const&) /pytorch_fuzz/c10/util/intrusive_ptr.h:352:5\n", " #2 0x9127929 in torch::jit::Expr::Expr(c10::intrusive_ptr<torch::jit::Tree, c10::detail::intrusive_target_default_null_type<torch::jit::Tree> > const&) /pytorch_fuzz/torch/csrc/jit/frontend/tree_views.h:269:49\n", " #3 0x91b1bbb in torch::jit::Maybe<torch::jit::Expr>::get() const /pytorch_fuzz/torch/csrc/jit/frontend/tree_views.h:211:12\n", " #4 0x92a8f74 in torch::jit::ScriptTypeParser::parseClassConstant(torch::jit::Assign const&) /pytorch_fuzz/torch/csrc/jit/frontend/script_type_parser.cpp:461:41\n", " #5 0x9e1c09b in torch::jit::SourceImporterImpl::importClass(c10::QualifiedName const&, torch::jit::ClassDef const&, bool) /pytorch_fuzz/torch/csrc/jit/serialization/import_source.cpp:549:34\n", " #6 0x9e13f00 in torch::jit::SourceImporterImpl::importNamedType(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, torch::jit::ClassDef const&) /pytorch_fuzz/torch/csrc/jit/serialization/import_source.cpp:288:5\n", " #7 0x9e11fbc in torch::jit::SourceImporterImpl::findNamedType(c10::QualifiedName const&) /pytorch_fuzz/torch/csrc/jit/serialization/import_source.cpp:140:5\n", ``` 3) Unhandled out of bounds access in a vector [crash-ccd524e7ba19a37982dd91e0d6fc06bb26dd0b10.zip](https://github.com/pytorch/pytorch/files/9461367/crash-ccd524e7ba19a37982dd91e0d6fc06bb26dd0b10.zip) ``` "AsanReport": [ "==3792== ERROR: libFuzzer: deadly signal\n", "[Detaching after fork from child process 3809]\n", " #0 0x59cc11 in __sanitizer_print_stack_trace /llvm-project/compiler-rt/lib/asan/asan_stack.cpp:87:3\n", " #1 0x511547 in fuzzer::PrintStackTrace() /llvm-project/compiler-rt/lib/fuzzer/FuzzerUtil.cpp:210:5\n", " #2 0x4f7753 in fuzzer::Fuzzer::CrashCallback() /llvm-project/compiler-rt/lib/fuzzer/FuzzerLoop.cpp:233:3\n", " #3 0x7ffff7c6741f (/lib/x86_64-linux-gnu/libpthread.so.0+0x1441f)\n", " #4 0x7ffff7a8700a in __libc_signal_restore_set /build/glibc-SzIz7B/glibc-2.31/signal/../sysdeps/unix/sysv/linux/internal-signals.h:86:3\n", " #5 0x7ffff7a8700a in raise /build/glibc-SzIz7B/glibc-2.31/signal/../sysdeps/unix/sysv/linux/raise.c:48:3\n", " #6 0x7ffff7a66858 in abort /build/glibc-SzIz7B/glibc-2.31/stdlib/abort.c:79:7\n", " #7 0x7ffff7e73910 (/lib/x86_64-linux-gnu/libstdc++.so.6+0x9e910)\n", " #8 0x7ffff7e7f38b (/lib/x86_64-linux-gnu/libstdc++.so.6+0xaa38b)\n", " #9 0x7ffff7e7f3f6 in std::terminate() (/lib/x86_64-linux-gnu/libstdc++.so.6+0xaa3f6)\n", " #10 0x7ffff7e7f6a8 in __cxa_throw (/lib/x86_64-linux-gnu/libstdc++.so.6+0xaa6a8)\n", " #11 0x7ffff7e763aa (/lib/x86_64-linux-gnu/libstdc++.so.6+0xa13aa)\n", " #12 0x6aeedf in std::vector<c10::IValue, std::allocator<c10::IValue> >::_M_range_check(unsigned long) const /usr/bin/../lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/bits/stl_vector.h:1073:4\n", " #13 0x9ecd66c in torch::jit::Unpickler::readInstruction() /pytorch_fuzz/torch/csrc/jit/serialization/unpickler.cpp\n", " #14 0x9ecafc7 in torch::jit::Unpickler::run() /pytorch_fuzz/torch/csrc/jit/serialization/unpickler.cpp:226:27\n", " #15 0x9ecac62 in torch::jit::Unpickler::parse_ivalue() /pytorch_fuzz/torch/csrc/jit/serialization/unpickler.cpp:183:3\n", ``` Some other crashes found by fuzzer: [crash-0cab888cbd1e9fea92ab6ddeadf40b958b87d62b.zip](https://github.com/pytorch/pytorch/files/9461406/crash-0cab888cbd1e9fea92ab6ddeadf40b958b87d62b.zip) [crash-04c9ba8e3b0f15028fd0fb0ed014fd352e182a1d.zip](https://github.com/pytorch/pytorch/files/9461407/crash-04c9ba8e3b0f15028fd0fb0ed014fd352e182a1d.zip) [crash-422ad8c3a3472980ba751f4c7f79cf2b53e49927.zip](https://github.com/pytorch/pytorch/files/9461408/crash-422ad8c3a3472980ba751f4c7f79cf2b53e49927.zip) ### How to reproduce ### 1. To reproduce the crashes, use provided docker: [Dockerfile](https://github.com/ispras/oss-sydr-fuzz/blob/master/projects/pytorch/Dockerfile) 2. Build the container: `docker build -t oss-sydr-fuzz-pytorch-reproduce .` 3. Copy crash file to the current directory 4. Run the container: `` docker run --privileged --network host -v `pwd`:/homedir --rm -it oss-sydr-fuzz-pytorch-reproduce /bin/bash `` 5. And execute fuzz-targets with provided crash-files. After execution completes you will see ASAN reports. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94815 Approved by: https://github.com/davidberard98
1148 lines
41 KiB
C++
1148 lines
41 KiB
C++
#include <ATen/ATen.h>
|
|
#include <ATen/core/Dict.h>
|
|
#ifdef USE_RPC
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#endif
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/mobile/type_parser.h>
|
|
#include <torch/csrc/jit/serialization/pickler.h>
|
|
#include <torch/csrc/jit/serialization/storage_context.h>
|
|
#include <torch/csrc/jit/serialization/unpickler.h>
|
|
#include <string>
|
|
|
|
namespace torch::jit {
|
|
|
|
using ::c10::IValue;
|
|
|
|
static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
|
|
if (root.isObject()) {
|
|
restoreAccurateTypeTags(root, root.type());
|
|
}
|
|
}
|
|
|
|
// Pickled objects are stored in a form compatible with Python pickling.
|
|
// In torchscript List[T]/Dict[K, V] are statically typed and contain
|
|
// dynamic type tags that allow T, K, and V to be recovered. But this
|
|
// info is not stored in the Python pickling information. However, we
|
|
// can recover this information from the static type of the top-level
|
|
// object being unpickled, because we have a record of the type of the
|
|
// objects it contains as attributes.
|
|
// `IfPossible` - we can only do this recovery when we have an object as
|
|
// the top-level unpickled thing (which is guaranteed for Modules, but
|
|
// not for torch.load/torch.save). Otherwise we do not know the types
|
|
// of the contained objects and cannot restore the tags.
|
|
void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
|
struct Work {
|
|
TypePtr type;
|
|
IValue value;
|
|
};
|
|
std::vector<Work> to_process = {{type_tag, root}};
|
|
std::unordered_set<const void*> scanned;
|
|
while (!to_process.empty()) {
|
|
Work w = std::move(to_process.back());
|
|
to_process.pop_back();
|
|
// ensure we only scan each pointer value once, otherwise this
|
|
// can become exponential (and if we allow recursive data in the future,
|
|
// it would not terminiate).
|
|
if (w.value.isPtrType()) {
|
|
const void* key = w.value.internalToPointer();
|
|
auto it = scanned.find(key);
|
|
if (it != scanned.end()) {
|
|
continue;
|
|
}
|
|
scanned.emplace_hint(it, key);
|
|
}
|
|
auto kind = w.type->kind();
|
|
if (auto dyn = w.type->castRaw<c10::DynamicType>()) {
|
|
kind = dyn->dynamicKind();
|
|
}
|
|
switch (kind) {
|
|
case TensorType::Kind:
|
|
case StorageType::Kind:
|
|
case NumberType::Kind:
|
|
case FloatType::Kind:
|
|
case ComplexType::Kind:
|
|
case IntType::Kind:
|
|
case NoneType::Kind:
|
|
case GeneratorType::Kind:
|
|
case QuantizerType::Kind:
|
|
case BoolType::Kind:
|
|
case VarType::Kind:
|
|
case CapsuleType::Kind:
|
|
case PyObjectType::Kind:
|
|
case StringType::Kind:
|
|
case FunctionType::Kind:
|
|
case DeviceObjType::Kind:
|
|
case StreamObjType::Kind:
|
|
case QSchemeType::Kind:
|
|
case LayoutType::Kind:
|
|
case MemoryFormatType::Kind:
|
|
case ScalarTypeType::Kind:
|
|
case RRefType::Kind:
|
|
case AnyType::Kind:
|
|
case AnyListType::Kind:
|
|
case AnyTupleType::Kind:
|
|
case AnyClassType::Kind:
|
|
case AnyEnumType::Kind:
|
|
// no op, there is nothing to tag
|
|
break;
|
|
case c10::SymIntType::Kind:
|
|
TORCH_CHECK(!w.value.toSymInt().is_symbolic());
|
|
// no op, there is nothing to tag
|
|
break;
|
|
case c10::SymFloatType::Kind:
|
|
TORCH_CHECK(!w.value.toSymFloat().is_symbolic());
|
|
// no op, there is nothing to tag
|
|
break;
|
|
case DynamicType::Kind:
|
|
case UnionType::Kind:
|
|
case EnumType::Kind:
|
|
// TODO(gmagogsfm): Implement serialization/deserialization of Enum.
|
|
TORCH_INTERNAL_ASSERT(false);
|
|
case TupleType::Kind: {
|
|
auto t = w.value.toTuple();
|
|
for (size_t i = 0; i < w.type->containedTypeSize(); ++i) {
|
|
Work elem = {w.type->containedType(i), t->elements().at(i)};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case FutureType::Kind: {
|
|
auto f = w.value.toFuture();
|
|
if (f->completed()) {
|
|
Work elem = {w.type->containedType(0), f->value()};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case AwaitType::Kind: {
|
|
auto aw = w.value.toAwait();
|
|
if (aw->completed()) {
|
|
Work elem = {w.type->containedType(0), aw->wait()};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case OptionalType::Kind: {
|
|
if (!w.value.isNone()) {
|
|
Work elem = {w.type->containedType(0), w.value};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case ListType::Kind: {
|
|
// specialized lists do not need their type refined, so we can exit
|
|
// early here
|
|
if (!w.value.isList()) {
|
|
break;
|
|
}
|
|
auto elem_type = w.type->containedType(0);
|
|
auto lst = w.value.toList();
|
|
lst.unsafeSetElementType(elem_type);
|
|
for (const IValue& item : lst) {
|
|
Work elem = {elem_type, item};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
} break;
|
|
case DictType::Kind: {
|
|
auto d = w.value.toGenericDict();
|
|
auto keyType = w.type->containedType(0);
|
|
auto valType = w.type->containedType(1);
|
|
d.unsafeSetKeyType(keyType);
|
|
d.unsafeSetValueType(valType);
|
|
for (const auto& item : d) {
|
|
Work kelem = {keyType, item.key()};
|
|
Work velem = {valType, item.value()};
|
|
to_process.emplace_back(std::move(kelem));
|
|
to_process.emplace_back(std::move(velem));
|
|
}
|
|
} break;
|
|
// in both cases the dynamic type is a class, and we are going to tag with
|
|
// the dynamic type
|
|
case InterfaceType::Kind:
|
|
case ClassType::Kind: {
|
|
auto obj = w.value.toObject();
|
|
auto typ = obj->type(); // note: intentionally using the dynamic type,
|
|
// the static type is potentially less accurate
|
|
for (size_t i = 0; i < typ->numAttributes(); ++i) {
|
|
Work elem = {typ->getAttribute(i), obj->getSlot(i)};
|
|
to_process.emplace_back(std::move(elem));
|
|
}
|
|
};
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
template <typename T>
|
|
bool is(const Type& type) {
|
|
if (type.kind() == T::Kind) {
|
|
return true;
|
|
}
|
|
if (auto dyn = type.castRaw<c10::DynamicType>()) {
|
|
return dyn->tag() == c10::DynamicTypeTrait<T>::tagValue();
|
|
}
|
|
return false;
|
|
}
|
|
} // namespace
|
|
|
|
void restoreContainerTypeTags(const IValue& ivalue, const TypePtr& type) {
|
|
if (is<DictType>(*type)) {
|
|
auto dict = ivalue.toGenericDict();
|
|
dict.unsafeSetKeyType(type->containedType(0));
|
|
dict.unsafeSetValueType(type->containedType(1));
|
|
} else if (is<ListType>(*type)) {
|
|
ivalue.toList().unsafeSetElementType(type->containedType(0));
|
|
} else {
|
|
AT_ERROR("Unknown type for tag restoration: " + type->annotation_str());
|
|
}
|
|
}
|
|
|
|
IValue Unpickler::parse_ivalue() {
|
|
run();
|
|
TORCH_CHECK(
|
|
stack_.size() == 1,
|
|
"Unpickler expected 1 element on the stack, but found ",
|
|
stack_.size());
|
|
if (version_ <= 2) {
|
|
// See [type tag serialization]
|
|
restoreAccurateTypeTagsIfPossible(stack_[0]);
|
|
}
|
|
return stack_[0];
|
|
}
|
|
|
|
double Unpickler::readFloat() {
|
|
AT_ASSERT(sizeof(double) == 8);
|
|
double big_endian = read<double>();
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
double little_endian;
|
|
|
|
// Pickle floats are big endian, so reverse the bytes
|
|
auto big_endian_ptr = reinterpret_cast<const char*>(&big_endian);
|
|
std::reverse_copy(
|
|
big_endian_ptr,
|
|
big_endian_ptr + sizeof(big_endian),
|
|
reinterpret_cast<char*>(&little_endian));
|
|
|
|
return little_endian;
|
|
}
|
|
|
|
void Unpickler::run() {
|
|
// Expect a PROTO opcode and protocol number at the start of blob
|
|
auto opcode = readOpCode();
|
|
TORCH_CHECK(
|
|
opcode == PickleOpCode::PROTO,
|
|
"Expected PROTO opcode at the start"
|
|
" of pickle archive, found ",
|
|
int(static_cast<uint8_t>(opcode)));
|
|
uint8_t protocol = read<uint8_t>();
|
|
TORCH_CHECK(
|
|
protocol == 2,
|
|
"Only Pickle protocol 2 is supported, found protocol = ",
|
|
protocol);
|
|
|
|
while (true) {
|
|
PickleOpCode opcode = readInstruction();
|
|
if (opcode == PickleOpCode::STOP) {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
void Unpickler::setInput(size_t memo_id) {
|
|
AT_ASSERT(!stack_.empty());
|
|
if (memo_id >= memo_table_.size()) {
|
|
memo_table_.insert(
|
|
memo_table_.end(), memo_id - memo_table_.size(), IValue());
|
|
memo_table_.push_back(stack_.back());
|
|
} else {
|
|
memo_table_[memo_id] = stack_.back();
|
|
}
|
|
}
|
|
|
|
// emplace_back on bool vectors does not exist on some systems
|
|
// avoid it by calling push_back for bool
|
|
template <typename T>
|
|
inline void append(std::vector<T>& a, T&& e) {
|
|
a.emplace_back(std::forward<T>(e));
|
|
}
|
|
template <>
|
|
inline void append<bool>(std::vector<bool>& a, bool&& e) {
|
|
a.push_back(e);
|
|
}
|
|
|
|
static std::vector<int64_t> tupleToIntList(const IValue& v) {
|
|
return fmap(v.toTupleRef().elements(), [](const IValue& v) -> int64_t {
|
|
return v.toInt();
|
|
});
|
|
}
|
|
|
|
// note we cannot use toIntList, toDoubleList because during unpickling the
|
|
// lists are not yet tagged
|
|
template <typename T>
|
|
static std::vector<T> convertList(const IValue& v) {
|
|
return fmap(v.toListRef(), [](const IValue& elem) { return elem.to<T>(); });
|
|
}
|
|
|
|
PickleOpCode Unpickler::readInstruction() {
|
|
auto opcode = readOpCode();
|
|
switch (opcode) {
|
|
case PickleOpCode::EMPTY_LIST: {
|
|
stack_.emplace_back(c10::impl::GenericList(AnyType::get()));
|
|
} break;
|
|
case PickleOpCode::EMPTY_TUPLE: {
|
|
if (empty_tuple_.isNone()) {
|
|
// we only need one object, since tuples are not mutable.
|
|
empty_tuple_ = c10::ivalue::Tuple::create(std::vector<IValue>());
|
|
}
|
|
stack_.emplace_back(empty_tuple_);
|
|
} break;
|
|
case PickleOpCode::BINPUT: {
|
|
size_t memo_id = read<uint8_t>();
|
|
setInput(memo_id);
|
|
} break;
|
|
case PickleOpCode::LONG_BINPUT: {
|
|
TORCH_CHECK(
|
|
std::numeric_limits<size_t>::max() >=
|
|
std::numeric_limits<uint32_t>::max(),
|
|
"Found a LONG_BINPUT opcode, but size_t on this system is "
|
|
"not big enough to decode it");
|
|
size_t memo_id = read<uint32_t>();
|
|
setInput(memo_id);
|
|
} break;
|
|
case PickleOpCode::MARK: {
|
|
// Mark location of the container ivalue in the stack
|
|
marks_.push_back(stack_.size());
|
|
} break;
|
|
case PickleOpCode::NEWTRUE: {
|
|
stack_.emplace_back(true);
|
|
} break;
|
|
case PickleOpCode::NEWFALSE: {
|
|
stack_.emplace_back(false);
|
|
} break;
|
|
case PickleOpCode::NONE: {
|
|
stack_.emplace_back();
|
|
} break;
|
|
case PickleOpCode::BININT1: {
|
|
uint8_t value = read<uint8_t>();
|
|
stack_.emplace_back(int64_t(value));
|
|
} break;
|
|
case PickleOpCode::BININT2: {
|
|
uint16_t value = read<uint16_t>();
|
|
stack_.emplace_back(int64_t(value));
|
|
} break;
|
|
case PickleOpCode::BININT: {
|
|
int32_t value = read<int32_t>();
|
|
stack_.emplace_back(int64_t(value));
|
|
} break;
|
|
case PickleOpCode::LONG1: {
|
|
// Only read LONG1s with 8 as the length
|
|
uint8_t length = read<uint8_t>();
|
|
TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length));
|
|
stack_.emplace_back(int64_t(read<int64_t>()));
|
|
} break;
|
|
case PickleOpCode::BINUNICODE: {
|
|
uint32_t length = read<uint32_t>();
|
|
stack_.emplace_back(readBytes(length));
|
|
} break;
|
|
case PickleOpCode::BINFLOAT:
|
|
stack_.emplace_back(readFloat());
|
|
break;
|
|
case PickleOpCode::TUPLE: {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
marks_.pop_back();
|
|
std::vector<IValue> elements;
|
|
const auto tupleSize = stack_.size() - start;
|
|
switch (tupleSize) {
|
|
case 3: {
|
|
auto e3 = pop(stack_);
|
|
auto e2 = pop(stack_);
|
|
auto e1 = pop(stack_);
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(
|
|
std::move(e1), std::move(e2), std::move(e3)));
|
|
break;
|
|
}
|
|
case 2: {
|
|
auto e2 = pop(stack_);
|
|
auto e1 = pop(stack_);
|
|
stack_.emplace_back(
|
|
c10::ivalue::Tuple::create(std::move(e1), std::move(e2)));
|
|
break;
|
|
}
|
|
case 1:
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_)));
|
|
break;
|
|
default: {
|
|
elements.reserve(stack_.size() - start);
|
|
auto start_it = stack_.begin() + start;
|
|
for (auto it = start_it; it != stack_.end(); ++it) {
|
|
elements.emplace_back(std::move(*it));
|
|
}
|
|
stack_.erase(start_it, stack_.end());
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(std::move(elements)));
|
|
break;
|
|
}
|
|
}
|
|
} break;
|
|
case PickleOpCode::TUPLE1: {
|
|
TORCH_CHECK(
|
|
stack_.size() > 0,
|
|
"Parsing error: stack_ contains ",
|
|
stack_.size(),
|
|
" elements, at least 1 expected");
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_)));
|
|
} break;
|
|
case PickleOpCode::TUPLE2: {
|
|
TORCH_CHECK(
|
|
stack_.size() > 1,
|
|
"Parsing error: stack_ contains ",
|
|
stack_.size(),
|
|
" elements, at least 2 expected");
|
|
auto e2 = pop(stack_);
|
|
auto e1 = pop(stack_);
|
|
stack_.emplace_back(
|
|
c10::ivalue::Tuple::create(std::move(e1), std::move(e2)));
|
|
} break;
|
|
case PickleOpCode::TUPLE3: {
|
|
TORCH_CHECK(
|
|
stack_.size() > 2,
|
|
"Parsing error: stack_ contains ",
|
|
stack_.size(),
|
|
" elements, at least 3 expected");
|
|
auto e3 = pop(stack_);
|
|
auto e2 = pop(stack_);
|
|
auto e1 = pop(stack_);
|
|
stack_.emplace_back(c10::ivalue::Tuple::create(
|
|
std::move(e1), std::move(e2), std::move(e3)));
|
|
} break;
|
|
case PickleOpCode::EMPTY_DICT:
|
|
stack_.emplace_back(
|
|
c10::impl::GenericDict(AnyType::get(), AnyType::get()));
|
|
break;
|
|
case PickleOpCode::APPENDS: {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
TORCH_CHECK(
|
|
start > 0 && start <= stack_.size(),
|
|
"Parsing error: wrong start index for stack_");
|
|
auto list_ivalue = stack_.at(start - 1);
|
|
readList(list_ivalue);
|
|
} break;
|
|
case PickleOpCode::LIST: {
|
|
IValue list_ivalue = c10::impl::GenericList(AnyType::get());
|
|
readList(list_ivalue);
|
|
stack_.push_back(std::move(list_ivalue));
|
|
} break;
|
|
case PickleOpCode::DICT: {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
marks_.pop_back();
|
|
auto dict = c10::impl::GenericDict(AnyType::get(), AnyType::get());
|
|
for (size_t i = start; i < stack_.size(); i += 2) {
|
|
dict.insert_or_assign(stack_[i], stack_[i + 1]);
|
|
}
|
|
stack_.erase(stack_.begin() + start, stack_.end());
|
|
stack_.emplace_back(std::move(dict));
|
|
} break;
|
|
case PickleOpCode::SETITEMS: {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
marks_.pop_back();
|
|
TORCH_CHECK(
|
|
start > 0 && start <= stack_.size(),
|
|
"Parsing error: wrong start index for stack_");
|
|
auto dict = stack_.at(start - 1).toGenericDict();
|
|
for (size_t i = start; i < stack_.size(); i += 2) {
|
|
dict.insert_or_assign(stack_[i], stack_[i + 1]);
|
|
}
|
|
stack_.erase(stack_.begin() + start, stack_.end());
|
|
} break;
|
|
case PickleOpCode::BINGET: {
|
|
auto pos = read<uint8_t>();
|
|
TORCH_CHECK(
|
|
memo_table_.size() > pos,
|
|
"Parsing error: out of bounds access at ",
|
|
(size_t)pos,
|
|
" to memo_table_ which is of size ",
|
|
memo_table_.size());
|
|
stack_.push_back(memo_table_.at(pos));
|
|
} break;
|
|
case PickleOpCode::LONG_BINGET: {
|
|
auto pos = read<uint32_t>();
|
|
TORCH_CHECK(
|
|
memo_table_.size() > pos,
|
|
"Parsing error: out of bounds access at ",
|
|
(size_t)pos,
|
|
" to memo_table_ which is of size ",
|
|
memo_table_.size());
|
|
stack_.push_back(memo_table_.at(pos));
|
|
} break;
|
|
case PickleOpCode::STOP:
|
|
break;
|
|
case PickleOpCode::GLOBAL: {
|
|
// Module name, it's not needed for anything
|
|
auto module_name = readString();
|
|
auto class_name = readString();
|
|
readGlobal(module_name, class_name);
|
|
} break;
|
|
case PickleOpCode::NEWOBJ: {
|
|
TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty");
|
|
// pop empty tuple, the actual action is stored in the globals_stack_
|
|
stack_.pop_back();
|
|
} break;
|
|
// because we have NEWOBJ do nothing, BUILD and REDUCE end up doing
|
|
// the same thing
|
|
case PickleOpCode::BUILD:
|
|
case PickleOpCode::REDUCE: {
|
|
// stack is: <functor_idx> <functor_arg>
|
|
// extract <functor_idx> and remove from the stack:
|
|
TORCH_CHECK(
|
|
stack_.size() > 1,
|
|
"Parsing error: stack_ contains ",
|
|
stack_.size(),
|
|
" elements, at least 2 expected");
|
|
std::swap(*(stack_.end() - 2), *(stack_.end() - 1));
|
|
size_t idx = stack_.back().toInt();
|
|
stack_.pop_back();
|
|
// stack is: <functor_arg>
|
|
TORCH_CHECK(
|
|
idx < globals_.size(),
|
|
"Parsing error: out of bounds access to globals_");
|
|
globals_.at(idx)();
|
|
} break;
|
|
case PickleOpCode::BINPERSID: {
|
|
TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty");
|
|
auto tuple = pop(stack_).toTuple();
|
|
const auto& args = tuple->elements();
|
|
AT_ASSERT(
|
|
args.at(0).toStringRef() == "storage",
|
|
"unknown PERSID key ",
|
|
args.at(0).toStringRef());
|
|
at::ScalarType type = args.at(1).toScalarType();
|
|
const std::string& key = args.at(2).toStringRef();
|
|
|
|
at::Device device(args.at(3).toStringRef());
|
|
if (device_) {
|
|
device = *device_;
|
|
}
|
|
|
|
at::Storage storage;
|
|
if (storage_context_ != nullptr && storage_context_->hasStorage(key)) {
|
|
// for torch.package logic where storage may be loaded already
|
|
storage = storage_context_->getStorage(key);
|
|
} else {
|
|
int64_t numel = args.at(4).toInt();
|
|
caffe2::TypeMeta dtype = at::CPU(type).typeMeta();
|
|
|
|
at::DataPtr storage_ptr;
|
|
if (numel > 0) {
|
|
// If there are no elements in the tensor, there's no point in
|
|
// reading a zero (0) byte file from the input stream and paying
|
|
// that cost.
|
|
storage_ptr = read_record_(key);
|
|
}
|
|
|
|
storage = at::Storage(
|
|
c10::Storage::use_byte_size_t(),
|
|
numel * dtype.itemsize(),
|
|
std::move(storage_ptr),
|
|
/*allocator=*/nullptr,
|
|
/*resizable=*/false); // NB: we didn't set any allocator for the
|
|
// tensor
|
|
if (storage_context_ != nullptr) {
|
|
storage_context_->addStorage(key, storage);
|
|
}
|
|
}
|
|
|
|
auto options = at::CPU(type).options();
|
|
if (use_storage_device_) {
|
|
options = options.device(storage.device());
|
|
device = storage.device();
|
|
}
|
|
|
|
at::Tensor tensor;
|
|
if (options.backend() == c10::Backend::QuantizedCPU) {
|
|
tensor = at::_empty_affine_quantized({}, options, 0, 0)
|
|
.set_(storage, 0, {}, {});
|
|
} else {
|
|
tensor = at::empty({0}, options).set_(storage);
|
|
}
|
|
|
|
if (device.is_cuda() || device.is_xpu() || device.is_meta() ||
|
|
device.is_hpu()) {
|
|
tensor = tensor.to(device, tensor.scalar_type());
|
|
} else if (device.type() != DeviceType::CPU) {
|
|
AT_ERROR(
|
|
"supported devices include CPU, CUDA and HPU, however got ",
|
|
DeviceTypeName(device.type(), false));
|
|
}
|
|
stack_.emplace_back(std::move(tensor));
|
|
} break;
|
|
case PickleOpCode::SETITEM: {
|
|
// At this OpCode, stack looks like
|
|
// | Stack Bottom |
|
|
// | ...... |
|
|
// | Dict | -> (stack_size - 3)
|
|
// | Key | -> (stack_size - 2)
|
|
// | Value | -> (stack_size - 1)
|
|
auto stack_size = stack_.size();
|
|
auto dict_pos = stack_size - 3;
|
|
auto key_pos = stack_size - 2;
|
|
auto val_pos = stack_size - 1;
|
|
auto dict = stack_.at(dict_pos).toGenericDict();
|
|
dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos));
|
|
stack_.erase(stack_.begin() + (key_pos), stack_.end());
|
|
} break;
|
|
default: {
|
|
AT_ERROR(
|
|
"Unknown opcode for unpickling at ",
|
|
reinterpret_cast<void*>(opcode),
|
|
": ",
|
|
int(static_cast<uint8_t>(opcode)));
|
|
} break;
|
|
}
|
|
return opcode;
|
|
}
|
|
|
|
void Unpickler::readGlobal(
|
|
const std::string& module_name,
|
|
const std::string& class_name) {
|
|
if (this->skip_next_read_global) {
|
|
// See [NOTE] skip_next_read_global
|
|
this->skip_next_read_global--;
|
|
if (this->skip_next_read_global == 1) {
|
|
// Pass through to the correct handler
|
|
} else if (this->skip_next_read_global == 0) {
|
|
// Corresponds to the type of `Tensor` being unpickled
|
|
if (module_name != "torch" || class_name != "Tensor") {
|
|
TORCH_WARN(
|
|
"Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++");
|
|
}
|
|
stack_.emplace_back(int64_t(globals_.size() - 1));
|
|
return;
|
|
} else {
|
|
TORCH_CHECK(false, "INVALID VALUES")
|
|
}
|
|
}
|
|
// TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this
|
|
// is only here for bc-compatibility reasons
|
|
if (module_name == "__main__") {
|
|
if (class_name == "TensorID") {
|
|
globals_.emplace_back([this] {
|
|
auto setitem_data = stack_.back();
|
|
stack_.pop_back();
|
|
TORCH_INTERNAL_ASSERT(
|
|
!tensor_table_.empty(),
|
|
"Pickler tried to write a tensor but had no tensor table to write to");
|
|
stack_.emplace_back(tensor_table_.at(setitem_data.toInt()));
|
|
});
|
|
} else if (class_name == "IntList") {
|
|
globals_.emplace_back([this] {
|
|
stack_.back().toList().unsafeSetElementType(IntType::get());
|
|
});
|
|
} else {
|
|
AT_ERROR("Unknown pickler class id", class_name);
|
|
}
|
|
} else if (module_name == "torch.jit._pickle") {
|
|
if (class_name == "build_tensor_from_id") {
|
|
globals_.emplace_back([this] {
|
|
// Pop reduce arg off the stack
|
|
auto data = stack_.back().toTupleRef().elements().at(0);
|
|
stack_.pop_back();
|
|
TORCH_CHECK(
|
|
!tensor_table_.empty(),
|
|
"Found a tensor table reference but Unpickler"
|
|
" has no tensor table\n");
|
|
stack_.emplace_back(tensor_table_.at(data.toInt()));
|
|
});
|
|
} else if (class_name == "restore_type_tag") {
|
|
globals_.emplace_back([this] {
|
|
auto tuple = stack_.back().toTuple();
|
|
const auto& data = tuple->elements();
|
|
auto type_str = data.at(1).toStringRef();
|
|
stack_.pop_back();
|
|
TypePtr type = nullptr;
|
|
auto entry = type_cache_.find(type_str);
|
|
if (entry != type_cache_.end()) {
|
|
type = entry->second;
|
|
} else {
|
|
if (type_resolver_ == nullptr) {
|
|
// If we haven't injected a custom way of retrieving types from
|
|
// names, use a barebones type parser.
|
|
type = type_parser_(type_str);
|
|
} else {
|
|
type = type_resolver_(type_str).type_;
|
|
}
|
|
type_cache_[type_str] = type;
|
|
}
|
|
// TODO: Use lookahead to avoid creating the tuple and immediately
|
|
// destroying it here
|
|
restoreContainerTypeTags(data.at(0), type);
|
|
stack_.emplace_back(data.at(0));
|
|
});
|
|
} else {
|
|
TypePtr elem_type = nullptr;
|
|
if (class_name == "build_intlist") {
|
|
elem_type = IntType::get();
|
|
} else if (class_name == "build_tensorlist") {
|
|
elem_type = TensorType::get();
|
|
} else if (class_name == "build_doublelist") {
|
|
elem_type = FloatType::get();
|
|
} else if (class_name == "build_boollist") {
|
|
elem_type = BoolType::get();
|
|
} else {
|
|
AT_ERROR("Unknown pickler class id ", class_name);
|
|
}
|
|
// Unpickle a list specialization (e.g. List[Tensor], List[int], ...)
|
|
globals_.emplace_back([this, elem_type] {
|
|
// Pop reduce arg off the stack
|
|
auto data = stack_.back().toTupleRef().elements().at(0).toList();
|
|
stack_.pop_back();
|
|
data.unsafeSetElementType(elem_type);
|
|
stack_.emplace_back(std::move(data));
|
|
});
|
|
}
|
|
} else if (
|
|
module_name == "torch._utils" &&
|
|
(class_name == "_rebuild_tensor_v2" ||
|
|
class_name == "_rebuild_qtensor")) {
|
|
// Unpickle a tensor
|
|
bool quantized = class_name == "_rebuild_qtensor";
|
|
rebuildTensor(quantized);
|
|
} else if (
|
|
module_name == "torch._tensor" &&
|
|
(class_name == "_rebuild_from_type_v2")) {
|
|
// Unpickle a Tensor with Python attributes or
|
|
// a Subclassed Tensor.
|
|
rebuildTensorFromTypeV2();
|
|
} else if (
|
|
module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") {
|
|
rebuildSparseTensor();
|
|
} else if (module_name == "builtins" && class_name == "complex") {
|
|
globals_.emplace_back([this] {
|
|
auto tuple = pop(stack_).toTuple();
|
|
const auto& elems = tuple->elements();
|
|
AT_ASSERT(elems.size() == 2);
|
|
auto complex =
|
|
c10::complex<double>(elems.at(0).toDouble(), elems.at(1).toDouble());
|
|
stack_.emplace_back(complex);
|
|
});
|
|
|
|
} else if (module_name == "collections" && class_name == "OrderedDict") {
|
|
// collections.OrderedDict is used in tensor serialization for a tensor's
|
|
// backward hooks (but they are not actually saved with this Pickler)
|
|
globals_.emplace_back([this] {
|
|
// drop the Tuple that was argument to OrderedDict, and replace it
|
|
// with None OrderedDicts only appear in tensor deserialization and
|
|
// their value is never used
|
|
stack_.back() = IValue();
|
|
});
|
|
} else if (module_name == "torch" && class_name == "device") {
|
|
globals_.emplace_back([this] {
|
|
auto device_string = stack_.back().toTupleRef().elements().at(0);
|
|
stack_.pop_back();
|
|
stack_.emplace_back(c10::Device(device_string.toStringRef()));
|
|
});
|
|
stack_.emplace_back(int64_t(globals_.size() - 1));
|
|
return;
|
|
} else if (module_name == "torch.distributed.rpc" && class_name == "rref") {
|
|
#ifdef USE_RPC
|
|
return rebuildRRef();
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"RRef unpickling is only supported with the distributed package");
|
|
#endif
|
|
} else if (module_name == "torch") {
|
|
// Try to manually resolve several global enums
|
|
// NOTE: this does not put a global into the global table,
|
|
// like the other branches here because no REDUCE or BUILD will
|
|
// be called on this value. Instead, we just put it on the stack
|
|
// and return early
|
|
c10::optional<c10::ScalarType> scalar_type;
|
|
#define CHECK_SCALAR(_, name) \
|
|
if (class_name == #name "Storage") { \
|
|
scalar_type = c10::k##name; \
|
|
}
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CHECK_SCALAR)
|
|
#undef CHECK_SCALAR
|
|
if (scalar_type.has_value()) {
|
|
stack_.emplace_back(int64_t(*scalar_type));
|
|
return;
|
|
}
|
|
|
|
c10::optional<at::QScheme> qscheme;
|
|
for (int i = 0; i < at::COMPILE_TIME_NUM_QSCHEMES; ++i) {
|
|
if (class_name == toString(static_cast<at::QScheme>(i))) {
|
|
qscheme = static_cast<at::QScheme>(i);
|
|
}
|
|
}
|
|
if (qscheme.has_value()) {
|
|
stack_.emplace_back(int64_t(*qscheme));
|
|
return;
|
|
}
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unpickler found unknown torch global, 'torch.",
|
|
class_name,
|
|
"'");
|
|
} else {
|
|
TORCH_CHECK(
|
|
type_resolver_,
|
|
"Unpickler found unknown type ",
|
|
module_name,
|
|
".",
|
|
class_name);
|
|
at::StrongTypePtr type =
|
|
type_resolver_(c10::QualifiedName(module_name, class_name));
|
|
if (auto enum_type = type.type_->cast<c10::EnumType>()) {
|
|
globals_.emplace_back([this, enum_type] {
|
|
auto val = stack_.back();
|
|
stack_.pop_back();
|
|
for (const auto& p : enum_type->enumNamesValues()) {
|
|
if (p.second == val) {
|
|
auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
|
|
enum_type, p.first, p.second);
|
|
stack_.emplace_back(std::move(enum_holder));
|
|
return;
|
|
}
|
|
}
|
|
});
|
|
} else {
|
|
// Otherwise, global is a class/object type.
|
|
globals_.emplace_back([this, type] {
|
|
auto val = stack_.back();
|
|
stack_.pop_back();
|
|
auto obj = obj_loader_(type, val);
|
|
stack_.emplace_back(std::move(obj));
|
|
});
|
|
}
|
|
}
|
|
stack_.emplace_back(int64_t(globals_.size() - 1));
|
|
}
|
|
|
|
void Unpickler::rebuildSparseTensor() {
|
|
globals_.emplace_back([this] {
|
|
auto tup = pop(stack_).toTuple();
|
|
const auto& elements = tup->elements();
|
|
size_t idx = 0;
|
|
auto layout = elements.at(idx++).toInt();
|
|
at::Tensor result;
|
|
switch (layout) {
|
|
case static_cast<int>(c10::Layout::Sparse): {
|
|
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
|
|
bool requires_grad = elements.at(idx++).toBool();
|
|
auto& indices_tensor = elements.at(idx++).toTensor();
|
|
auto& values_tensor = elements.at(idx++).toTensor();
|
|
auto options = values_tensor.options()
|
|
.layout(c10::Layout::Sparse)
|
|
.requires_grad(requires_grad);
|
|
result = at::_sparse_coo_tensor_unsafe(
|
|
indices_tensor, values_tensor, size, options);
|
|
result = autograd::make_variable(result, options.requires_grad());
|
|
break;
|
|
}
|
|
case static_cast<int>(c10::Layout::SparseCsr): {
|
|
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
|
|
bool requires_grad = elements.at(idx++).toBool();
|
|
auto& crow_indices = elements.at(idx++).toTensor();
|
|
auto& col_indices = elements.at(idx++).toTensor();
|
|
auto& values_tensor = elements.at(idx++).toTensor();
|
|
auto options = values_tensor.options()
|
|
.layout(c10::Layout::SparseCsr)
|
|
.requires_grad(requires_grad);
|
|
result = at::_sparse_csr_tensor_unsafe(
|
|
crow_indices, col_indices, values_tensor, size, options);
|
|
result =
|
|
autograd::make_variable(std::move(result), options.requires_grad());
|
|
break;
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unsupported sparse tensor layout type in serialization ",
|
|
static_cast<c10::Layout>(layout));
|
|
break;
|
|
}
|
|
stack_.emplace_back(std::move(result));
|
|
});
|
|
}
|
|
|
|
void Unpickler::rebuildTensor(bool quantized) {
|
|
globals_.emplace_back([this, quantized] {
|
|
auto tup = pop(stack_).toTuple();
|
|
const auto& elements = tup->elements();
|
|
size_t idx = 0;
|
|
auto& storage_tensor = elements.at(idx++).toTensor();
|
|
int64_t storage_offset = elements.at(idx++).toInt();
|
|
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
|
|
std::vector<int64_t> stride = tupleToIntList(elements.at(idx++));
|
|
at::Tensor result;
|
|
if (quantized) {
|
|
auto qparams_tuple = elements.at(idx++).toTuple();
|
|
const auto& qparams = qparams_tuple->elements();
|
|
auto qscheme = static_cast<at::QScheme>(qparams.at(0).toInt());
|
|
switch (qscheme) {
|
|
case at::kPerTensorAffine: {
|
|
double q_scale = qparams.at(1).toDouble();
|
|
int64_t q_zero_point = qparams.at(2).toInt();
|
|
result = at::_empty_affine_quantized(
|
|
{0}, storage_tensor.options(), q_scale, q_zero_point);
|
|
} break;
|
|
case at::kPerChannelAffineFloatQParams:
|
|
case at::kPerChannelAffine: {
|
|
const auto& scales = qparams.at(1).toTensor();
|
|
const auto& zero_points = qparams.at(2).toTensor();
|
|
int64_t axis = qparams.at(3).toInt();
|
|
result = at::_empty_per_channel_affine_quantized(
|
|
{0}, scales, zero_points, axis, storage_tensor.options());
|
|
} break;
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unsupported tensor quantization type in serialization ",
|
|
toString(qscheme));
|
|
break;
|
|
}
|
|
} else {
|
|
result = at::empty({0}, storage_tensor.options());
|
|
}
|
|
bool requires_grad = elements.at(idx++).toBool();
|
|
idx++; // backwards hooks is empty
|
|
at::TensorImpl* impl = result.unsafeGetTensorImpl();
|
|
impl->set_storage_keep_dtype(storage_tensor.storage());
|
|
impl->set_storage_offset(storage_offset);
|
|
impl->set_sizes_and_strides(size, stride);
|
|
result = autograd::make_variable(result, requires_grad);
|
|
|
|
// Handle if math_bits were pickled.
|
|
// See `args` of _reduce_ex_internal
|
|
// for a regular tensor (final else case).
|
|
// Tensors pickled before this patch didn't
|
|
// have this argument for storing MathBits,
|
|
// in that case, we do nothing.
|
|
// NOTE: `math_bits` is the 7th arg.
|
|
// NOTE: This is only meant for regular tensor and not quantized
|
|
// which also has 7 args serialized.
|
|
if (!quantized && elements.size() == 7) {
|
|
auto math_bits = elements.at(idx++).toGenericDict();
|
|
torch::jit::setTensorMetadata(result, math_bits);
|
|
}
|
|
|
|
stack_.emplace_back(std::move(result));
|
|
});
|
|
}
|
|
|
|
void Unpickler::rebuildTensorFromTypeV2() {
|
|
// [NOTE] skip_next_read_global
|
|
// When rebuilding Tensor with Python Attr or Subclassed Tensor,
|
|
// we receive `(func, type(self), args, state)` on stack for
|
|
// `rebuildTensorFromTypeV2`.
|
|
// Thus next call to readGlobal corresponds to `func` which is
|
|
// the function to rebuild the base tensor.
|
|
// The call after `func` to readGlobal corresponds to `type` of the
|
|
// Tensor where we raise warning if the type is not `torch.Tensor`.
|
|
this->skip_next_read_global = 2;
|
|
auto curr_globals_idx = globals_.size();
|
|
globals_.emplace_back([this, curr_globals_idx] {
|
|
// args is a tuple with following data
|
|
// (function to rebuild base tensor, type of tensor,
|
|
// arguments to construct base tensor, Python State (as dict))
|
|
auto args = pop(stack_).toTuple();
|
|
size_t tup_idx = 0;
|
|
const auto args_elems = args->elements();
|
|
auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple();
|
|
auto py_state = args_elems.at(tup_idx + 3).toGenericDict();
|
|
if (!py_state.empty()) {
|
|
TORCH_WARN(
|
|
"Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded");
|
|
}
|
|
// This calls the function to rebuild the
|
|
// base tensor.
|
|
// Eg. `rebuildTensor`, `rebuildSpareTensor`.
|
|
stack_.emplace_back(base_tensor_args);
|
|
globals_[curr_globals_idx + 1]();
|
|
stack_.emplace_back(pop(stack_));
|
|
});
|
|
}
|
|
|
|
#ifdef USE_RPC
|
|
void Unpickler::rebuildRRef() {
|
|
globals_.emplace_back([this] {
|
|
// It is the same as how rref is unpickled in python,
|
|
// see PyRRef::unpickle
|
|
auto tuple = std::move(stack_.back()).toTuple();
|
|
const auto& args = tuple->elements();
|
|
stack_.pop_back();
|
|
TORCH_INTERNAL_ASSERT(
|
|
args.size() == distributed::rpc::RFD_TUPLE_SIZE,
|
|
"Pickled RRefForkData must contain 7 numbers.");
|
|
auto ownerId =
|
|
static_cast<int16_t>(args.at(distributed::rpc::OWNER_IDX).toInt());
|
|
// const reference will extend the lifetime of the temporary variable
|
|
const auto& rrefId = distributed::rpc::RRefId(
|
|
static_cast<int16_t>(args.at(distributed::rpc::RREFID_ON_IDX).toInt()),
|
|
static_cast<int64_t>(args.at(distributed::rpc::RREFID_ID_IDX).toInt()));
|
|
const auto& forkId = distributed::rpc::RRefId(
|
|
static_cast<int16_t>(args.at(distributed::rpc::FORKID_ON_IDX).toInt()),
|
|
static_cast<int64_t>(args.at(distributed::rpc::FORKID_ID_IDX).toInt()));
|
|
auto parent =
|
|
static_cast<int16_t>(args.at(distributed::rpc::PARENT_IDX).toInt());
|
|
const auto& typeStr = static_cast<std::string>(
|
|
args.at(distributed::rpc::TYPE_IDX).toStringRef());
|
|
auto rrefForkData = distributed::rpc::RRefForkData(
|
|
ownerId, rrefId, forkId, parent, typeStr);
|
|
auto& ctx = distributed::rpc::RRefContext::getInstance();
|
|
c10::intrusive_ptr<distributed::rpc::RRef> rref;
|
|
TORCH_INTERNAL_ASSERT(
|
|
type_resolver_ != nullptr, "type_resolver_ is nullptr.");
|
|
at::StrongTypePtr type = type_resolver_(c10::QualifiedName(typeStr));
|
|
rref = ctx.getOrCreateRRef(rrefForkData, type.type_);
|
|
ctx.notifyOwnerAndParentOfFork(
|
|
rrefForkData.forkId_, rrefForkData.parent_, rref);
|
|
stack_.emplace_back(
|
|
c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref));
|
|
});
|
|
stack_.emplace_back(int64_t(globals_.size() - 1));
|
|
return;
|
|
}
|
|
#endif
|
|
|
|
void Unpickler::readSlowWithBuffer(char* dest, size_t sz) {
|
|
// First, read any partial from buffer (may be 0).
|
|
// We explicitly assume that sz > buffer_remaining_,
|
|
// and that sz is never bigger than buffer_.size().
|
|
AT_ASSERT(sz > buffer_remaining_);
|
|
const size_t from_old_buf = buffer_remaining_;
|
|
if (from_old_buf != 0) {
|
|
memcpy(dest, buffer_.data() + buffer_pos_, from_old_buf);
|
|
}
|
|
const size_t needed = sz - from_old_buf;
|
|
// Full read into the buffer. The calls here all explicitly
|
|
// assume that one buffer will be enough for any sz.
|
|
AT_ASSERT(sz <= buffer_.size());
|
|
buffer_remaining_ = reader_(buffer_.data(), buffer_.size());
|
|
if (buffer_remaining_ < needed) {
|
|
AT_ERROR("Unexpected end of pickler archive.");
|
|
}
|
|
memcpy(dest + from_old_buf, buffer_.data(), needed);
|
|
buffer_pos_ = needed; // assignment (0'ed from read)
|
|
buffer_remaining_ -= needed;
|
|
}
|
|
|
|
// Read a number of bytes from the input stream
|
|
std::string Unpickler::readBytes(size_t length) {
|
|
std::string data;
|
|
static const size_t kSmallString = 64;
|
|
if (length <= buffer_remaining_) {
|
|
// Fast-path: entirely in buffer.
|
|
data.assign(buffer_.data() + buffer_pos_, length);
|
|
buffer_pos_ += length;
|
|
buffer_remaining_ -= length;
|
|
} else if (length <= kSmallString) {
|
|
// If the string is smallish, do a full buffer read,
|
|
// and read out of that buffer.
|
|
data.resize(length);
|
|
readSlowWithBuffer(&data[0], length);
|
|
} else {
|
|
// Otherwise, for larger strings, read what we can from
|
|
// the buffer, and then read directly to the destination.
|
|
const size_t from_old_buf = buffer_remaining_;
|
|
if (from_old_buf != 0) {
|
|
data.reserve(length);
|
|
data.append(buffer_.data() + buffer_pos_, from_old_buf);
|
|
}
|
|
data.resize(length);
|
|
const size_t needed = length - from_old_buf;
|
|
size_t nread = reader_(&data[from_old_buf], needed);
|
|
if (nread != needed) {
|
|
AT_ERROR("Unexpected end of pickler archive.");
|
|
}
|
|
buffer_remaining_ = 0;
|
|
// buffer_pos_ has no meaning with buffer_remaining_ == 0.
|
|
}
|
|
return data;
|
|
}
|
|
|
|
// Pop all the list items off of the stack and append them to the list at
|
|
// the corresponding MARK
|
|
void Unpickler::readList(IValue list_ivalue) {
|
|
TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
|
|
size_t start = marks_.back();
|
|
marks_.pop_back();
|
|
auto num_elements = stack_.size() - start;
|
|
auto elements = c10::ArrayRef<IValue>(stack_).slice(start);
|
|
if (list_ivalue.isIntList()) {
|
|
auto list = std::move(list_ivalue).toIntList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.emplace_back(elem.toInt());
|
|
}
|
|
} else if (list_ivalue.isTensorList()) {
|
|
auto list = std::move(list_ivalue).toTensorList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.emplace_back(elem.toTensor());
|
|
}
|
|
} else if (list_ivalue.isDoubleList()) {
|
|
auto list = std::move(list_ivalue).toDoubleList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.emplace_back(elem.toDouble());
|
|
}
|
|
} else if (list_ivalue.isBoolList()) {
|
|
auto list = std::move(list_ivalue).toBoolList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.push_back(elem.toBool());
|
|
}
|
|
} else if (list_ivalue.isList()) {
|
|
auto list = std::move(list_ivalue).toList();
|
|
list.reserve(num_elements);
|
|
for (const auto& elem : elements) {
|
|
list.emplace_back(elem);
|
|
}
|
|
} else {
|
|
AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind());
|
|
}
|
|
|
|
stack_.erase(stack_.begin() + start, stack_.end());
|
|
}
|
|
|
|
inline bool is_valid_python_id_char(char c) {
|
|
return c == '_' || c == '.' || (c >= '0' && c <= '9') ||
|
|
(c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
|
|
}
|
|
|
|
// Read a newline terminated string
|
|
std::string Unpickler::readString() {
|
|
std::string ss;
|
|
while (true) {
|
|
auto* const bufferStart = buffer_.data() + buffer_pos_;
|
|
const auto bufferLeft = buffer_.size() - buffer_pos_;
|
|
char* const newlinePtr =
|
|
static_cast<char*>(memchr(bufferStart, '\n', bufferLeft));
|
|
if (newlinePtr) {
|
|
// read up to newline and we are done.
|
|
auto const charsRead = newlinePtr - bufferStart;
|
|
ss.append(bufferStart, charsRead);
|
|
buffer_remaining_ -= charsRead + 1;
|
|
buffer_pos_ += charsRead + 1;
|
|
break;
|
|
} else {
|
|
// read whole buffer, refill
|
|
for (const char* p = bufferStart; p < bufferStart + bufferLeft; ++p) {
|
|
// Simple check just in case there is no terminating '\n'
|
|
TORCH_CHECK(
|
|
is_valid_python_id_char(*p),
|
|
"Found character '",
|
|
int(uint8_t(*p)),
|
|
"' in string, ",
|
|
"strings must be qualified Python identifiers");
|
|
}
|
|
ss.append(bufferStart, bufferLeft);
|
|
buffer_remaining_ = reader_(buffer_.data(), buffer_.size());
|
|
buffer_pos_ = 0;
|
|
}
|
|
}
|
|
return ss;
|
|
}
|
|
|
|
} // namespace torch::jit
|