mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix the Problems About Defining Static Variable in Inline Function (#147095)
Refer to https://github.com/pytorch/pytorch/issues/125465 for more informations - Remove unused header files - Move common functionality to separate files to reduce dependencies between picklers and unpicklers - Move the inline function that defines the static variable to .cc Differential Revision: [D76266755](https://our.internmc.facebook.com/intern/diff/D76266755) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147095 Approved by: https://github.com/cyyever, https://github.com/albanD Co-authored-by: Edward Yang <ezyang@meta.com>
This commit is contained in:
@ -1244,6 +1244,7 @@ def define_buck_targets(
|
||||
"torch/csrc/jit/mobile/parse_operators.cpp",
|
||||
"torch/csrc/jit/mobile/upgrader_mobile.cpp",
|
||||
"torch/csrc/jit/serialization/import_read.cpp",
|
||||
"torch/csrc/jit/serialization/pickler_helper.cpp",
|
||||
"torch/csrc/jit/serialization/unpickler.cpp",
|
||||
],
|
||||
header_namespace = "",
|
||||
|
@ -89,6 +89,7 @@ core_sources_common = [
|
||||
|
||||
torch_unpickler_common = [
|
||||
"torch/csrc/jit/serialization/import_read.cpp",
|
||||
"torch/csrc/jit/serialization/pickler_helper.cpp",
|
||||
"torch/csrc/jit/serialization/unpickler.cpp",
|
||||
]
|
||||
|
||||
@ -637,6 +638,7 @@ libtorch_lite_eager_symbolication = [
|
||||
# Later we can split serialization and deserialization logic
|
||||
# to have better separation within build and only build relevant parts.
|
||||
"torch/csrc/jit/serialization/pickle.cpp",
|
||||
"torch/csrc/jit/serialization/pickler_helper.cpp",
|
||||
"torch/csrc/jit/serialization/pickler.cpp",
|
||||
"torch/csrc/jit/serialization/unpickler.cpp",
|
||||
]
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/serialize.h>
|
||||
|
||||
#include <vector>
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
namespace torch::distributed::rpc {
|
||||
|
||||
class TORCH_API PythonRemoteCall : public RpcCommandBase {
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::distributed::rpc {
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include <torch/csrc/distributed/rpc/script_call.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::distributed::rpc {
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
|
||||
namespace torch::distributed::rpc {
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
||||
#include <torch/csrc/jit/serialization/onnx.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/onnx/back_compat.h>
|
||||
#include <torch/csrc/onnx/onnx.h>
|
||||
#include <torch/version.h>
|
||||
|
@ -5,7 +5,6 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/jit/serialization/python_print.h>
|
||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
||||
|
@ -1,20 +1,20 @@
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/Dict.h>
|
||||
#ifdef USE_RPC
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#endif
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/utils/byte_order.h>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#ifdef USE_RPC
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#endif
|
||||
|
||||
namespace torch::jit {
|
||||
|
||||
using ::c10::IValue;
|
||||
|
||||
// Protocol 2 is the highest that can be decoded by Python 2
|
||||
// See https://docs.python.org/3/library/pickle.html#data-stream-format
|
||||
constexpr static uint8_t PROTOCOL_VERSION = 2;
|
||||
@ -719,92 +719,4 @@ void Pickler::pushTuple(const IValue& ivalue) {
|
||||
}
|
||||
}
|
||||
|
||||
WriteableTensorData getWriteableTensorData(
|
||||
const at::Tensor& tensor,
|
||||
bool to_cpu) {
|
||||
WriteableTensorData result;
|
||||
result.tensor_ = tensor;
|
||||
result.size_ = tensor.storage().nbytes();
|
||||
// TODO HIP support
|
||||
if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) {
|
||||
// NB: This new tensor is created to support cuda tensors.
|
||||
// Storages can be mutated when converting tensors from cuda to cpu,
|
||||
// and we need a cpu tensor to copy data from.
|
||||
result.tensor_ =
|
||||
at::empty({0}, tensor.options())
|
||||
.set_(
|
||||
tensor.storage(),
|
||||
/* storage_offset = */ 0,
|
||||
/* size = */
|
||||
{static_cast<int64_t>(
|
||||
tensor.storage().nbytes() / tensor.element_size())},
|
||||
/* stride = */ {1})
|
||||
.cpu();
|
||||
TORCH_CHECK(
|
||||
result.tensor_.storage().nbytes() == result.size_,
|
||||
"Storage tensor size did not match record size");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
|
||||
// Check that the schemas for __getstate__ and __setstate__ are correct
|
||||
auto getstate = cls->findMethod("__getstate__");
|
||||
if (getstate == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto get_schema = getstate->getSchema();
|
||||
|
||||
// Check __getstate__
|
||||
// __getstate__ is expected to be (self) -> T
|
||||
TORCH_CHECK(
|
||||
get_schema.arguments().size() == 1,
|
||||
"'__getstate__' must have 'self' as its only argument, but found ",
|
||||
get_schema.arguments().size(),
|
||||
" arguments");
|
||||
TORCH_CHECK(
|
||||
get_schema.returns().size() == 1,
|
||||
"'__getstate__' must return 1 value, but found ",
|
||||
get_schema.returns().size());
|
||||
|
||||
// Check __setstate__ if the method exists
|
||||
// __setstate__ is expected to be (self, T) -> None
|
||||
auto setstate = cls->findMethod("__setstate__");
|
||||
if (!setstate) {
|
||||
return false;
|
||||
}
|
||||
auto set_schema = setstate->getSchema();
|
||||
|
||||
TORCH_CHECK(
|
||||
set_schema.arguments().size() == 2,
|
||||
"'__setstate__' must have 'self' and the state as its "
|
||||
"only arguments, but found ",
|
||||
set_schema.arguments().size(),
|
||||
" arguments");
|
||||
TORCH_CHECK(
|
||||
set_schema.returns().size() == 1,
|
||||
"'__setstate__' must return None, but found ",
|
||||
set_schema.returns().size(),
|
||||
" return values");
|
||||
TORCH_CHECK(
|
||||
set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
|
||||
"'__setstate__' must return None, but found value of type",
|
||||
set_schema.returns().at(0).type()->annotation_str());
|
||||
|
||||
// Check that the return type of __getstate__ matches the input to
|
||||
// __setstate__
|
||||
auto get_type = get_schema.returns().at(0).type();
|
||||
auto set_type = set_schema.arguments().at(1).type();
|
||||
|
||||
TORCH_CHECK(
|
||||
get_type->isSubtypeOf(*set_type),
|
||||
"'__getstate__'s return type (",
|
||||
get_type->annotation_str(),
|
||||
") does not match '__setstate__'s argument type (",
|
||||
set_type->annotation_str(),
|
||||
")");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace torch::jit
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
@ -9,112 +8,17 @@
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/serialization/pickler_helper.h>
|
||||
|
||||
namespace torch::jit {
|
||||
|
||||
// See Python's pickletools.py for a detailed description of each of these codes
|
||||
enum class PickleOpCode : char {
|
||||
MARK = '(',
|
||||
STOP = '.',
|
||||
POP = '0',
|
||||
POP_MARK = '1',
|
||||
DUP = '2',
|
||||
FLOAT = 'F',
|
||||
INT = 'I',
|
||||
BININT = 'J',
|
||||
BININT1 = 'K',
|
||||
LONG = 'L',
|
||||
BININT2 = 'M',
|
||||
NONE = 'N',
|
||||
PERSID = 'P',
|
||||
BINPERSID = 'Q',
|
||||
REDUCE = 'R',
|
||||
STRING = 'S',
|
||||
BINSTRING = 'T',
|
||||
SHORT_BINSTRING = 'U',
|
||||
// NB: Avoid using UNICODE as it is a macro in the Windows API
|
||||
UNICODE_ = 'V',
|
||||
BINUNICODE = 'X',
|
||||
APPEND = 'a',
|
||||
BUILD = 'b',
|
||||
GLOBAL = 'c',
|
||||
DICT = 'd',
|
||||
EMPTY_DICT = '}',
|
||||
APPENDS = 'e',
|
||||
GET = 'g',
|
||||
BINGET = 'h',
|
||||
INST = 'i',
|
||||
LONG_BINGET = 'j',
|
||||
LIST = 'l',
|
||||
EMPTY_LIST = ']',
|
||||
OBJ = 'o',
|
||||
PUT = 'p',
|
||||
BINPUT = 'q',
|
||||
LONG_BINPUT = 'r',
|
||||
SETITEM = 's',
|
||||
TUPLE = 't',
|
||||
EMPTY_TUPLE = ')',
|
||||
SETITEMS = 'u',
|
||||
BINFLOAT = 'G',
|
||||
|
||||
// Protocol 2
|
||||
PROTO = char('\x80'),
|
||||
NEWOBJ = '\x81',
|
||||
EXT1 = '\x82',
|
||||
EXT2 = '\x83',
|
||||
EXT4 = '\x84',
|
||||
TUPLE1 = '\x85',
|
||||
TUPLE2 = '\x86',
|
||||
TUPLE3 = '\x87',
|
||||
NEWTRUE = '\x88',
|
||||
NEWFALSE = '\x89',
|
||||
LONG1 = '\x8a',
|
||||
LONG4 = '\x8b',
|
||||
|
||||
// Protocol 3 (Python 3.x)
|
||||
BINBYTES = 'B',
|
||||
SHORT_BINBYTES = 'C',
|
||||
|
||||
// Protocol 4
|
||||
SHORT_BINUNICODE = char('\x8c'),
|
||||
BINUNICODE8 = '\x8d',
|
||||
BINBYTES8 = '\x8e',
|
||||
EMPTY_SET = '\x8f',
|
||||
ADDITEMS = '\x90',
|
||||
FROZENSET = '\x91',
|
||||
NEWOBJ_EX = '\x92',
|
||||
STACK_GLOBAL = '\x93',
|
||||
MEMOIZE = '\x94',
|
||||
FRAME = '\x95'
|
||||
};
|
||||
|
||||
using ::c10::IValue;
|
||||
|
||||
struct WriteableTensorData {
|
||||
const char* data() const {
|
||||
return static_cast<const char*>(tensor_.storage().data());
|
||||
}
|
||||
size_t sizeInBytes() const {
|
||||
return size_;
|
||||
}
|
||||
size_t nbytes() const {
|
||||
return tensor_.storage().nbytes();
|
||||
}
|
||||
bool storageHasDeleter() const {
|
||||
return tensor_.storage().data_ptr().get_context() != nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
friend TORCH_API WriteableTensorData
|
||||
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu);
|
||||
at::Tensor tensor_;
|
||||
uint64_t size_;
|
||||
};
|
||||
|
||||
class TORCH_API Pickler {
|
||||
AT_DISALLOW_COPY_AND_ASSIGN(Pickler);
|
||||
|
||||
@ -278,142 +182,4 @@ class TORCH_API Pickler {
|
||||
bool tag_aggregates_;
|
||||
};
|
||||
|
||||
// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
|
||||
// if it was CUDA and to_cpu is True.
|
||||
TORCH_API WriteableTensorData
|
||||
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true);
|
||||
|
||||
// if the cls has __getstate__/__setstate__
|
||||
// assert they have the right schema and return true,
|
||||
// otherwise return false
|
||||
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
|
||||
|
||||
// Declare BackendMeta serialization and deserialization function pointer types.
|
||||
using BackendMetaPtr = std::function<
|
||||
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
|
||||
|
||||
// A allowlist of device type, currently available is PrivateUse1
|
||||
inline std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
|
||||
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
|
||||
c10::DeviceType::PrivateUse1};
|
||||
return DeviceTypeAllowlist;
|
||||
}
|
||||
|
||||
// Dynamically obtain serialization function pairs
|
||||
// that require the corresponding backend.
|
||||
inline std::array<
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
||||
GetBackendMetaSerialization() {
|
||||
// The array to save function pointer for BackendMeta serialization.
|
||||
// key is the DeviceType, value is std::pair obj.
|
||||
// value.first represent get function and value.seconde represent set function
|
||||
static std::array<
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
BackendMetaSerialization;
|
||||
return BackendMetaSerialization;
|
||||
}
|
||||
|
||||
// Register function pointer of Tensor BackendMetadata for serialization.
|
||||
TORCH_API inline void TensorBackendMetaRegistry(
|
||||
c10::DeviceType t,
|
||||
const BackendMetaPtr& get_fptr,
|
||||
const BackendMetaPtr& set_fptr) {
|
||||
// allowlist verification
|
||||
// Only if the devicetype is in the allowlist,
|
||||
// we allow the serialization extension to be registered for backendmeta data.
|
||||
const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
|
||||
TORCH_CHECK(
|
||||
DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
|
||||
"It is not allowed to register the serialization method ",
|
||||
"of backendMeta data for PrivateUse1. ",
|
||||
"If you have related serialization requirements, ",
|
||||
"please expand the allowlist");
|
||||
// Register function pointer
|
||||
int device_type = static_cast<int>(t);
|
||||
auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
||||
TORCH_CHECK(
|
||||
!BackendMetaSerialization[device_type].has_value(),
|
||||
"The tensor BackendMeta serialization function pointer for ",
|
||||
t,
|
||||
" has been registered.");
|
||||
BackendMetaSerialization[device_type] =
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>(
|
||||
std::make_pair(get_fptr, set_fptr));
|
||||
}
|
||||
|
||||
// Return a map of Tensor Metadata which including BackendMetaData for
|
||||
// serialization. For now, it only takes care of `conj` and `neg` bit.
|
||||
inline std::unordered_map<std::string, bool> getTensorMetadata(
|
||||
const at::Tensor& t) {
|
||||
// We don't support serializing `ZeroTensor` as it is not public
|
||||
// facing yet.
|
||||
TORCH_CHECK(
|
||||
!t._is_zerotensor(),
|
||||
"ZeroTensor is not serializable,",
|
||||
" please file an issue if required.");
|
||||
std::unordered_map<std::string, bool> metadata{};
|
||||
|
||||
// Only add meta-data if the value is not default.
|
||||
if (t.is_conj()) {
|
||||
metadata["conj"] = true;
|
||||
}
|
||||
if (t.is_neg()) {
|
||||
metadata["neg"] = true;
|
||||
}
|
||||
// Only add BackendMetaData for custom backend if the function pointer is
|
||||
// registered.
|
||||
int device_type = static_cast<int>(t.device().type());
|
||||
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
||||
if (BackendMetaSerialization[device_type].has_value()) {
|
||||
// Pass the tensor and metadata map references as parameters to the custom
|
||||
// serialization function.
|
||||
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first;
|
||||
fptr(t, metadata);
|
||||
}
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// set Tensor Metadata based on the map.
|
||||
// Refer: getTensorMetadata
|
||||
inline void setTensorMetadata(
|
||||
const at::Tensor& t,
|
||||
std::unordered_map<std::string, bool> metadata) {
|
||||
auto iter_end = metadata.end();
|
||||
auto iter_temp = metadata.find("conj");
|
||||
if (iter_temp != iter_end) {
|
||||
t._set_conj(true);
|
||||
metadata.erase(iter_temp);
|
||||
}
|
||||
iter_temp = metadata.find("neg");
|
||||
if (iter_temp != iter_end) {
|
||||
t._set_neg(true);
|
||||
metadata.erase(iter_temp);
|
||||
}
|
||||
// Only set BackendMetaData for custom backend if the function pointer is
|
||||
// registered.
|
||||
int device_type = static_cast<int>(t.device().type());
|
||||
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
||||
if (BackendMetaSerialization[device_type].has_value()) {
|
||||
// Pass the tensor and metadata map references as parameters to the custom
|
||||
// deserialization function.
|
||||
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second;
|
||||
fptr(t, metadata);
|
||||
}
|
||||
}
|
||||
|
||||
// set Tensor metadata based on the map.
|
||||
// NOTE: This overload is required by unpickler.cpp
|
||||
inline void setTensorMetadata(
|
||||
const at::Tensor& t,
|
||||
const c10::Dict<c10::IValue, c10::IValue>& metadata_idict) {
|
||||
std::unordered_map<std::string, bool> metadata;
|
||||
for (auto& pair : metadata_idict) {
|
||||
auto key = *pair.key().toString();
|
||||
metadata[key] = pair.value().toBool();
|
||||
}
|
||||
setTensorMetadata(t, std::move(metadata));
|
||||
}
|
||||
|
||||
} // namespace torch::jit
|
||||
|
117
torch/csrc/jit/serialization/pickler_helper.cpp
Normal file
117
torch/csrc/jit/serialization/pickler_helper.cpp
Normal file
@ -0,0 +1,117 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/serialization/pickler_helper.h>
|
||||
|
||||
namespace torch::jit {
|
||||
|
||||
WriteableTensorData getWriteableTensorData(
|
||||
const at::Tensor& tensor,
|
||||
bool to_cpu) {
|
||||
WriteableTensorData result;
|
||||
result.tensor_ = tensor;
|
||||
result.size_ = tensor.storage().nbytes();
|
||||
// TODO HIP support
|
||||
if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) {
|
||||
// NB: This new tensor is created to support cuda tensors.
|
||||
// Storages can be mutated when converting tensors from cuda to cpu,
|
||||
// and we need a cpu tensor to copy data from.
|
||||
result.tensor_ =
|
||||
at::empty({0}, tensor.options())
|
||||
.set_(
|
||||
tensor.storage(),
|
||||
/* storage_offset = */ 0,
|
||||
/* size = */
|
||||
{static_cast<int64_t>(
|
||||
tensor.storage().nbytes() / tensor.element_size())},
|
||||
/* stride = */ {1})
|
||||
.cpu();
|
||||
TORCH_CHECK(
|
||||
result.tensor_.storage().nbytes() == result.size_,
|
||||
"Storage tensor size did not match record size");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
|
||||
// Check that the schemas for __getstate__ and __setstate__ are correct
|
||||
auto getstate = cls->findMethod("__getstate__");
|
||||
if (getstate == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto get_schema = getstate->getSchema();
|
||||
|
||||
// Check __getstate__
|
||||
// __getstate__ is expected to be (self) -> T
|
||||
TORCH_CHECK(
|
||||
get_schema.arguments().size() == 1,
|
||||
"'__getstate__' must have 'self' as its only argument, but found ",
|
||||
get_schema.arguments().size(),
|
||||
" arguments");
|
||||
TORCH_CHECK(
|
||||
get_schema.returns().size() == 1,
|
||||
"'__getstate__' must return 1 value, but found ",
|
||||
get_schema.returns().size());
|
||||
|
||||
// Check __setstate__ if the method exists
|
||||
// __setstate__ is expected to be (self, T) -> None
|
||||
auto setstate = cls->findMethod("__setstate__");
|
||||
if (!setstate) {
|
||||
return false;
|
||||
}
|
||||
auto set_schema = setstate->getSchema();
|
||||
|
||||
TORCH_CHECK(
|
||||
set_schema.arguments().size() == 2,
|
||||
"'__setstate__' must have 'self' and the state as its "
|
||||
"only arguments, but found ",
|
||||
set_schema.arguments().size(),
|
||||
" arguments");
|
||||
TORCH_CHECK(
|
||||
set_schema.returns().size() == 1,
|
||||
"'__setstate__' must return None, but found ",
|
||||
set_schema.returns().size(),
|
||||
" return values");
|
||||
TORCH_CHECK(
|
||||
set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
|
||||
"'__setstate__' must return None, but found value of type",
|
||||
set_schema.returns().at(0).type()->annotation_str());
|
||||
|
||||
// Check that the return type of __getstate__ matches the input to
|
||||
// __setstate__
|
||||
auto get_type = get_schema.returns().at(0).type();
|
||||
auto set_type = set_schema.arguments().at(1).type();
|
||||
|
||||
TORCH_CHECK(
|
||||
get_type->isSubtypeOf(*set_type),
|
||||
"'__getstate__'s return type (",
|
||||
get_type->annotation_str(),
|
||||
") does not match '__setstate__'s argument type (",
|
||||
set_type->annotation_str(),
|
||||
")");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
|
||||
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
|
||||
c10::DeviceType::PrivateUse1};
|
||||
return DeviceTypeAllowlist;
|
||||
}
|
||||
|
||||
std::array<
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
||||
GetBackendMetaSerialization() {
|
||||
// The array to save function pointer for BackendMeta serialization.
|
||||
// key is the DeviceType, value is std::pair obj.
|
||||
// value.first represent get function and value.seconde represent set function
|
||||
static std::array<
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
BackendMetaSerialization;
|
||||
return BackendMetaSerialization;
|
||||
}
|
||||
|
||||
} // namespace torch::jit
|
232
torch/csrc/jit/serialization/pickler_helper.h
Normal file
232
torch/csrc/jit/serialization/pickler_helper.h
Normal file
@ -0,0 +1,232 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
||||
namespace torch::jit {
|
||||
|
||||
// See Python's pickletools.py for a detailed description of each of these codes
|
||||
enum class PickleOpCode : char {
|
||||
MARK = '(',
|
||||
STOP = '.',
|
||||
POP = '0',
|
||||
POP_MARK = '1',
|
||||
DUP = '2',
|
||||
FLOAT = 'F',
|
||||
INT = 'I',
|
||||
BININT = 'J',
|
||||
BININT1 = 'K',
|
||||
LONG = 'L',
|
||||
BININT2 = 'M',
|
||||
NONE = 'N',
|
||||
PERSID = 'P',
|
||||
BINPERSID = 'Q',
|
||||
REDUCE = 'R',
|
||||
STRING = 'S',
|
||||
BINSTRING = 'T',
|
||||
SHORT_BINSTRING = 'U',
|
||||
// NB: Avoid using UNICODE as it is a macro in the Windows API
|
||||
UNICODE_ = 'V',
|
||||
BINUNICODE = 'X',
|
||||
APPEND = 'a',
|
||||
BUILD = 'b',
|
||||
GLOBAL = 'c',
|
||||
DICT = 'd',
|
||||
EMPTY_DICT = '}',
|
||||
APPENDS = 'e',
|
||||
GET = 'g',
|
||||
BINGET = 'h',
|
||||
INST = 'i',
|
||||
LONG_BINGET = 'j',
|
||||
LIST = 'l',
|
||||
EMPTY_LIST = ']',
|
||||
OBJ = 'o',
|
||||
PUT = 'p',
|
||||
BINPUT = 'q',
|
||||
LONG_BINPUT = 'r',
|
||||
SETITEM = 's',
|
||||
TUPLE = 't',
|
||||
EMPTY_TUPLE = ')',
|
||||
SETITEMS = 'u',
|
||||
BINFLOAT = 'G',
|
||||
|
||||
// Protocol 2
|
||||
PROTO = char('\x80'),
|
||||
NEWOBJ = '\x81',
|
||||
EXT1 = '\x82',
|
||||
EXT2 = '\x83',
|
||||
EXT4 = '\x84',
|
||||
TUPLE1 = '\x85',
|
||||
TUPLE2 = '\x86',
|
||||
TUPLE3 = '\x87',
|
||||
NEWTRUE = '\x88',
|
||||
NEWFALSE = '\x89',
|
||||
LONG1 = '\x8a',
|
||||
LONG4 = '\x8b',
|
||||
|
||||
// Protocol 3 (Python 3.x)
|
||||
BINBYTES = 'B',
|
||||
SHORT_BINBYTES = 'C',
|
||||
|
||||
// Protocol 4
|
||||
SHORT_BINUNICODE = char('\x8c'),
|
||||
BINUNICODE8 = '\x8d',
|
||||
BINBYTES8 = '\x8e',
|
||||
EMPTY_SET = '\x8f',
|
||||
ADDITEMS = '\x90',
|
||||
FROZENSET = '\x91',
|
||||
NEWOBJ_EX = '\x92',
|
||||
STACK_GLOBAL = '\x93',
|
||||
MEMOIZE = '\x94',
|
||||
FRAME = '\x95'
|
||||
};
|
||||
|
||||
struct WriteableTensorData {
|
||||
const char* data() const {
|
||||
return static_cast<const char*>(tensor_.storage().data());
|
||||
}
|
||||
size_t sizeInBytes() const {
|
||||
return size_;
|
||||
}
|
||||
size_t nbytes() const {
|
||||
return tensor_.storage().nbytes();
|
||||
}
|
||||
bool storageHasDeleter() const {
|
||||
return tensor_.storage().data_ptr().get_context() != nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
friend TORCH_API WriteableTensorData
|
||||
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu);
|
||||
at::Tensor tensor_;
|
||||
uint64_t size_;
|
||||
};
|
||||
|
||||
// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
|
||||
// if it was CUDA and to_cpu is True.
|
||||
TORCH_API WriteableTensorData
|
||||
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true);
|
||||
|
||||
// if the cls has __getstate__/__setstate__
|
||||
// assert they have the right schema and return true,
|
||||
// otherwise return false
|
||||
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
|
||||
|
||||
// Declare BackendMeta serialization and deserialization function pointer types.
|
||||
using BackendMetaPtr = std::function<
|
||||
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
|
||||
|
||||
// A allowlist of device type, currently available is PrivateUse1
|
||||
TORCH_API std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist();
|
||||
|
||||
// Dynamically obtain serialization function pairs
|
||||
// that require the corresponding backend.
|
||||
TORCH_API std::array<
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
||||
GetBackendMetaSerialization();
|
||||
|
||||
// Return a map of Tensor Metadata which including BackendMetaData for
|
||||
// serialization. For now, it only takes care of `conj` and `neg` bit.
|
||||
inline std::unordered_map<std::string, bool> getTensorMetadata(
|
||||
const at::Tensor& t) {
|
||||
// We don't support serializing `ZeroTensor` as it is not public
|
||||
// facing yet.
|
||||
TORCH_CHECK(
|
||||
!t._is_zerotensor(),
|
||||
"ZeroTensor is not serializable,",
|
||||
" please file an issue if required.");
|
||||
std::unordered_map<std::string, bool> metadata{};
|
||||
|
||||
// Only add meta-data if the value is not default.
|
||||
if (t.is_conj()) {
|
||||
metadata["conj"] = true;
|
||||
}
|
||||
if (t.is_neg()) {
|
||||
metadata["neg"] = true;
|
||||
}
|
||||
// Only add BackendMetaData for custom backend if the function pointer is
|
||||
// registered.
|
||||
int device_type = static_cast<int>(t.device().type());
|
||||
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
||||
if (BackendMetaSerialization[device_type].has_value()) {
|
||||
// Pass the tensor and metadata map references as parameters to the custom
|
||||
// serialization function.
|
||||
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first;
|
||||
fptr(t, metadata);
|
||||
}
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// set Tensor Metadata based on the map.
|
||||
// Refer: getTensorMetadata
|
||||
inline void setTensorMetadata(
|
||||
const at::Tensor& t,
|
||||
std::unordered_map<std::string, bool> metadata) {
|
||||
auto iter_end = metadata.end();
|
||||
auto iter_temp = metadata.find("conj");
|
||||
if (iter_temp != iter_end) {
|
||||
t._set_conj(true);
|
||||
metadata.erase(iter_temp);
|
||||
}
|
||||
iter_temp = metadata.find("neg");
|
||||
if (iter_temp != iter_end) {
|
||||
t._set_neg(true);
|
||||
metadata.erase(iter_temp);
|
||||
}
|
||||
// Only set BackendMetaData for custom backend if the function pointer is
|
||||
// registered.
|
||||
int device_type = static_cast<int>(t.device().type());
|
||||
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
||||
if (BackendMetaSerialization[device_type].has_value()) {
|
||||
// Pass the tensor and metadata map references as parameters to the custom
|
||||
// deserialization function.
|
||||
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second;
|
||||
fptr(t, metadata);
|
||||
}
|
||||
}
|
||||
|
||||
// set Tensor metadata based on the map.
|
||||
// NOTE: This overload is required by unpickler.cpp
|
||||
inline void setTensorMetadata(
|
||||
const at::Tensor& t,
|
||||
const c10::Dict<c10::IValue, c10::IValue>& metadata_idict) {
|
||||
std::unordered_map<std::string, bool> metadata;
|
||||
for (auto& pair : metadata_idict) {
|
||||
auto key = *pair.key().toString();
|
||||
metadata[key] = pair.value().toBool();
|
||||
}
|
||||
setTensorMetadata(t, std::move(metadata));
|
||||
}
|
||||
|
||||
// Register function pointer of Tensor BackendMetadata for serialization.
|
||||
TORCH_API inline void TensorBackendMetaRegistry(
|
||||
c10::DeviceType t,
|
||||
const BackendMetaPtr& get_fptr,
|
||||
const BackendMetaPtr& set_fptr) {
|
||||
// allowlist verification
|
||||
// Only if the devicetype is in the allowlist,
|
||||
// we allow the serialization extension to be registered for backendmeta data.
|
||||
const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
|
||||
TORCH_CHECK(
|
||||
DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
|
||||
"It is not allowed to register the serialization method ",
|
||||
"of backendMeta data for PrivateUse1. ",
|
||||
"If you have related serialization requirements, ",
|
||||
"please expand the allowlist");
|
||||
// Register function pointer
|
||||
int device_type = static_cast<int>(t);
|
||||
auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
||||
TORCH_CHECK(
|
||||
!BackendMetaSerialization[device_type].has_value(),
|
||||
"The tensor BackendMeta serialization function pointer for ",
|
||||
t,
|
||||
" has been registered.");
|
||||
BackendMetaSerialization[device_type] =
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>(
|
||||
std::make_pair(get_fptr, set_fptr));
|
||||
}
|
||||
|
||||
} // namespace torch::jit
|
@ -5,7 +5,6 @@
|
||||
#endif
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
#include <torch/csrc/utils/byte_order.h>
|
||||
|
@ -3,9 +3,10 @@
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/jit/serialization/pickler_helper.h>
|
||||
|
||||
namespace torch::jit {
|
||||
|
||||
|
Reference in New Issue
Block a user