mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12848 Updated all non-test uses of protobuf::MessageLite::SerializeAsString to call SerializeAsString_EnforceCheck so that the return value is checked and can throw an exception if failing. Most of the affected code was called from classes derived from BlobSerializeBase. Didn't touch most tests and ENFORCE calls because they usually do checks anyway. Original commit changeset: c0760e73ecc7 Reviewed By: dzhulgakov Differential Revision: D10453456 fbshipit-source-id: d2f2b7b4578e721924354149f08f627c7e3bf070
107 lines
3.1 KiB
C++
107 lines
3.1 KiB
C++
#include "caffe2/core/blob_serialization.h"
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/tensor_int8.h"
|
|
#include "caffe2/core/typeid.h"
|
|
#include "caffe2/core/types.h"
|
|
|
|
namespace caffe2 {
|
|
namespace int8 {
|
|
|
|
class Int8TensorCPUSerializer : public BlobSerializerBase {
|
|
public:
|
|
void Serialize(
|
|
const void* pointer,
|
|
TypeMeta typeMeta,
|
|
const string& name,
|
|
SerializationAcceptor acceptor) override {
|
|
CAFFE_ENFORCE(typeMeta.Match<Int8TensorCPU>());
|
|
const auto& tensor = *static_cast<const Int8TensorCPU*>(pointer);
|
|
BlobProto blob_proto;
|
|
blob_proto.set_name(name);
|
|
blob_proto.set_type("Int8TensorCPU");
|
|
QTensorProto& proto = *blob_proto.mutable_qtensor();
|
|
proto.set_name(name);
|
|
for (int i = 0; i < tensor.t.ndim(); ++i) {
|
|
proto.add_dims(tensor.t.dim32(i));
|
|
}
|
|
proto.set_precision(8);
|
|
proto.set_scale(tensor.scale);
|
|
proto.set_bias(tensor.zero_point);
|
|
proto.set_is_signed(false);
|
|
|
|
const TensorProto::DataType data_type = TypeMetaToDataType(tensor.t.meta());
|
|
proto.set_data_type(data_type);
|
|
switch (data_type) {
|
|
case TensorProto_DataType_INT32:
|
|
detail::CopyToProtoAsIs(
|
|
tensor.t.size(),
|
|
tensor.t.template data<int32_t>(),
|
|
proto.mutable_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_UINT8:
|
|
detail::CopyToProtoWithCast(
|
|
tensor.t.size(),
|
|
tensor.t.template data<uint8_t>(),
|
|
proto.mutable_data(),
|
|
&this->context_);
|
|
break;
|
|
default:
|
|
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
|
|
}
|
|
|
|
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
|
}
|
|
|
|
private:
|
|
CPUContext context_;
|
|
};
|
|
|
|
class Int8TensorCPUDeserializer : public TensorDeserializer {
|
|
public:
|
|
void Deserialize(const BlobProto& blob_proto, Blob* blob) override {
|
|
const QTensorProto& proto = blob_proto.qtensor();
|
|
Int8TensorCPU* tensor = blob->template GetMutable<Int8TensorCPU>();
|
|
tensor->scale = proto.scale();
|
|
tensor->zero_point = proto.bias();
|
|
vector<int> dims;
|
|
for (const int d : proto.dims()) {
|
|
dims.push_back(d);
|
|
}
|
|
tensor->t.Resize(dims);
|
|
switch (proto.data_type()) {
|
|
case TensorProto_DataType_INT32:
|
|
detail::CopyFromProtoAsIs(
|
|
tensor->t.size(),
|
|
proto.data(),
|
|
tensor->t.template mutable_data<int32_t>(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_UINT8:
|
|
detail::CopyFromProtoWithCast(
|
|
tensor->t.size(),
|
|
proto.data(),
|
|
tensor->t.template mutable_data<uint8_t>(),
|
|
&this->context_);
|
|
break;
|
|
default:
|
|
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
|
|
}
|
|
}
|
|
|
|
private:
|
|
CPUContext context_;
|
|
};
|
|
|
|
} // namespace int8
|
|
|
|
namespace {
|
|
REGISTER_BLOB_SERIALIZER(
|
|
(TypeMeta::Id<int8::Int8TensorCPU>()),
|
|
int8::Int8TensorCPUSerializer);
|
|
REGISTER_BLOB_DESERIALIZER(Int8TensorCPU, int8::Int8TensorCPUDeserializer);
|
|
} // namespace
|
|
|
|
} // namespace caffe2
|