Guard all Caffe2 protobuf string serializations with CAFFE_ENFORCE (fixed reverted bug) (#12848)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12848

Updated all non-test uses of protobuf::MessageLite::SerializeAsString to call
SerializeAsString_EnforceCheck so that the return value is checked and can
throw an exception if failing.

Most of the affected code was called from classes derived from  BlobSerializeBase.
Didn't touch most tests and ENFORCE calls because they usually do checks
anyway.

Original commit changeset: c0760e73ecc7

Reviewed By: dzhulgakov

Differential Revision: D10453456

fbshipit-source-id: d2f2b7b4578e721924354149f08f627c7e3bf070
This commit is contained in:
Michael Antonov
2018-10-23 16:19:23 -07:00
committed by Facebook Github Bot
parent dd00c2997f
commit a6949abb15
14 changed files with 58 additions and 18 deletions

View File

@ -79,7 +79,7 @@ int main(int argc, char** argv) {
data->add_dims(datum.channels());
data->set_byte_data(buffer, datum.data().size());
}
transaction->Put(cursor->key(), protos.SerializeAsString());
transaction->Put(cursor->key(), SerializeAsString_EnforceCheck(protos));
if (++count % FLAGS_batch_size == 0) {
transaction->Commit();
LOG(INFO) << "Converted " << count << " items so far.";
@ -88,4 +88,3 @@ int main(int argc, char** argv) {
LOG(INFO) << "A total of " << count << " items processed.";
return 0;
}

View File

@ -47,7 +47,7 @@ class StringSerializer : public BlobSerializerBase {
blob_proto.set_name(name);
blob_proto.set_type("std::string");
blob_proto.set_content(*static_cast<const std::string*>(pointer));
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};
@ -134,7 +134,7 @@ void TensorSerializer::SerializeWithChunkSize(
tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
acceptor(
c10::str(name, kChunkIdSeparator, chunkStart / chunk_size),
blob_proto.SerializeAsString());
SerializeBlobProtoAsString_EnforceCheck(blob_proto));
};
#ifndef __ANDROID__
@ -543,6 +543,25 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
context->FinishDeviceComputation();
}
////////////////////////////////////////////////////////////////////////////////
// Serialization Helpers
////////////////////////////////////////////////////////////////////////////////
std::string SerializeAsString_EnforceCheck(
const google::protobuf::MessageLite& msg,
const char* error_location) {
std::string serialize_output;
bool result = msg.SerializeToString(&serialize_output);
if (!error_location) {
CAFFE_ENFORCE(result, "protobuf::SerializeToString failed");
} else {
CAFFE_ENFORCE(result,
"protobuf::SerializeToString failed for ", error_location);
}
return serialize_output;
}
namespace {
// Serialize Tensor
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<Tensor>()), TensorSerializer);

View File

@ -184,6 +184,24 @@ inline void CopyFromProtoWithCast(
}
} // namespace detail
////////////////////////////////////////////////////////////////////////////////
// Serialization Helpers
////////////////////////////////////////////////////////////////////////////////
// Converts MessageLite to string while also checking that SerializeAsString
// succeeds. Pass description of class/function of the call if you'd
// like it appended to the error message.
CAFFE2_API std::string SerializeAsString_EnforceCheck(
const google::protobuf::MessageLite&,
const char* error_location = nullptr);
// Convert BlobProto to string with success checks.
inline std::string SerializeBlobProtoAsString_EnforceCheck(
const BlobProto& blob) {
return SerializeAsString_EnforceCheck(blob, blob.name().c_str());
}
} // namespace caffe2
#endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_

View File

@ -65,7 +65,7 @@ class BlobTestFooSerializer : public BlobSerializerBase {
reinterpret_cast<const char*>(
&static_cast<const BlobTestFoo*>(pointer)->val),
sizeof(int32_t)));
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};

View File

@ -186,8 +186,8 @@ void DBReaderSerializer::Serialize(
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("DBReader");
blob_proto.set_content(proto.SerializeAsString());
acceptor(name, blob_proto.SerializeAsString());
blob_proto.set_content(SerializeAsString_EnforceCheck(proto));
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
void DBReaderDeserializer::Deserialize(const BlobProto& proto, Blob* blob) {

View File

@ -51,7 +51,7 @@ class Int8TensorCPUSerializer : public BlobSerializerBase {
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
}
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
private:

View File

@ -55,7 +55,7 @@ void QTensorSerializer<Context>::Serialize(
proto.set_is_signed(qtensor.is_signed());
detail::CopyToProtoWithCast(
qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_);
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
template <class Context>

View File

@ -20,7 +20,10 @@ class ProtoDBCursor : public Cursor {
void SeekToFirst() override { iter_ = 0; }
void Next() override { ++iter_; }
string key() override { return proto_->protos(iter_).name(); }
string value() override { return proto_->protos(iter_).SerializeAsString(); }
string value() override {
return
SerializeAsString_EnforceCheck(proto_->protos(iter_), "ProtoDBCursor");
}
bool Valid() override { return iter_ < proto_->protos_size(); }
private:

View File

@ -155,7 +155,7 @@ class CounterSerializer : public BlobSerializerBase {
proto.add_int64_data(
(*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer))
->retrieve());
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};

View File

@ -1451,7 +1451,7 @@ class TreeCursorSerializer : public BlobSerializerBase {
}
blob_proto.set_content(os.str());
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};
@ -1513,7 +1513,7 @@ void SharedTensorVectorPtrSerializer::Serialize(
blob_proto.set_name(name);
blob_proto.set_type("std::shared_ptr<std::vector<TensorCPU>>");
blob_proto.set_content("");
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
};
void SharedTensorVectorPtrDeserializer::Deserialize(

View File

@ -381,7 +381,7 @@ class IndexSerializer : public BlobSerializerBase {
os << base->maxElements() << " " << base->isFrozen();
blob_proto.set_content(os.str());
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
private:

View File

@ -225,8 +225,8 @@ class MapSerializer : public BlobSerializerBase {
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type(MapTypeTraits<KEY_T, VALUE_T>::MapTypeName());
blob_proto.set_content(tensor_protos.SerializeAsString());
acceptor(name, blob_proto.SerializeAsString());
blob_proto.set_content(SerializeAsString_EnforceCheck(tensor_protos));
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};

View File

@ -586,7 +586,8 @@ void addObjectMethods(py::module& m) {
const auto& meta = GetGradientForOp(def, output_gradients);
std::vector<py::bytes> grad_ops;
for (const auto& op : meta.ops_) {
grad_ops.push_back(op.SerializeAsString());
grad_ops.push_back(
SerializeAsString_EnforceCheck(op, "addObjectMethods"));
}
return std::pair<std::vector<py::bytes>, std::vector<GradientWrapper>>{
grad_ops, meta.g_input_};

View File

@ -17,7 +17,7 @@ void MutexSerializer::Serialize(
blob_proto.set_name(name);
blob_proto.set_type("std::unique_ptr<std::mutex>");
blob_proto.set_content("");
acceptor(name, blob_proto.SerializeAsString());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) {