Files
pytorch/torch/csrc/jit/serialization/pickle.cpp
2024-11-04 13:43:16 +00:00

220 lines
6.2 KiB
C++

#include <torch/csrc/jit/serialization/pickle.h>
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/import_read.h>
namespace torch::jit {
namespace {
c10::StrongTypePtr customClassResolver(const c10::QualifiedName& qn) {
at::TypePtr type = nullptr;
if (c10::QualifiedName("__torch__").isPrefixOf(qn)) {
type = torch::getCustomClass(qn.qualifiedName());
} else {
// This is a regular type, fall back to the default type parser
torch::jit::ScriptTypeParser parser;
type = parser.parseType(qn.qualifiedName());
return c10::StrongTypePtr(nullptr, std::move(type));
}
if (type == nullptr) {
TORCH_CHECK(
false,
"Couldn't resolve type '{}', did you forget to add its build dependency?",
qn.qualifiedName());
}
// Passing nullptr is a little bit sus, but should be fine:
// 1. The lifetime of the class type is not tied to a specific
// CompilationUnit
// but rather the global custom class registry.
// 2. We will not access the `cu_` field and immediately discard this
// StrongTypePtr post-deserialization.
return c10::StrongTypePtr(nullptr, std::move(type));
}
} // namespace
void pickle(
std::function<void(const char* data_start, size_t data_len)> writer,
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table) {
Pickler pickler(std::move(writer), tensor_table, nullptr, nullptr);
pickler.protocol();
pickler.pushIValue(ivalue);
pickler.stop();
}
std::vector<char> pickle(
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table) {
std::vector<char> data;
pickle(
[&](const char* bytes, size_t len) {
data.insert(data.end(), bytes, bytes + len);
},
ivalue,
tensor_table);
return data;
}
// This has to live here instead of the C++ API to mirror torch.save since the
// mobile build excludes the C++ API
std::vector<char> pickle_save(const at::IValue& ivalue) {
#ifndef C10_MOBILE
// Pickle the IValue into an array of bytes
std::vector<char> pickle_data;
Pickler pickler([&](const char* buf, size_t size) {
pickle_data.insert(pickle_data.end(), buf, buf + size);
});
pickler.protocol();
pickler.pushIValue(ivalue);
pickler.stop();
std::vector<char> container_data;
container_data.reserve(pickle_data.size());
caffe2::serialize::PyTorchStreamWriter writer(
[&](const void* void_bytes, size_t len) {
const char* bytes = reinterpret_cast<const char*>(void_bytes);
container_data.insert(container_data.end(), bytes, bytes + len);
return len;
});
// Write the generated bytes and the associated tensors into a data.pkl file
// and data/0, data/1, data/2... files for each of the tensors
writeArchiveAndTensors(
"data",
pickle_data.data(),
pickle_data.size(),
pickler.tensorData(),
writer);
return container_data;
#else
TORCH_CHECK(
false,
"pickle_save not supported on mobile "
"(see https://github.com/pytorch/pytorch/pull/30108)");
#endif
}
#ifndef C10_MOBILE
size_t VectorReader::read(uint64_t pos, void* buf, size_t n, const char* what)
const {
std::copy(
data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
return n;
}
size_t StringViewReader::read(
uint64_t pos,
void* buf,
size_t n,
const char* what) const {
std::copy(
data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
return n;
}
#endif
IValue pickle_load(const std::vector<char>& data) {
// Read in the pickle data
#ifndef C10_MOBILE
caffe2::serialize::PyTorchStreamReader reader(
std::make_unique<VectorReader>(data));
return readArchiveAndTensors(
"data",
/*pickle_prefix=*/"",
/*tensor_prefix=*/"",
/*type_resolver=*/std::nullopt,
/*obj_loader=*/std::nullopt,
/*device=*/std::nullopt,
reader);
#else
TORCH_CHECK(
false,
"pickle_load not supported on mobile "
"(see https://github.com/pytorch/pytorch/pull/30108)");
#endif
}
// A specialized version of pickle_load that can load custom objects.
c10::IValue pickle_load_obj(std::string_view data) {
#ifndef C10_MOBILE
caffe2::serialize::PyTorchStreamReader reader(
std::make_unique<torch::jit::StringViewReader>(data));
return torch::jit::readArchiveAndTensors(
"data",
/*pickle_prefix=*/"",
/*tensor_prefix=*/"",
/*type_resolver=*/customClassResolver,
/*obj_loader=*/torch::jit::ObjLoaderFunc,
/*device=*/std::nullopt,
reader);
#else
TORCH_CHECK(
false,
"pickle_load not supported on mobile "
"(see https://github.com/pytorch/pytorch/pull/30108)");
#endif
}
IValue unpickle(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&),
ObjLoader obj_loader) {
Unpickler unpickler(
std::move(reader),
std::move(type_resolver),
tensor_table,
std::move(obj_loader),
type_parser);
return unpickler.parse_ivalue();
}
IValue unpickle(
const char* data,
size_t size,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&)) {
return unpickle(
data, size, nullptr, std::move(type_resolver), tensor_table, type_parser);
}
IValue unpickle(
const char* data,
size_t size,
ObjLoader obj_loader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&)) {
size_t bytes_read = 0;
return unpickle(
[&](char* buffer, size_t len) -> size_t {
if (bytes_read >= size) {
return 0;
}
len = std::min(size - bytes_read, len);
// Copy len bytes into buffer
const char* start = data + bytes_read;
std::memcpy(buffer, start, len);
bytes_read += len;
return len;
},
std::move(type_resolver),
tensor_table,
type_parser,
std::move(obj_loader));
}
} // namespace torch::jit