diff --git a/binaries/convert_caffe_image_db.cc b/binaries/convert_caffe_image_db.cc index de7efbf65b24..dca13d6e9737 100644 --- a/binaries/convert_caffe_image_db.cc +++ b/binaries/convert_caffe_image_db.cc @@ -79,7 +79,7 @@ int main(int argc, char** argv) { data->add_dims(datum.channels()); data->set_byte_data(buffer, datum.data().size()); } - transaction->Put(cursor->key(), protos.SerializeAsString()); + transaction->Put(cursor->key(), SerializeAsString_EnforceCheck(protos)); if (++count % FLAGS_batch_size == 0) { transaction->Commit(); LOG(INFO) << "Converted " << count << " items so far."; @@ -88,4 +88,3 @@ int main(int argc, char** argv) { LOG(INFO) << "A total of " << count << " items processed."; return 0; } - diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc index 281b5bd4f782..82dedf63d652 100644 --- a/caffe2/core/blob_serialization.cc +++ b/caffe2/core/blob_serialization.cc @@ -47,7 +47,7 @@ class StringSerializer : public BlobSerializerBase { blob_proto.set_name(name); blob_proto.set_type("std::string"); blob_proto.set_content(*static_cast(pointer)); - acceptor(name, blob_proto.SerializeAsString()); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } }; @@ -134,7 +134,7 @@ void TensorSerializer::SerializeWithChunkSize( tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size); acceptor( c10::str(name, kChunkIdSeparator, chunkStart / chunk_size), - blob_proto.SerializeAsString()); + SerializeBlobProtoAsString_EnforceCheck(blob_proto)); }; #ifndef __ANDROID__ @@ -543,6 +543,25 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) { context->FinishDeviceComputation(); } +//////////////////////////////////////////////////////////////////////////////// +// Serialization Helpers +//////////////////////////////////////////////////////////////////////////////// + +std::string SerializeAsString_EnforceCheck( + const google::protobuf::MessageLite& msg, + const char* error_location) { + std::string serialize_output; + bool result = msg.SerializeToString(&serialize_output); + if (!error_location) { + CAFFE_ENFORCE(result, "protobuf::SerializeToString failed"); + } else { + CAFFE_ENFORCE(result, + "protobuf::SerializeToString failed for ", error_location); + } + return serialize_output; +} + + namespace { // Serialize Tensor REGISTER_BLOB_SERIALIZER((TypeMeta::Id()), TensorSerializer); diff --git a/caffe2/core/blob_serialization.h b/caffe2/core/blob_serialization.h index 2c2d59071158..ccb86f5dae52 100644 --- a/caffe2/core/blob_serialization.h +++ b/caffe2/core/blob_serialization.h @@ -184,6 +184,24 @@ inline void CopyFromProtoWithCast( } } // namespace detail + +//////////////////////////////////////////////////////////////////////////////// +// Serialization Helpers +//////////////////////////////////////////////////////////////////////////////// + +// Converts MessageLite to string while also checking that SerializeAsString +// succeeds. Pass description of class/function of the call if you'd +// like it appended to the error message. +CAFFE2_API std::string SerializeAsString_EnforceCheck( + const google::protobuf::MessageLite&, + const char* error_location = nullptr); + +// Convert BlobProto to string with success checks. +inline std::string SerializeBlobProtoAsString_EnforceCheck( + const BlobProto& blob) { + return SerializeAsString_EnforceCheck(blob, blob.name().c_str()); +} + } // namespace caffe2 #endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_ diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc index c65b860bcb41..662eab767378 100644 --- a/caffe2/core/blob_test.cc +++ b/caffe2/core/blob_test.cc @@ -65,7 +65,7 @@ class BlobTestFooSerializer : public BlobSerializerBase { reinterpret_cast( &static_cast(pointer)->val), sizeof(int32_t))); - acceptor(name, blob_proto.SerializeAsString()); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } }; diff --git a/caffe2/core/db.cc b/caffe2/core/db.cc index 67b0f1ffe2c1..16c6509299ea 100644 --- a/caffe2/core/db.cc +++ b/caffe2/core/db.cc @@ -186,8 +186,8 @@ void DBReaderSerializer::Serialize( BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type("DBReader"); - blob_proto.set_content(proto.SerializeAsString()); - acceptor(name, blob_proto.SerializeAsString()); + blob_proto.set_content(SerializeAsString_EnforceCheck(proto)); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } void DBReaderDeserializer::Deserialize(const BlobProto& proto, Blob* blob) { diff --git a/caffe2/core/int8_serialization.cc b/caffe2/core/int8_serialization.cc index 7a18e16a2b30..dc22b12a9913 100644 --- a/caffe2/core/int8_serialization.cc +++ b/caffe2/core/int8_serialization.cc @@ -51,7 +51,7 @@ class Int8TensorCPUSerializer : public BlobSerializerBase { CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU"); } - acceptor(name, blob_proto.SerializeAsString()); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } private: diff --git a/caffe2/core/qtensor_serialization.h b/caffe2/core/qtensor_serialization.h index d9881030f1b8..007174368a44 100644 --- a/caffe2/core/qtensor_serialization.h +++ b/caffe2/core/qtensor_serialization.h @@ -55,7 +55,7 @@ void QTensorSerializer::Serialize( proto.set_is_signed(qtensor.is_signed()); detail::CopyToProtoWithCast( qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_); - acceptor(name, blob_proto.SerializeAsString()); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } template diff --git a/caffe2/db/protodb.cc b/caffe2/db/protodb.cc index fdaaaf57f171..68b74724a7cd 100644 --- a/caffe2/db/protodb.cc +++ b/caffe2/db/protodb.cc @@ -20,7 +20,10 @@ class ProtoDBCursor : public Cursor { void SeekToFirst() override { iter_ = 0; } void Next() override { ++iter_; } string key() override { return proto_->protos(iter_).name(); } - string value() override { return proto_->protos(iter_).SerializeAsString(); } + string value() override { + return + SerializeAsString_EnforceCheck(proto_->protos(iter_), "ProtoDBCursor"); + } bool Valid() override { return iter_ < proto_->protos_size(); } private: diff --git a/caffe2/operators/counter_ops.cc b/caffe2/operators/counter_ops.cc index 79a6b51057e4..2a2278c3132c 100644 --- a/caffe2/operators/counter_ops.cc +++ b/caffe2/operators/counter_ops.cc @@ -155,7 +155,7 @@ class CounterSerializer : public BlobSerializerBase { proto.add_int64_data( (*static_cast>*>(pointer)) ->retrieve()); - acceptor(name, blob_proto.SerializeAsString()); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } }; diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc index 87ed0433c2c5..b0a34f813b5f 100644 --- a/caffe2/operators/dataset_ops.cc +++ b/caffe2/operators/dataset_ops.cc @@ -1451,7 +1451,7 @@ class TreeCursorSerializer : public BlobSerializerBase { } blob_proto.set_content(os.str()); - acceptor(name, blob_proto.SerializeAsString()); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } }; @@ -1513,7 +1513,7 @@ void SharedTensorVectorPtrSerializer::Serialize( blob_proto.set_name(name); blob_proto.set_type("std::shared_ptr>"); blob_proto.set_content(""); - acceptor(name, blob_proto.SerializeAsString()); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); }; void SharedTensorVectorPtrDeserializer::Deserialize( diff --git a/caffe2/operators/index_ops.cc b/caffe2/operators/index_ops.cc index b6da99e99e08..5c6488d59eac 100644 --- a/caffe2/operators/index_ops.cc +++ b/caffe2/operators/index_ops.cc @@ -381,7 +381,7 @@ class IndexSerializer : public BlobSerializerBase { os << base->maxElements() << " " << base->isFrozen(); blob_proto.set_content(os.str()); - acceptor(name, blob_proto.SerializeAsString()); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } private: diff --git a/caffe2/operators/map_ops.h b/caffe2/operators/map_ops.h index 7b64808709ee..fa2c4a865bde 100644 --- a/caffe2/operators/map_ops.h +++ b/caffe2/operators/map_ops.h @@ -225,8 +225,8 @@ class MapSerializer : public BlobSerializerBase { BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type(MapTypeTraits::MapTypeName()); - blob_proto.set_content(tensor_protos.SerializeAsString()); - acceptor(name, blob_proto.SerializeAsString()); + blob_proto.set_content(SerializeAsString_EnforceCheck(tensor_protos)); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } }; diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index cd8549f86682..78fc118d6fd6 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -586,7 +586,8 @@ void addObjectMethods(py::module& m) { const auto& meta = GetGradientForOp(def, output_gradients); std::vector grad_ops; for (const auto& op : meta.ops_) { - grad_ops.push_back(op.SerializeAsString()); + grad_ops.push_back( + SerializeAsString_EnforceCheck(op, "addObjectMethods")); } return std::pair, std::vector>{ grad_ops, meta.g_input_}; diff --git a/caffe2/sgd/iter_op.cc b/caffe2/sgd/iter_op.cc index 222b20fb6f62..8b851e0ba151 100644 --- a/caffe2/sgd/iter_op.cc +++ b/caffe2/sgd/iter_op.cc @@ -17,7 +17,7 @@ void MutexSerializer::Serialize( blob_proto.set_name(name); blob_proto.set_type("std::unique_ptr"); blob_proto.set_content(""); - acceptor(name, blob_proto.SerializeAsString()); + acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); } void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) {