Fix the Problems About Defining Static Variable in Inline Function (#147095)

Refer to https://github.com/pytorch/pytorch/issues/125465 for more informations

- Remove unused header files
- Move common functionality to separate files to reduce dependencies between picklers and unpicklers
- Move the inline function that defines the static variable to .cc

Differential Revision: [D76266755](https://our.internmc.facebook.com/intern/diff/D76266755)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147095
Approved by: https://github.com/cyyever, https://github.com/albanD

Co-authored-by: Edward Yang <ezyang@meta.com>
This commit is contained in:
FFFrog
2025-06-24 20:19:39 +08:00
committed by PyTorch MergeBot
parent 41910d7a94
commit e8cf5ff564
16 changed files with 364 additions and 340 deletions

View File

@ -1244,6 +1244,7 @@ def define_buck_targets(
"torch/csrc/jit/mobile/parse_operators.cpp", "torch/csrc/jit/mobile/parse_operators.cpp",
"torch/csrc/jit/mobile/upgrader_mobile.cpp", "torch/csrc/jit/mobile/upgrader_mobile.cpp",
"torch/csrc/jit/serialization/import_read.cpp", "torch/csrc/jit/serialization/import_read.cpp",
"torch/csrc/jit/serialization/pickler_helper.cpp",
"torch/csrc/jit/serialization/unpickler.cpp", "torch/csrc/jit/serialization/unpickler.cpp",
], ],
header_namespace = "", header_namespace = "",

View File

@ -89,6 +89,7 @@ core_sources_common = [
torch_unpickler_common = [ torch_unpickler_common = [
"torch/csrc/jit/serialization/import_read.cpp", "torch/csrc/jit/serialization/import_read.cpp",
"torch/csrc/jit/serialization/pickler_helper.cpp",
"torch/csrc/jit/serialization/unpickler.cpp", "torch/csrc/jit/serialization/unpickler.cpp",
] ]
@ -637,6 +638,7 @@ libtorch_lite_eager_symbolication = [
# Later we can split serialization and deserialization logic # Later we can split serialization and deserialization logic
# to have better separation within build and only build relevant parts. # to have better separation within build and only build relevant parts.
"torch/csrc/jit/serialization/pickle.cpp", "torch/csrc/jit/serialization/pickle.cpp",
"torch/csrc/jit/serialization/pickler_helper.cpp",
"torch/csrc/jit/serialization/pickler.cpp", "torch/csrc/jit/serialization/pickler.cpp",
"torch/csrc/jit/serialization/unpickler.cpp", "torch/csrc/jit/serialization/unpickler.cpp",
] ]

View File

@ -1,5 +1,4 @@
#include <torch/csrc/jit/serialization/pickle.h> #include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/serialize.h> #include <torch/serialize.h>
#include <vector> #include <vector>

View File

@ -3,7 +3,6 @@
#include <torch/csrc/distributed/rpc/message.h> #include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h> #include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h> #include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/serialization/pickler.h>
namespace torch::distributed::rpc { namespace torch::distributed::rpc {
class TORCH_API PythonRemoteCall : public RpcCommandBase { class TORCH_API PythonRemoteCall : public RpcCommandBase {

View File

@ -4,7 +4,6 @@
#include <torch/csrc/distributed/rpc/rpc_command_base.h> #include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h> #include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/runtime/operator.h> #include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <vector> #include <vector>
namespace torch::distributed::rpc { namespace torch::distributed::rpc {

View File

@ -3,7 +3,6 @@
#include <torch/csrc/distributed/rpc/message.h> #include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h> #include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/jit/runtime/operator.h> #include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <optional> #include <optional>
#include <vector> #include <vector>

View File

@ -3,7 +3,6 @@
#include <torch/csrc/distributed/rpc/script_call.h> #include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/types.h> #include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/runtime/operator.h> #include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <vector> #include <vector>
namespace torch::distributed::rpc { namespace torch::distributed::rpc {

View File

@ -2,7 +2,6 @@
#include <torch/csrc/distributed/rpc/message.h> #include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h> #include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/jit/serialization/pickler.h>
namespace torch::distributed::rpc { namespace torch::distributed::rpc {

View File

@ -17,6 +17,7 @@
#include <torch/csrc/jit/serialization/import_export_functions.h> #include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h> #include <torch/csrc/jit/serialization/import_export_helpers.h>
#include <torch/csrc/jit/serialization/onnx.h> #include <torch/csrc/jit/serialization/onnx.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/onnx/back_compat.h> #include <torch/csrc/onnx/back_compat.h>
#include <torch/csrc/onnx/onnx.h> #include <torch/csrc/onnx/onnx.h>
#include <torch/version.h> #include <torch/version.h>

View File

@ -5,7 +5,6 @@
#include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/serialization/export_bytecode.h> #include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h> #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/python_print.h> #include <torch/csrc/jit/serialization/python_print.h>
#include <torch/csrc/jit/serialization/storage_context.h> #include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/type_name_uniquer.h> #include <torch/csrc/jit/serialization/type_name_uniquer.h>

View File

@ -1,20 +1,20 @@
#include <string>
#include <type_traits>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/core/Dict.h> #include <ATen/core/Dict.h>
#ifdef USE_RPC
#include <torch/csrc/distributed/rpc/rref_context.h>
#endif
#include <ATen/quantized/Quantizer.h> #include <ATen/quantized/Quantizer.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h> #include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/serialization/pickler.h> #include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/utils/byte_order.h> #include <torch/csrc/utils/byte_order.h>
#include <string> #ifdef USE_RPC
#include <type_traits> #include <torch/csrc/distributed/rpc/rref_context.h>
#endif
namespace torch::jit { namespace torch::jit {
using ::c10::IValue;
// Protocol 2 is the highest that can be decoded by Python 2 // Protocol 2 is the highest that can be decoded by Python 2
// See https://docs.python.org/3/library/pickle.html#data-stream-format // See https://docs.python.org/3/library/pickle.html#data-stream-format
constexpr static uint8_t PROTOCOL_VERSION = 2; constexpr static uint8_t PROTOCOL_VERSION = 2;
@ -719,92 +719,4 @@ void Pickler::pushTuple(const IValue& ivalue) {
} }
} }
WriteableTensorData getWriteableTensorData(
const at::Tensor& tensor,
bool to_cpu) {
WriteableTensorData result;
result.tensor_ = tensor;
result.size_ = tensor.storage().nbytes();
// TODO HIP support
if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) {
// NB: This new tensor is created to support cuda tensors.
// Storages can be mutated when converting tensors from cuda to cpu,
// and we need a cpu tensor to copy data from.
result.tensor_ =
at::empty({0}, tensor.options())
.set_(
tensor.storage(),
/* storage_offset = */ 0,
/* size = */
{static_cast<int64_t>(
tensor.storage().nbytes() / tensor.element_size())},
/* stride = */ {1})
.cpu();
TORCH_CHECK(
result.tensor_.storage().nbytes() == result.size_,
"Storage tensor size did not match record size");
}
return result;
}
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
// Check that the schemas for __getstate__ and __setstate__ are correct
auto getstate = cls->findMethod("__getstate__");
if (getstate == nullptr) {
return false;
}
auto get_schema = getstate->getSchema();
// Check __getstate__
// __getstate__ is expected to be (self) -> T
TORCH_CHECK(
get_schema.arguments().size() == 1,
"'__getstate__' must have 'self' as its only argument, but found ",
get_schema.arguments().size(),
" arguments");
TORCH_CHECK(
get_schema.returns().size() == 1,
"'__getstate__' must return 1 value, but found ",
get_schema.returns().size());
// Check __setstate__ if the method exists
// __setstate__ is expected to be (self, T) -> None
auto setstate = cls->findMethod("__setstate__");
if (!setstate) {
return false;
}
auto set_schema = setstate->getSchema();
TORCH_CHECK(
set_schema.arguments().size() == 2,
"'__setstate__' must have 'self' and the state as its "
"only arguments, but found ",
set_schema.arguments().size(),
" arguments");
TORCH_CHECK(
set_schema.returns().size() == 1,
"'__setstate__' must return None, but found ",
set_schema.returns().size(),
" return values");
TORCH_CHECK(
set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
"'__setstate__' must return None, but found value of type",
set_schema.returns().at(0).type()->annotation_str());
// Check that the return type of __getstate__ matches the input to
// __setstate__
auto get_type = get_schema.returns().at(0).type();
auto set_type = set_schema.arguments().at(1).type();
TORCH_CHECK(
get_type->isSubtypeOf(*set_type),
"'__getstate__'s return type (",
get_type->annotation_str(),
") does not match '__setstate__'s argument type (",
set_type->annotation_str(),
")");
return true;
}
} // namespace torch::jit } // namespace torch::jit

View File

@ -1,6 +1,5 @@
#pragma once #pragma once
#include <ATen/core/qualified_name.h>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <utility> #include <utility>
@ -9,112 +8,17 @@
#include <ATen/Utils.h> #include <ATen/Utils.h>
#include <ATen/core/ivalue.h> #include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h> #include <ATen/core/jit_type.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/ArrayRef.h> #include <c10/util/ArrayRef.h>
#include <c10/util/FbcodeMaps.h> #include <c10/util/FbcodeMaps.h>
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
#include <torch/csrc/Export.h> #include <torch/csrc/Export.h>
#include <torch/csrc/jit/serialization/pickler_helper.h>
namespace torch::jit { namespace torch::jit {
// See Python's pickletools.py for a detailed description of each of these codes
enum class PickleOpCode : char {
MARK = '(',
STOP = '.',
POP = '0',
POP_MARK = '1',
DUP = '2',
FLOAT = 'F',
INT = 'I',
BININT = 'J',
BININT1 = 'K',
LONG = 'L',
BININT2 = 'M',
NONE = 'N',
PERSID = 'P',
BINPERSID = 'Q',
REDUCE = 'R',
STRING = 'S',
BINSTRING = 'T',
SHORT_BINSTRING = 'U',
// NB: Avoid using UNICODE as it is a macro in the Windows API
UNICODE_ = 'V',
BINUNICODE = 'X',
APPEND = 'a',
BUILD = 'b',
GLOBAL = 'c',
DICT = 'd',
EMPTY_DICT = '}',
APPENDS = 'e',
GET = 'g',
BINGET = 'h',
INST = 'i',
LONG_BINGET = 'j',
LIST = 'l',
EMPTY_LIST = ']',
OBJ = 'o',
PUT = 'p',
BINPUT = 'q',
LONG_BINPUT = 'r',
SETITEM = 's',
TUPLE = 't',
EMPTY_TUPLE = ')',
SETITEMS = 'u',
BINFLOAT = 'G',
// Protocol 2
PROTO = char('\x80'),
NEWOBJ = '\x81',
EXT1 = '\x82',
EXT2 = '\x83',
EXT4 = '\x84',
TUPLE1 = '\x85',
TUPLE2 = '\x86',
TUPLE3 = '\x87',
NEWTRUE = '\x88',
NEWFALSE = '\x89',
LONG1 = '\x8a',
LONG4 = '\x8b',
// Protocol 3 (Python 3.x)
BINBYTES = 'B',
SHORT_BINBYTES = 'C',
// Protocol 4
SHORT_BINUNICODE = char('\x8c'),
BINUNICODE8 = '\x8d',
BINBYTES8 = '\x8e',
EMPTY_SET = '\x8f',
ADDITEMS = '\x90',
FROZENSET = '\x91',
NEWOBJ_EX = '\x92',
STACK_GLOBAL = '\x93',
MEMOIZE = '\x94',
FRAME = '\x95'
};
using ::c10::IValue; using ::c10::IValue;
struct WriteableTensorData {
const char* data() const {
return static_cast<const char*>(tensor_.storage().data());
}
size_t sizeInBytes() const {
return size_;
}
size_t nbytes() const {
return tensor_.storage().nbytes();
}
bool storageHasDeleter() const {
return tensor_.storage().data_ptr().get_context() != nullptr;
}
private:
friend TORCH_API WriteableTensorData
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu);
at::Tensor tensor_;
uint64_t size_;
};
class TORCH_API Pickler { class TORCH_API Pickler {
AT_DISALLOW_COPY_AND_ASSIGN(Pickler); AT_DISALLOW_COPY_AND_ASSIGN(Pickler);
@ -278,142 +182,4 @@ class TORCH_API Pickler {
bool tag_aggregates_; bool tag_aggregates_;
}; };
// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
// if it was CUDA and to_cpu is True.
TORCH_API WriteableTensorData
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true);
// if the cls has __getstate__/__setstate__
// assert they have the right schema and return true,
// otherwise return false
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
// Declare BackendMeta serialization and deserialization function pointer types.
using BackendMetaPtr = std::function<
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
// A allowlist of device type, currently available is PrivateUse1
inline std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
c10::DeviceType::PrivateUse1};
return DeviceTypeAllowlist;
}
// Dynamically obtain serialization function pairs
// that require the corresponding backend.
inline std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
GetBackendMetaSerialization() {
// The array to save function pointer for BackendMeta serialization.
// key is the DeviceType, value is std::pair obj.
// value.first represent get function and value.seconde represent set function
static std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>
BackendMetaSerialization;
return BackendMetaSerialization;
}
// Register function pointer of Tensor BackendMetadata for serialization.
TORCH_API inline void TensorBackendMetaRegistry(
c10::DeviceType t,
const BackendMetaPtr& get_fptr,
const BackendMetaPtr& set_fptr) {
// allowlist verification
// Only if the devicetype is in the allowlist,
// we allow the serialization extension to be registered for backendmeta data.
const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
TORCH_CHECK(
DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
"It is not allowed to register the serialization method ",
"of backendMeta data for PrivateUse1. ",
"If you have related serialization requirements, ",
"please expand the allowlist");
// Register function pointer
int device_type = static_cast<int>(t);
auto& BackendMetaSerialization = GetBackendMetaSerialization();
TORCH_CHECK(
!BackendMetaSerialization[device_type].has_value(),
"The tensor BackendMeta serialization function pointer for ",
t,
" has been registered.");
BackendMetaSerialization[device_type] =
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>(
std::make_pair(get_fptr, set_fptr));
}
// Return a map of Tensor Metadata which including BackendMetaData for
// serialization. For now, it only takes care of `conj` and `neg` bit.
inline std::unordered_map<std::string, bool> getTensorMetadata(
const at::Tensor& t) {
// We don't support serializing `ZeroTensor` as it is not public
// facing yet.
TORCH_CHECK(
!t._is_zerotensor(),
"ZeroTensor is not serializable,",
" please file an issue if required.");
std::unordered_map<std::string, bool> metadata{};
// Only add meta-data if the value is not default.
if (t.is_conj()) {
metadata["conj"] = true;
}
if (t.is_neg()) {
metadata["neg"] = true;
}
// Only add BackendMetaData for custom backend if the function pointer is
// registered.
int device_type = static_cast<int>(t.device().type());
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
if (BackendMetaSerialization[device_type].has_value()) {
// Pass the tensor and metadata map references as parameters to the custom
// serialization function.
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first;
fptr(t, metadata);
}
return metadata;
}
// set Tensor Metadata based on the map.
// Refer: getTensorMetadata
inline void setTensorMetadata(
const at::Tensor& t,
std::unordered_map<std::string, bool> metadata) {
auto iter_end = metadata.end();
auto iter_temp = metadata.find("conj");
if (iter_temp != iter_end) {
t._set_conj(true);
metadata.erase(iter_temp);
}
iter_temp = metadata.find("neg");
if (iter_temp != iter_end) {
t._set_neg(true);
metadata.erase(iter_temp);
}
// Only set BackendMetaData for custom backend if the function pointer is
// registered.
int device_type = static_cast<int>(t.device().type());
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
if (BackendMetaSerialization[device_type].has_value()) {
// Pass the tensor and metadata map references as parameters to the custom
// deserialization function.
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second;
fptr(t, metadata);
}
}
// set Tensor metadata based on the map.
// NOTE: This overload is required by unpickler.cpp
inline void setTensorMetadata(
const at::Tensor& t,
const c10::Dict<c10::IValue, c10::IValue>& metadata_idict) {
std::unordered_map<std::string, bool> metadata;
for (auto& pair : metadata_idict) {
auto key = *pair.key().toString();
metadata[key] = pair.value().toBool();
}
setTensorMetadata(t, std::move(metadata));
}
} // namespace torch::jit } // namespace torch::jit

View File

@ -0,0 +1,117 @@
#include <ATen/ATen.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/serialization/pickler_helper.h>
namespace torch::jit {
WriteableTensorData getWriteableTensorData(
const at::Tensor& tensor,
bool to_cpu) {
WriteableTensorData result;
result.tensor_ = tensor;
result.size_ = tensor.storage().nbytes();
// TODO HIP support
if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) {
// NB: This new tensor is created to support cuda tensors.
// Storages can be mutated when converting tensors from cuda to cpu,
// and we need a cpu tensor to copy data from.
result.tensor_ =
at::empty({0}, tensor.options())
.set_(
tensor.storage(),
/* storage_offset = */ 0,
/* size = */
{static_cast<int64_t>(
tensor.storage().nbytes() / tensor.element_size())},
/* stride = */ {1})
.cpu();
TORCH_CHECK(
result.tensor_.storage().nbytes() == result.size_,
"Storage tensor size did not match record size");
}
return result;
}
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
// Check that the schemas for __getstate__ and __setstate__ are correct
auto getstate = cls->findMethod("__getstate__");
if (getstate == nullptr) {
return false;
}
auto get_schema = getstate->getSchema();
// Check __getstate__
// __getstate__ is expected to be (self) -> T
TORCH_CHECK(
get_schema.arguments().size() == 1,
"'__getstate__' must have 'self' as its only argument, but found ",
get_schema.arguments().size(),
" arguments");
TORCH_CHECK(
get_schema.returns().size() == 1,
"'__getstate__' must return 1 value, but found ",
get_schema.returns().size());
// Check __setstate__ if the method exists
// __setstate__ is expected to be (self, T) -> None
auto setstate = cls->findMethod("__setstate__");
if (!setstate) {
return false;
}
auto set_schema = setstate->getSchema();
TORCH_CHECK(
set_schema.arguments().size() == 2,
"'__setstate__' must have 'self' and the state as its "
"only arguments, but found ",
set_schema.arguments().size(),
" arguments");
TORCH_CHECK(
set_schema.returns().size() == 1,
"'__setstate__' must return None, but found ",
set_schema.returns().size(),
" return values");
TORCH_CHECK(
set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
"'__setstate__' must return None, but found value of type",
set_schema.returns().at(0).type()->annotation_str());
// Check that the return type of __getstate__ matches the input to
// __setstate__
auto get_type = get_schema.returns().at(0).type();
auto set_type = set_schema.arguments().at(1).type();
TORCH_CHECK(
get_type->isSubtypeOf(*set_type),
"'__getstate__'s return type (",
get_type->annotation_str(),
") does not match '__setstate__'s argument type (",
set_type->annotation_str(),
")");
return true;
}
std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
c10::DeviceType::PrivateUse1};
return DeviceTypeAllowlist;
}
std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
GetBackendMetaSerialization() {
// The array to save function pointer for BackendMeta serialization.
// key is the DeviceType, value is std::pair obj.
// value.first represent get function and value.seconde represent set function
static std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>
BackendMetaSerialization;
return BackendMetaSerialization;
}
} // namespace torch::jit

View File

@ -0,0 +1,232 @@
#pragma once
#include <string>
#include <ATen/Utils.h>
#include <ATen/core/ivalue.h>
namespace torch::jit {
// See Python's pickletools.py for a detailed description of each of these codes
enum class PickleOpCode : char {
MARK = '(',
STOP = '.',
POP = '0',
POP_MARK = '1',
DUP = '2',
FLOAT = 'F',
INT = 'I',
BININT = 'J',
BININT1 = 'K',
LONG = 'L',
BININT2 = 'M',
NONE = 'N',
PERSID = 'P',
BINPERSID = 'Q',
REDUCE = 'R',
STRING = 'S',
BINSTRING = 'T',
SHORT_BINSTRING = 'U',
// NB: Avoid using UNICODE as it is a macro in the Windows API
UNICODE_ = 'V',
BINUNICODE = 'X',
APPEND = 'a',
BUILD = 'b',
GLOBAL = 'c',
DICT = 'd',
EMPTY_DICT = '}',
APPENDS = 'e',
GET = 'g',
BINGET = 'h',
INST = 'i',
LONG_BINGET = 'j',
LIST = 'l',
EMPTY_LIST = ']',
OBJ = 'o',
PUT = 'p',
BINPUT = 'q',
LONG_BINPUT = 'r',
SETITEM = 's',
TUPLE = 't',
EMPTY_TUPLE = ')',
SETITEMS = 'u',
BINFLOAT = 'G',
// Protocol 2
PROTO = char('\x80'),
NEWOBJ = '\x81',
EXT1 = '\x82',
EXT2 = '\x83',
EXT4 = '\x84',
TUPLE1 = '\x85',
TUPLE2 = '\x86',
TUPLE3 = '\x87',
NEWTRUE = '\x88',
NEWFALSE = '\x89',
LONG1 = '\x8a',
LONG4 = '\x8b',
// Protocol 3 (Python 3.x)
BINBYTES = 'B',
SHORT_BINBYTES = 'C',
// Protocol 4
SHORT_BINUNICODE = char('\x8c'),
BINUNICODE8 = '\x8d',
BINBYTES8 = '\x8e',
EMPTY_SET = '\x8f',
ADDITEMS = '\x90',
FROZENSET = '\x91',
NEWOBJ_EX = '\x92',
STACK_GLOBAL = '\x93',
MEMOIZE = '\x94',
FRAME = '\x95'
};
struct WriteableTensorData {
const char* data() const {
return static_cast<const char*>(tensor_.storage().data());
}
size_t sizeInBytes() const {
return size_;
}
size_t nbytes() const {
return tensor_.storage().nbytes();
}
bool storageHasDeleter() const {
return tensor_.storage().data_ptr().get_context() != nullptr;
}
private:
friend TORCH_API WriteableTensorData
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu);
at::Tensor tensor_;
uint64_t size_;
};
// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
// if it was CUDA and to_cpu is True.
TORCH_API WriteableTensorData
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true);
// if the cls has __getstate__/__setstate__
// assert they have the right schema and return true,
// otherwise return false
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
// Declare BackendMeta serialization and deserialization function pointer types.
using BackendMetaPtr = std::function<
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
// A allowlist of device type, currently available is PrivateUse1
TORCH_API std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist();
// Dynamically obtain serialization function pairs
// that require the corresponding backend.
TORCH_API std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
GetBackendMetaSerialization();
// Return a map of Tensor Metadata which including BackendMetaData for
// serialization. For now, it only takes care of `conj` and `neg` bit.
inline std::unordered_map<std::string, bool> getTensorMetadata(
const at::Tensor& t) {
// We don't support serializing `ZeroTensor` as it is not public
// facing yet.
TORCH_CHECK(
!t._is_zerotensor(),
"ZeroTensor is not serializable,",
" please file an issue if required.");
std::unordered_map<std::string, bool> metadata{};
// Only add meta-data if the value is not default.
if (t.is_conj()) {
metadata["conj"] = true;
}
if (t.is_neg()) {
metadata["neg"] = true;
}
// Only add BackendMetaData for custom backend if the function pointer is
// registered.
int device_type = static_cast<int>(t.device().type());
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
if (BackendMetaSerialization[device_type].has_value()) {
// Pass the tensor and metadata map references as parameters to the custom
// serialization function.
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first;
fptr(t, metadata);
}
return metadata;
}
// set Tensor Metadata based on the map.
// Refer: getTensorMetadata
inline void setTensorMetadata(
const at::Tensor& t,
std::unordered_map<std::string, bool> metadata) {
auto iter_end = metadata.end();
auto iter_temp = metadata.find("conj");
if (iter_temp != iter_end) {
t._set_conj(true);
metadata.erase(iter_temp);
}
iter_temp = metadata.find("neg");
if (iter_temp != iter_end) {
t._set_neg(true);
metadata.erase(iter_temp);
}
// Only set BackendMetaData for custom backend if the function pointer is
// registered.
int device_type = static_cast<int>(t.device().type());
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
if (BackendMetaSerialization[device_type].has_value()) {
// Pass the tensor and metadata map references as parameters to the custom
// deserialization function.
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second;
fptr(t, metadata);
}
}
// set Tensor metadata based on the map.
// NOTE: This overload is required by unpickler.cpp
inline void setTensorMetadata(
const at::Tensor& t,
const c10::Dict<c10::IValue, c10::IValue>& metadata_idict) {
std::unordered_map<std::string, bool> metadata;
for (auto& pair : metadata_idict) {
auto key = *pair.key().toString();
metadata[key] = pair.value().toBool();
}
setTensorMetadata(t, std::move(metadata));
}
// Register function pointer of Tensor BackendMetadata for serialization.
TORCH_API inline void TensorBackendMetaRegistry(
c10::DeviceType t,
const BackendMetaPtr& get_fptr,
const BackendMetaPtr& set_fptr) {
// allowlist verification
// Only if the devicetype is in the allowlist,
// we allow the serialization extension to be registered for backendmeta data.
const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
TORCH_CHECK(
DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
"It is not allowed to register the serialization method ",
"of backendMeta data for PrivateUse1. ",
"If you have related serialization requirements, ",
"please expand the allowlist");
// Register function pointer
int device_type = static_cast<int>(t);
auto& BackendMetaSerialization = GetBackendMetaSerialization();
TORCH_CHECK(
!BackendMetaSerialization[device_type].has_value(),
"The tensor BackendMeta serialization function pointer for ",
t,
" has been registered.");
BackendMetaSerialization[device_type] =
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>(
std::make_pair(get_fptr, set_fptr));
}
} // namespace torch::jit

View File

@ -5,7 +5,6 @@
#endif #endif
#include <torch/csrc/jit/api/function_impl.h> #include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/mobile/type_parser.h> #include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/storage_context.h> #include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/unpickler.h> #include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/csrc/utils/byte_order.h> #include <torch/csrc/utils/byte_order.h>

View File

@ -3,9 +3,10 @@
#include <ATen/core/ivalue.h> #include <ATen/core/ivalue.h>
#include <c10/util/ArrayRef.h> #include <c10/util/ArrayRef.h>
#include <caffe2/serialize/inline_container.h> #include <caffe2/serialize/inline_container.h>
#include <torch/csrc/Export.h> #include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/script_type_parser.h> #include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/serialization/pickler.h> #include <torch/csrc/jit/serialization/pickler_helper.h>
namespace torch::jit { namespace torch::jit {