mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix the Problems About Defining Static Variable in Inline Function (#147095)
Refer to https://github.com/pytorch/pytorch/issues/125465 for more informations - Remove unused header files - Move common functionality to separate files to reduce dependencies between picklers and unpicklers - Move the inline function that defines the static variable to .cc Differential Revision: [D76266755](https://our.internmc.facebook.com/intern/diff/D76266755) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147095 Approved by: https://github.com/cyyever, https://github.com/albanD Co-authored-by: Edward Yang <ezyang@meta.com>
This commit is contained in:
@ -1244,6 +1244,7 @@ def define_buck_targets(
|
|||||||
"torch/csrc/jit/mobile/parse_operators.cpp",
|
"torch/csrc/jit/mobile/parse_operators.cpp",
|
||||||
"torch/csrc/jit/mobile/upgrader_mobile.cpp",
|
"torch/csrc/jit/mobile/upgrader_mobile.cpp",
|
||||||
"torch/csrc/jit/serialization/import_read.cpp",
|
"torch/csrc/jit/serialization/import_read.cpp",
|
||||||
|
"torch/csrc/jit/serialization/pickler_helper.cpp",
|
||||||
"torch/csrc/jit/serialization/unpickler.cpp",
|
"torch/csrc/jit/serialization/unpickler.cpp",
|
||||||
],
|
],
|
||||||
header_namespace = "",
|
header_namespace = "",
|
||||||
|
@ -89,6 +89,7 @@ core_sources_common = [
|
|||||||
|
|
||||||
torch_unpickler_common = [
|
torch_unpickler_common = [
|
||||||
"torch/csrc/jit/serialization/import_read.cpp",
|
"torch/csrc/jit/serialization/import_read.cpp",
|
||||||
|
"torch/csrc/jit/serialization/pickler_helper.cpp",
|
||||||
"torch/csrc/jit/serialization/unpickler.cpp",
|
"torch/csrc/jit/serialization/unpickler.cpp",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -637,6 +638,7 @@ libtorch_lite_eager_symbolication = [
|
|||||||
# Later we can split serialization and deserialization logic
|
# Later we can split serialization and deserialization logic
|
||||||
# to have better separation within build and only build relevant parts.
|
# to have better separation within build and only build relevant parts.
|
||||||
"torch/csrc/jit/serialization/pickle.cpp",
|
"torch/csrc/jit/serialization/pickle.cpp",
|
||||||
|
"torch/csrc/jit/serialization/pickler_helper.cpp",
|
||||||
"torch/csrc/jit/serialization/pickler.cpp",
|
"torch/csrc/jit/serialization/pickler.cpp",
|
||||||
"torch/csrc/jit/serialization/unpickler.cpp",
|
"torch/csrc/jit/serialization/unpickler.cpp",
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
#include <torch/csrc/jit/serialization/pickle.h>
|
#include <torch/csrc/jit/serialization/pickle.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
|
||||||
#include <torch/serialize.h>
|
#include <torch/serialize.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <torch/csrc/distributed/rpc/message.h>
|
#include <torch/csrc/distributed/rpc/message.h>
|
||||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||||
#include <torch/csrc/distributed/rpc/types.h>
|
#include <torch/csrc/distributed/rpc/types.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
|
||||||
namespace torch::distributed::rpc {
|
namespace torch::distributed::rpc {
|
||||||
|
|
||||||
class TORCH_API PythonRemoteCall : public RpcCommandBase {
|
class TORCH_API PythonRemoteCall : public RpcCommandBase {
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||||
#include <torch/csrc/distributed/rpc/types.h>
|
#include <torch/csrc/distributed/rpc/types.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch::distributed::rpc {
|
namespace torch::distributed::rpc {
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <torch/csrc/distributed/rpc/message.h>
|
#include <torch/csrc/distributed/rpc/message.h>
|
||||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <torch/csrc/distributed/rpc/script_call.h>
|
#include <torch/csrc/distributed/rpc/script_call.h>
|
||||||
#include <torch/csrc/distributed/rpc/types.h>
|
#include <torch/csrc/distributed/rpc/types.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch::distributed::rpc {
|
namespace torch::distributed::rpc {
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include <torch/csrc/distributed/rpc/message.h>
|
#include <torch/csrc/distributed/rpc/message.h>
|
||||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
|
||||||
|
|
||||||
namespace torch::distributed::rpc {
|
namespace torch::distributed::rpc {
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||||
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
||||||
#include <torch/csrc/jit/serialization/onnx.h>
|
#include <torch/csrc/jit/serialization/onnx.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
#include <torch/csrc/onnx/back_compat.h>
|
#include <torch/csrc/onnx/back_compat.h>
|
||||||
#include <torch/csrc/onnx/onnx.h>
|
#include <torch/csrc/onnx/onnx.h>
|
||||||
#include <torch/version.h>
|
#include <torch/version.h>
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
|
||||||
#include <torch/csrc/jit/serialization/python_print.h>
|
#include <torch/csrc/jit/serialization/python_print.h>
|
||||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||||
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
||||||
|
@ -1,20 +1,20 @@
|
|||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/core/Dict.h>
|
#include <ATen/core/Dict.h>
|
||||||
#ifdef USE_RPC
|
|
||||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
||||||
#endif
|
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
|
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <torch/csrc/jit/api/function_impl.h>
|
#include <torch/csrc/jit/api/function_impl.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
#include <torch/csrc/jit/serialization/pickler.h>
|
||||||
#include <torch/csrc/utils/byte_order.h>
|
#include <torch/csrc/utils/byte_order.h>
|
||||||
#include <string>
|
#ifdef USE_RPC
|
||||||
#include <type_traits>
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace torch::jit {
|
namespace torch::jit {
|
||||||
|
|
||||||
using ::c10::IValue;
|
|
||||||
|
|
||||||
// Protocol 2 is the highest that can be decoded by Python 2
|
// Protocol 2 is the highest that can be decoded by Python 2
|
||||||
// See https://docs.python.org/3/library/pickle.html#data-stream-format
|
// See https://docs.python.org/3/library/pickle.html#data-stream-format
|
||||||
constexpr static uint8_t PROTOCOL_VERSION = 2;
|
constexpr static uint8_t PROTOCOL_VERSION = 2;
|
||||||
@ -719,92 +719,4 @@ void Pickler::pushTuple(const IValue& ivalue) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
WriteableTensorData getWriteableTensorData(
|
|
||||||
const at::Tensor& tensor,
|
|
||||||
bool to_cpu) {
|
|
||||||
WriteableTensorData result;
|
|
||||||
result.tensor_ = tensor;
|
|
||||||
result.size_ = tensor.storage().nbytes();
|
|
||||||
// TODO HIP support
|
|
||||||
if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) {
|
|
||||||
// NB: This new tensor is created to support cuda tensors.
|
|
||||||
// Storages can be mutated when converting tensors from cuda to cpu,
|
|
||||||
// and we need a cpu tensor to copy data from.
|
|
||||||
result.tensor_ =
|
|
||||||
at::empty({0}, tensor.options())
|
|
||||||
.set_(
|
|
||||||
tensor.storage(),
|
|
||||||
/* storage_offset = */ 0,
|
|
||||||
/* size = */
|
|
||||||
{static_cast<int64_t>(
|
|
||||||
tensor.storage().nbytes() / tensor.element_size())},
|
|
||||||
/* stride = */ {1})
|
|
||||||
.cpu();
|
|
||||||
TORCH_CHECK(
|
|
||||||
result.tensor_.storage().nbytes() == result.size_,
|
|
||||||
"Storage tensor size did not match record size");
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
|
|
||||||
// Check that the schemas for __getstate__ and __setstate__ are correct
|
|
||||||
auto getstate = cls->findMethod("__getstate__");
|
|
||||||
if (getstate == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto get_schema = getstate->getSchema();
|
|
||||||
|
|
||||||
// Check __getstate__
|
|
||||||
// __getstate__ is expected to be (self) -> T
|
|
||||||
TORCH_CHECK(
|
|
||||||
get_schema.arguments().size() == 1,
|
|
||||||
"'__getstate__' must have 'self' as its only argument, but found ",
|
|
||||||
get_schema.arguments().size(),
|
|
||||||
" arguments");
|
|
||||||
TORCH_CHECK(
|
|
||||||
get_schema.returns().size() == 1,
|
|
||||||
"'__getstate__' must return 1 value, but found ",
|
|
||||||
get_schema.returns().size());
|
|
||||||
|
|
||||||
// Check __setstate__ if the method exists
|
|
||||||
// __setstate__ is expected to be (self, T) -> None
|
|
||||||
auto setstate = cls->findMethod("__setstate__");
|
|
||||||
if (!setstate) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto set_schema = setstate->getSchema();
|
|
||||||
|
|
||||||
TORCH_CHECK(
|
|
||||||
set_schema.arguments().size() == 2,
|
|
||||||
"'__setstate__' must have 'self' and the state as its "
|
|
||||||
"only arguments, but found ",
|
|
||||||
set_schema.arguments().size(),
|
|
||||||
" arguments");
|
|
||||||
TORCH_CHECK(
|
|
||||||
set_schema.returns().size() == 1,
|
|
||||||
"'__setstate__' must return None, but found ",
|
|
||||||
set_schema.returns().size(),
|
|
||||||
" return values");
|
|
||||||
TORCH_CHECK(
|
|
||||||
set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
|
|
||||||
"'__setstate__' must return None, but found value of type",
|
|
||||||
set_schema.returns().at(0).type()->annotation_str());
|
|
||||||
|
|
||||||
// Check that the return type of __getstate__ matches the input to
|
|
||||||
// __setstate__
|
|
||||||
auto get_type = get_schema.returns().at(0).type();
|
|
||||||
auto set_type = set_schema.arguments().at(1).type();
|
|
||||||
|
|
||||||
TORCH_CHECK(
|
|
||||||
get_type->isSubtypeOf(*set_type),
|
|
||||||
"'__getstate__'s return type (",
|
|
||||||
get_type->annotation_str(),
|
|
||||||
") does not match '__setstate__'s argument type (",
|
|
||||||
set_type->annotation_str(),
|
|
||||||
")");
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace torch::jit
|
} // namespace torch::jit
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/core/qualified_name.h>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -9,112 +8,17 @@
|
|||||||
#include <ATen/Utils.h>
|
#include <ATen/Utils.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
#include <ATen/core/jit_type.h>
|
#include <ATen/core/jit_type.h>
|
||||||
|
#include <ATen/core/qualified_name.h>
|
||||||
#include <c10/util/ArrayRef.h>
|
#include <c10/util/ArrayRef.h>
|
||||||
#include <c10/util/FbcodeMaps.h>
|
#include <c10/util/FbcodeMaps.h>
|
||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler_helper.h>
|
||||||
|
|
||||||
namespace torch::jit {
|
namespace torch::jit {
|
||||||
|
|
||||||
// See Python's pickletools.py for a detailed description of each of these codes
|
|
||||||
enum class PickleOpCode : char {
|
|
||||||
MARK = '(',
|
|
||||||
STOP = '.',
|
|
||||||
POP = '0',
|
|
||||||
POP_MARK = '1',
|
|
||||||
DUP = '2',
|
|
||||||
FLOAT = 'F',
|
|
||||||
INT = 'I',
|
|
||||||
BININT = 'J',
|
|
||||||
BININT1 = 'K',
|
|
||||||
LONG = 'L',
|
|
||||||
BININT2 = 'M',
|
|
||||||
NONE = 'N',
|
|
||||||
PERSID = 'P',
|
|
||||||
BINPERSID = 'Q',
|
|
||||||
REDUCE = 'R',
|
|
||||||
STRING = 'S',
|
|
||||||
BINSTRING = 'T',
|
|
||||||
SHORT_BINSTRING = 'U',
|
|
||||||
// NB: Avoid using UNICODE as it is a macro in the Windows API
|
|
||||||
UNICODE_ = 'V',
|
|
||||||
BINUNICODE = 'X',
|
|
||||||
APPEND = 'a',
|
|
||||||
BUILD = 'b',
|
|
||||||
GLOBAL = 'c',
|
|
||||||
DICT = 'd',
|
|
||||||
EMPTY_DICT = '}',
|
|
||||||
APPENDS = 'e',
|
|
||||||
GET = 'g',
|
|
||||||
BINGET = 'h',
|
|
||||||
INST = 'i',
|
|
||||||
LONG_BINGET = 'j',
|
|
||||||
LIST = 'l',
|
|
||||||
EMPTY_LIST = ']',
|
|
||||||
OBJ = 'o',
|
|
||||||
PUT = 'p',
|
|
||||||
BINPUT = 'q',
|
|
||||||
LONG_BINPUT = 'r',
|
|
||||||
SETITEM = 's',
|
|
||||||
TUPLE = 't',
|
|
||||||
EMPTY_TUPLE = ')',
|
|
||||||
SETITEMS = 'u',
|
|
||||||
BINFLOAT = 'G',
|
|
||||||
|
|
||||||
// Protocol 2
|
|
||||||
PROTO = char('\x80'),
|
|
||||||
NEWOBJ = '\x81',
|
|
||||||
EXT1 = '\x82',
|
|
||||||
EXT2 = '\x83',
|
|
||||||
EXT4 = '\x84',
|
|
||||||
TUPLE1 = '\x85',
|
|
||||||
TUPLE2 = '\x86',
|
|
||||||
TUPLE3 = '\x87',
|
|
||||||
NEWTRUE = '\x88',
|
|
||||||
NEWFALSE = '\x89',
|
|
||||||
LONG1 = '\x8a',
|
|
||||||
LONG4 = '\x8b',
|
|
||||||
|
|
||||||
// Protocol 3 (Python 3.x)
|
|
||||||
BINBYTES = 'B',
|
|
||||||
SHORT_BINBYTES = 'C',
|
|
||||||
|
|
||||||
// Protocol 4
|
|
||||||
SHORT_BINUNICODE = char('\x8c'),
|
|
||||||
BINUNICODE8 = '\x8d',
|
|
||||||
BINBYTES8 = '\x8e',
|
|
||||||
EMPTY_SET = '\x8f',
|
|
||||||
ADDITEMS = '\x90',
|
|
||||||
FROZENSET = '\x91',
|
|
||||||
NEWOBJ_EX = '\x92',
|
|
||||||
STACK_GLOBAL = '\x93',
|
|
||||||
MEMOIZE = '\x94',
|
|
||||||
FRAME = '\x95'
|
|
||||||
};
|
|
||||||
|
|
||||||
using ::c10::IValue;
|
using ::c10::IValue;
|
||||||
|
|
||||||
struct WriteableTensorData {
|
|
||||||
const char* data() const {
|
|
||||||
return static_cast<const char*>(tensor_.storage().data());
|
|
||||||
}
|
|
||||||
size_t sizeInBytes() const {
|
|
||||||
return size_;
|
|
||||||
}
|
|
||||||
size_t nbytes() const {
|
|
||||||
return tensor_.storage().nbytes();
|
|
||||||
}
|
|
||||||
bool storageHasDeleter() const {
|
|
||||||
return tensor_.storage().data_ptr().get_context() != nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend TORCH_API WriteableTensorData
|
|
||||||
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu);
|
|
||||||
at::Tensor tensor_;
|
|
||||||
uint64_t size_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class TORCH_API Pickler {
|
class TORCH_API Pickler {
|
||||||
AT_DISALLOW_COPY_AND_ASSIGN(Pickler);
|
AT_DISALLOW_COPY_AND_ASSIGN(Pickler);
|
||||||
|
|
||||||
@ -278,142 +182,4 @@ class TORCH_API Pickler {
|
|||||||
bool tag_aggregates_;
|
bool tag_aggregates_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
|
|
||||||
// if it was CUDA and to_cpu is True.
|
|
||||||
TORCH_API WriteableTensorData
|
|
||||||
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true);
|
|
||||||
|
|
||||||
// if the cls has __getstate__/__setstate__
|
|
||||||
// assert they have the right schema and return true,
|
|
||||||
// otherwise return false
|
|
||||||
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
|
|
||||||
|
|
||||||
// Declare BackendMeta serialization and deserialization function pointer types.
|
|
||||||
using BackendMetaPtr = std::function<
|
|
||||||
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
|
|
||||||
|
|
||||||
// A allowlist of device type, currently available is PrivateUse1
|
|
||||||
inline std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
|
|
||||||
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
|
|
||||||
c10::DeviceType::PrivateUse1};
|
|
||||||
return DeviceTypeAllowlist;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dynamically obtain serialization function pairs
|
|
||||||
// that require the corresponding backend.
|
|
||||||
inline std::array<
|
|
||||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
|
||||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
|
||||||
GetBackendMetaSerialization() {
|
|
||||||
// The array to save function pointer for BackendMeta serialization.
|
|
||||||
// key is the DeviceType, value is std::pair obj.
|
|
||||||
// value.first represent get function and value.seconde represent set function
|
|
||||||
static std::array<
|
|
||||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
|
||||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
|
||||||
BackendMetaSerialization;
|
|
||||||
return BackendMetaSerialization;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register function pointer of Tensor BackendMetadata for serialization.
|
|
||||||
TORCH_API inline void TensorBackendMetaRegistry(
|
|
||||||
c10::DeviceType t,
|
|
||||||
const BackendMetaPtr& get_fptr,
|
|
||||||
const BackendMetaPtr& set_fptr) {
|
|
||||||
// allowlist verification
|
|
||||||
// Only if the devicetype is in the allowlist,
|
|
||||||
// we allow the serialization extension to be registered for backendmeta data.
|
|
||||||
const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
|
|
||||||
TORCH_CHECK(
|
|
||||||
DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
|
|
||||||
"It is not allowed to register the serialization method ",
|
|
||||||
"of backendMeta data for PrivateUse1. ",
|
|
||||||
"If you have related serialization requirements, ",
|
|
||||||
"please expand the allowlist");
|
|
||||||
// Register function pointer
|
|
||||||
int device_type = static_cast<int>(t);
|
|
||||||
auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
|
||||||
TORCH_CHECK(
|
|
||||||
!BackendMetaSerialization[device_type].has_value(),
|
|
||||||
"The tensor BackendMeta serialization function pointer for ",
|
|
||||||
t,
|
|
||||||
" has been registered.");
|
|
||||||
BackendMetaSerialization[device_type] =
|
|
||||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>(
|
|
||||||
std::make_pair(get_fptr, set_fptr));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a map of Tensor Metadata which including BackendMetaData for
|
|
||||||
// serialization. For now, it only takes care of `conj` and `neg` bit.
|
|
||||||
inline std::unordered_map<std::string, bool> getTensorMetadata(
|
|
||||||
const at::Tensor& t) {
|
|
||||||
// We don't support serializing `ZeroTensor` as it is not public
|
|
||||||
// facing yet.
|
|
||||||
TORCH_CHECK(
|
|
||||||
!t._is_zerotensor(),
|
|
||||||
"ZeroTensor is not serializable,",
|
|
||||||
" please file an issue if required.");
|
|
||||||
std::unordered_map<std::string, bool> metadata{};
|
|
||||||
|
|
||||||
// Only add meta-data if the value is not default.
|
|
||||||
if (t.is_conj()) {
|
|
||||||
metadata["conj"] = true;
|
|
||||||
}
|
|
||||||
if (t.is_neg()) {
|
|
||||||
metadata["neg"] = true;
|
|
||||||
}
|
|
||||||
// Only add BackendMetaData for custom backend if the function pointer is
|
|
||||||
// registered.
|
|
||||||
int device_type = static_cast<int>(t.device().type());
|
|
||||||
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
|
||||||
if (BackendMetaSerialization[device_type].has_value()) {
|
|
||||||
// Pass the tensor and metadata map references as parameters to the custom
|
|
||||||
// serialization function.
|
|
||||||
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first;
|
|
||||||
fptr(t, metadata);
|
|
||||||
}
|
|
||||||
return metadata;
|
|
||||||
}
|
|
||||||
|
|
||||||
// set Tensor Metadata based on the map.
|
|
||||||
// Refer: getTensorMetadata
|
|
||||||
inline void setTensorMetadata(
|
|
||||||
const at::Tensor& t,
|
|
||||||
std::unordered_map<std::string, bool> metadata) {
|
|
||||||
auto iter_end = metadata.end();
|
|
||||||
auto iter_temp = metadata.find("conj");
|
|
||||||
if (iter_temp != iter_end) {
|
|
||||||
t._set_conj(true);
|
|
||||||
metadata.erase(iter_temp);
|
|
||||||
}
|
|
||||||
iter_temp = metadata.find("neg");
|
|
||||||
if (iter_temp != iter_end) {
|
|
||||||
t._set_neg(true);
|
|
||||||
metadata.erase(iter_temp);
|
|
||||||
}
|
|
||||||
// Only set BackendMetaData for custom backend if the function pointer is
|
|
||||||
// registered.
|
|
||||||
int device_type = static_cast<int>(t.device().type());
|
|
||||||
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
|
||||||
if (BackendMetaSerialization[device_type].has_value()) {
|
|
||||||
// Pass the tensor and metadata map references as parameters to the custom
|
|
||||||
// deserialization function.
|
|
||||||
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second;
|
|
||||||
fptr(t, metadata);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// set Tensor metadata based on the map.
|
|
||||||
// NOTE: This overload is required by unpickler.cpp
|
|
||||||
inline void setTensorMetadata(
|
|
||||||
const at::Tensor& t,
|
|
||||||
const c10::Dict<c10::IValue, c10::IValue>& metadata_idict) {
|
|
||||||
std::unordered_map<std::string, bool> metadata;
|
|
||||||
for (auto& pair : metadata_idict) {
|
|
||||||
auto key = *pair.key().toString();
|
|
||||||
metadata[key] = pair.value().toBool();
|
|
||||||
}
|
|
||||||
setTensorMetadata(t, std::move(metadata));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace torch::jit
|
} // namespace torch::jit
|
||||||
|
117
torch/csrc/jit/serialization/pickler_helper.cpp
Normal file
117
torch/csrc/jit/serialization/pickler_helper.cpp
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/api/function_impl.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickler_helper.h>
|
||||||
|
|
||||||
|
namespace torch::jit {
|
||||||
|
|
||||||
|
WriteableTensorData getWriteableTensorData(
|
||||||
|
const at::Tensor& tensor,
|
||||||
|
bool to_cpu) {
|
||||||
|
WriteableTensorData result;
|
||||||
|
result.tensor_ = tensor;
|
||||||
|
result.size_ = tensor.storage().nbytes();
|
||||||
|
// TODO HIP support
|
||||||
|
if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) {
|
||||||
|
// NB: This new tensor is created to support cuda tensors.
|
||||||
|
// Storages can be mutated when converting tensors from cuda to cpu,
|
||||||
|
// and we need a cpu tensor to copy data from.
|
||||||
|
result.tensor_ =
|
||||||
|
at::empty({0}, tensor.options())
|
||||||
|
.set_(
|
||||||
|
tensor.storage(),
|
||||||
|
/* storage_offset = */ 0,
|
||||||
|
/* size = */
|
||||||
|
{static_cast<int64_t>(
|
||||||
|
tensor.storage().nbytes() / tensor.element_size())},
|
||||||
|
/* stride = */ {1})
|
||||||
|
.cpu();
|
||||||
|
TORCH_CHECK(
|
||||||
|
result.tensor_.storage().nbytes() == result.size_,
|
||||||
|
"Storage tensor size did not match record size");
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
|
||||||
|
// Check that the schemas for __getstate__ and __setstate__ are correct
|
||||||
|
auto getstate = cls->findMethod("__getstate__");
|
||||||
|
if (getstate == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto get_schema = getstate->getSchema();
|
||||||
|
|
||||||
|
// Check __getstate__
|
||||||
|
// __getstate__ is expected to be (self) -> T
|
||||||
|
TORCH_CHECK(
|
||||||
|
get_schema.arguments().size() == 1,
|
||||||
|
"'__getstate__' must have 'self' as its only argument, but found ",
|
||||||
|
get_schema.arguments().size(),
|
||||||
|
" arguments");
|
||||||
|
TORCH_CHECK(
|
||||||
|
get_schema.returns().size() == 1,
|
||||||
|
"'__getstate__' must return 1 value, but found ",
|
||||||
|
get_schema.returns().size());
|
||||||
|
|
||||||
|
// Check __setstate__ if the method exists
|
||||||
|
// __setstate__ is expected to be (self, T) -> None
|
||||||
|
auto setstate = cls->findMethod("__setstate__");
|
||||||
|
if (!setstate) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto set_schema = setstate->getSchema();
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
set_schema.arguments().size() == 2,
|
||||||
|
"'__setstate__' must have 'self' and the state as its "
|
||||||
|
"only arguments, but found ",
|
||||||
|
set_schema.arguments().size(),
|
||||||
|
" arguments");
|
||||||
|
TORCH_CHECK(
|
||||||
|
set_schema.returns().size() == 1,
|
||||||
|
"'__setstate__' must return None, but found ",
|
||||||
|
set_schema.returns().size(),
|
||||||
|
" return values");
|
||||||
|
TORCH_CHECK(
|
||||||
|
set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
|
||||||
|
"'__setstate__' must return None, but found value of type",
|
||||||
|
set_schema.returns().at(0).type()->annotation_str());
|
||||||
|
|
||||||
|
// Check that the return type of __getstate__ matches the input to
|
||||||
|
// __setstate__
|
||||||
|
auto get_type = get_schema.returns().at(0).type();
|
||||||
|
auto set_type = set_schema.arguments().at(1).type();
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
get_type->isSubtypeOf(*set_type),
|
||||||
|
"'__getstate__'s return type (",
|
||||||
|
get_type->annotation_str(),
|
||||||
|
") does not match '__setstate__'s argument type (",
|
||||||
|
set_type->annotation_str(),
|
||||||
|
")");
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
|
||||||
|
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
|
||||||
|
c10::DeviceType::PrivateUse1};
|
||||||
|
return DeviceTypeAllowlist;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<
|
||||||
|
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||||
|
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
||||||
|
GetBackendMetaSerialization() {
|
||||||
|
// The array to save function pointer for BackendMeta serialization.
|
||||||
|
// key is the DeviceType, value is std::pair obj.
|
||||||
|
// value.first represent get function and value.seconde represent set function
|
||||||
|
static std::array<
|
||||||
|
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||||
|
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||||
|
BackendMetaSerialization;
|
||||||
|
return BackendMetaSerialization;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::jit
|
232
torch/csrc/jit/serialization/pickler_helper.h
Normal file
232
torch/csrc/jit/serialization/pickler_helper.h
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <ATen/Utils.h>
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
|
||||||
|
namespace torch::jit {
|
||||||
|
|
||||||
|
// See Python's pickletools.py for a detailed description of each of these codes
|
||||||
|
enum class PickleOpCode : char {
|
||||||
|
MARK = '(',
|
||||||
|
STOP = '.',
|
||||||
|
POP = '0',
|
||||||
|
POP_MARK = '1',
|
||||||
|
DUP = '2',
|
||||||
|
FLOAT = 'F',
|
||||||
|
INT = 'I',
|
||||||
|
BININT = 'J',
|
||||||
|
BININT1 = 'K',
|
||||||
|
LONG = 'L',
|
||||||
|
BININT2 = 'M',
|
||||||
|
NONE = 'N',
|
||||||
|
PERSID = 'P',
|
||||||
|
BINPERSID = 'Q',
|
||||||
|
REDUCE = 'R',
|
||||||
|
STRING = 'S',
|
||||||
|
BINSTRING = 'T',
|
||||||
|
SHORT_BINSTRING = 'U',
|
||||||
|
// NB: Avoid using UNICODE as it is a macro in the Windows API
|
||||||
|
UNICODE_ = 'V',
|
||||||
|
BINUNICODE = 'X',
|
||||||
|
APPEND = 'a',
|
||||||
|
BUILD = 'b',
|
||||||
|
GLOBAL = 'c',
|
||||||
|
DICT = 'd',
|
||||||
|
EMPTY_DICT = '}',
|
||||||
|
APPENDS = 'e',
|
||||||
|
GET = 'g',
|
||||||
|
BINGET = 'h',
|
||||||
|
INST = 'i',
|
||||||
|
LONG_BINGET = 'j',
|
||||||
|
LIST = 'l',
|
||||||
|
EMPTY_LIST = ']',
|
||||||
|
OBJ = 'o',
|
||||||
|
PUT = 'p',
|
||||||
|
BINPUT = 'q',
|
||||||
|
LONG_BINPUT = 'r',
|
||||||
|
SETITEM = 's',
|
||||||
|
TUPLE = 't',
|
||||||
|
EMPTY_TUPLE = ')',
|
||||||
|
SETITEMS = 'u',
|
||||||
|
BINFLOAT = 'G',
|
||||||
|
|
||||||
|
// Protocol 2
|
||||||
|
PROTO = char('\x80'),
|
||||||
|
NEWOBJ = '\x81',
|
||||||
|
EXT1 = '\x82',
|
||||||
|
EXT2 = '\x83',
|
||||||
|
EXT4 = '\x84',
|
||||||
|
TUPLE1 = '\x85',
|
||||||
|
TUPLE2 = '\x86',
|
||||||
|
TUPLE3 = '\x87',
|
||||||
|
NEWTRUE = '\x88',
|
||||||
|
NEWFALSE = '\x89',
|
||||||
|
LONG1 = '\x8a',
|
||||||
|
LONG4 = '\x8b',
|
||||||
|
|
||||||
|
// Protocol 3 (Python 3.x)
|
||||||
|
BINBYTES = 'B',
|
||||||
|
SHORT_BINBYTES = 'C',
|
||||||
|
|
||||||
|
// Protocol 4
|
||||||
|
SHORT_BINUNICODE = char('\x8c'),
|
||||||
|
BINUNICODE8 = '\x8d',
|
||||||
|
BINBYTES8 = '\x8e',
|
||||||
|
EMPTY_SET = '\x8f',
|
||||||
|
ADDITEMS = '\x90',
|
||||||
|
FROZENSET = '\x91',
|
||||||
|
NEWOBJ_EX = '\x92',
|
||||||
|
STACK_GLOBAL = '\x93',
|
||||||
|
MEMOIZE = '\x94',
|
||||||
|
FRAME = '\x95'
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WriteableTensorData {
|
||||||
|
const char* data() const {
|
||||||
|
return static_cast<const char*>(tensor_.storage().data());
|
||||||
|
}
|
||||||
|
size_t sizeInBytes() const {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
size_t nbytes() const {
|
||||||
|
return tensor_.storage().nbytes();
|
||||||
|
}
|
||||||
|
bool storageHasDeleter() const {
|
||||||
|
return tensor_.storage().data_ptr().get_context() != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend TORCH_API WriteableTensorData
|
||||||
|
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu);
|
||||||
|
at::Tensor tensor_;
|
||||||
|
uint64_t size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
|
||||||
|
// if it was CUDA and to_cpu is True.
|
||||||
|
TORCH_API WriteableTensorData
|
||||||
|
getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true);
|
||||||
|
|
||||||
|
// if the cls has __getstate__/__setstate__
|
||||||
|
// assert they have the right schema and return true,
|
||||||
|
// otherwise return false
|
||||||
|
bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
|
||||||
|
|
||||||
|
// Declare BackendMeta serialization and deserialization function pointer types.
|
||||||
|
using BackendMetaPtr = std::function<
|
||||||
|
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
|
||||||
|
|
||||||
|
// A allowlist of device type, currently available is PrivateUse1
|
||||||
|
TORCH_API std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist();
|
||||||
|
|
||||||
|
// Dynamically obtain serialization function pairs
|
||||||
|
// that require the corresponding backend.
|
||||||
|
TORCH_API std::array<
|
||||||
|
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||||
|
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
||||||
|
GetBackendMetaSerialization();
|
||||||
|
|
||||||
|
// Return a map of Tensor Metadata which including BackendMetaData for
|
||||||
|
// serialization. For now, it only takes care of `conj` and `neg` bit.
|
||||||
|
inline std::unordered_map<std::string, bool> getTensorMetadata(
|
||||||
|
const at::Tensor& t) {
|
||||||
|
// We don't support serializing `ZeroTensor` as it is not public
|
||||||
|
// facing yet.
|
||||||
|
TORCH_CHECK(
|
||||||
|
!t._is_zerotensor(),
|
||||||
|
"ZeroTensor is not serializable,",
|
||||||
|
" please file an issue if required.");
|
||||||
|
std::unordered_map<std::string, bool> metadata{};
|
||||||
|
|
||||||
|
// Only add meta-data if the value is not default.
|
||||||
|
if (t.is_conj()) {
|
||||||
|
metadata["conj"] = true;
|
||||||
|
}
|
||||||
|
if (t.is_neg()) {
|
||||||
|
metadata["neg"] = true;
|
||||||
|
}
|
||||||
|
// Only add BackendMetaData for custom backend if the function pointer is
|
||||||
|
// registered.
|
||||||
|
int device_type = static_cast<int>(t.device().type());
|
||||||
|
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
||||||
|
if (BackendMetaSerialization[device_type].has_value()) {
|
||||||
|
// Pass the tensor and metadata map references as parameters to the custom
|
||||||
|
// serialization function.
|
||||||
|
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first;
|
||||||
|
fptr(t, metadata);
|
||||||
|
}
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set Tensor Metadata based on the map.
|
||||||
|
// Refer: getTensorMetadata
|
||||||
|
inline void setTensorMetadata(
|
||||||
|
const at::Tensor& t,
|
||||||
|
std::unordered_map<std::string, bool> metadata) {
|
||||||
|
auto iter_end = metadata.end();
|
||||||
|
auto iter_temp = metadata.find("conj");
|
||||||
|
if (iter_temp != iter_end) {
|
||||||
|
t._set_conj(true);
|
||||||
|
metadata.erase(iter_temp);
|
||||||
|
}
|
||||||
|
iter_temp = metadata.find("neg");
|
||||||
|
if (iter_temp != iter_end) {
|
||||||
|
t._set_neg(true);
|
||||||
|
metadata.erase(iter_temp);
|
||||||
|
}
|
||||||
|
// Only set BackendMetaData for custom backend if the function pointer is
|
||||||
|
// registered.
|
||||||
|
int device_type = static_cast<int>(t.device().type());
|
||||||
|
const auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
||||||
|
if (BackendMetaSerialization[device_type].has_value()) {
|
||||||
|
// Pass the tensor and metadata map references as parameters to the custom
|
||||||
|
// deserialization function.
|
||||||
|
BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second;
|
||||||
|
fptr(t, metadata);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set Tensor metadata based on the map.
|
||||||
|
// NOTE: This overload is required by unpickler.cpp
|
||||||
|
inline void setTensorMetadata(
|
||||||
|
const at::Tensor& t,
|
||||||
|
const c10::Dict<c10::IValue, c10::IValue>& metadata_idict) {
|
||||||
|
std::unordered_map<std::string, bool> metadata;
|
||||||
|
for (auto& pair : metadata_idict) {
|
||||||
|
auto key = *pair.key().toString();
|
||||||
|
metadata[key] = pair.value().toBool();
|
||||||
|
}
|
||||||
|
setTensorMetadata(t, std::move(metadata));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register function pointer of Tensor BackendMetadata for serialization.
|
||||||
|
TORCH_API inline void TensorBackendMetaRegistry(
|
||||||
|
c10::DeviceType t,
|
||||||
|
const BackendMetaPtr& get_fptr,
|
||||||
|
const BackendMetaPtr& set_fptr) {
|
||||||
|
// allowlist verification
|
||||||
|
// Only if the devicetype is in the allowlist,
|
||||||
|
// we allow the serialization extension to be registered for backendmeta data.
|
||||||
|
const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
|
||||||
|
TORCH_CHECK(
|
||||||
|
DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
|
||||||
|
"It is not allowed to register the serialization method ",
|
||||||
|
"of backendMeta data for PrivateUse1. ",
|
||||||
|
"If you have related serialization requirements, ",
|
||||||
|
"please expand the allowlist");
|
||||||
|
// Register function pointer
|
||||||
|
int device_type = static_cast<int>(t);
|
||||||
|
auto& BackendMetaSerialization = GetBackendMetaSerialization();
|
||||||
|
TORCH_CHECK(
|
||||||
|
!BackendMetaSerialization[device_type].has_value(),
|
||||||
|
"The tensor BackendMeta serialization function pointer for ",
|
||||||
|
t,
|
||||||
|
" has been registered.");
|
||||||
|
BackendMetaSerialization[device_type] =
|
||||||
|
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>(
|
||||||
|
std::make_pair(get_fptr, set_fptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::jit
|
@ -5,7 +5,6 @@
|
|||||||
#endif
|
#endif
|
||||||
#include <torch/csrc/jit/api/function_impl.h>
|
#include <torch/csrc/jit/api/function_impl.h>
|
||||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
|
||||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||||
#include <torch/csrc/utils/byte_order.h>
|
#include <torch/csrc/utils/byte_order.h>
|
||||||
|
@ -3,9 +3,10 @@
|
|||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
#include <c10/util/ArrayRef.h>
|
#include <c10/util/ArrayRef.h>
|
||||||
#include <caffe2/serialize/inline_container.h>
|
#include <caffe2/serialize/inline_container.h>
|
||||||
|
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||||
#include <torch/csrc/jit/serialization/pickler.h>
|
#include <torch/csrc/jit/serialization/pickler_helper.h>
|
||||||
|
|
||||||
namespace torch::jit {
|
namespace torch::jit {
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user