Files
pytorch/caffe2/core/int8_serialization.cc
Michael Antonov a6949abb15 Guard all Caffe2 protobuf string serializations with CAFFE_ENFORCE (fixed reverted bug) (#12848)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12848

Updated all non-test uses of protobuf::MessageLite::SerializeAsString to call
SerializeAsString_EnforceCheck so that the return value is checked and can
throw an exception if failing.

Most of the affected code was called from classes derived from  BlobSerializeBase.
Didn't touch most tests and ENFORCE calls because they usually do checks
anyway.

Original commit changeset: c0760e73ecc7

Reviewed By: dzhulgakov

Differential Revision: D10453456

fbshipit-source-id: d2f2b7b4578e721924354149f08f627c7e3bf070
2018-10-23 16:21:26 -07:00

107 lines
3.1 KiB
C++

#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/common.h"
#include "caffe2/core/context.h"
#include "caffe2/core/tensor_int8.h"
#include "caffe2/core/typeid.h"
#include "caffe2/core/types.h"
namespace caffe2 {
namespace int8 {
class Int8TensorCPUSerializer : public BlobSerializerBase {
public:
void Serialize(
const void* pointer,
TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
CAFFE_ENFORCE(typeMeta.Match<Int8TensorCPU>());
const auto& tensor = *static_cast<const Int8TensorCPU*>(pointer);
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("Int8TensorCPU");
QTensorProto& proto = *blob_proto.mutable_qtensor();
proto.set_name(name);
for (int i = 0; i < tensor.t.ndim(); ++i) {
proto.add_dims(tensor.t.dim32(i));
}
proto.set_precision(8);
proto.set_scale(tensor.scale);
proto.set_bias(tensor.zero_point);
proto.set_is_signed(false);
const TensorProto::DataType data_type = TypeMetaToDataType(tensor.t.meta());
proto.set_data_type(data_type);
switch (data_type) {
case TensorProto_DataType_INT32:
detail::CopyToProtoAsIs(
tensor.t.size(),
tensor.t.template data<int32_t>(),
proto.mutable_data(),
&this->context_);
break;
case TensorProto_DataType_UINT8:
detail::CopyToProtoWithCast(
tensor.t.size(),
tensor.t.template data<uint8_t>(),
proto.mutable_data(),
&this->context_);
break;
default:
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
}
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
private:
CPUContext context_;
};
class Int8TensorCPUDeserializer : public TensorDeserializer {
public:
void Deserialize(const BlobProto& blob_proto, Blob* blob) override {
const QTensorProto& proto = blob_proto.qtensor();
Int8TensorCPU* tensor = blob->template GetMutable<Int8TensorCPU>();
tensor->scale = proto.scale();
tensor->zero_point = proto.bias();
vector<int> dims;
for (const int d : proto.dims()) {
dims.push_back(d);
}
tensor->t.Resize(dims);
switch (proto.data_type()) {
case TensorProto_DataType_INT32:
detail::CopyFromProtoAsIs(
tensor->t.size(),
proto.data(),
tensor->t.template mutable_data<int32_t>(),
&this->context_);
break;
case TensorProto_DataType_UINT8:
detail::CopyFromProtoWithCast(
tensor->t.size(),
proto.data(),
tensor->t.template mutable_data<uint8_t>(),
&this->context_);
break;
default:
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
}
}
private:
CPUContext context_;
};
} // namespace int8
namespace {
REGISTER_BLOB_SERIALIZER(
(TypeMeta::Id<int8::Int8TensorCPU>()),
int8::Int8TensorCPUSerializer);
REGISTER_BLOB_DESERIALIZER(Int8TensorCPU, int8::Int8TensorCPUDeserializer);
} // namespace
} // namespace caffe2