mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Guard all Caffe2 protobuf string serializations with CAFFE_ENFORCE (fixed reverted bug) (#12848)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12848 Updated all non-test uses of protobuf::MessageLite::SerializeAsString to call SerializeAsString_EnforceCheck so that the return value is checked and can throw an exception if failing. Most of the affected code was called from classes derived from BlobSerializeBase. Didn't touch most tests and ENFORCE calls because they usually do checks anyway. Original commit changeset: c0760e73ecc7 Reviewed By: dzhulgakov Differential Revision: D10453456 fbshipit-source-id: d2f2b7b4578e721924354149f08f627c7e3bf070
This commit is contained in:
committed by
Facebook Github Bot
parent
dd00c2997f
commit
a6949abb15
@ -79,7 +79,7 @@ int main(int argc, char** argv) {
|
||||
data->add_dims(datum.channels());
|
||||
data->set_byte_data(buffer, datum.data().size());
|
||||
}
|
||||
transaction->Put(cursor->key(), protos.SerializeAsString());
|
||||
transaction->Put(cursor->key(), SerializeAsString_EnforceCheck(protos));
|
||||
if (++count % FLAGS_batch_size == 0) {
|
||||
transaction->Commit();
|
||||
LOG(INFO) << "Converted " << count << " items so far.";
|
||||
@ -88,4 +88,3 @@ int main(int argc, char** argv) {
|
||||
LOG(INFO) << "A total of " << count << " items processed.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ class StringSerializer : public BlobSerializerBase {
|
||||
blob_proto.set_name(name);
|
||||
blob_proto.set_type("std::string");
|
||||
blob_proto.set_content(*static_cast<const std::string*>(pointer));
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
};
|
||||
|
||||
@ -134,7 +134,7 @@ void TensorSerializer::SerializeWithChunkSize(
|
||||
tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
|
||||
acceptor(
|
||||
c10::str(name, kChunkIdSeparator, chunkStart / chunk_size),
|
||||
blob_proto.SerializeAsString());
|
||||
SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
};
|
||||
|
||||
#ifndef __ANDROID__
|
||||
@ -543,6 +543,25 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
|
||||
context->FinishDeviceComputation();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Serialization Helpers
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::string SerializeAsString_EnforceCheck(
|
||||
const google::protobuf::MessageLite& msg,
|
||||
const char* error_location) {
|
||||
std::string serialize_output;
|
||||
bool result = msg.SerializeToString(&serialize_output);
|
||||
if (!error_location) {
|
||||
CAFFE_ENFORCE(result, "protobuf::SerializeToString failed");
|
||||
} else {
|
||||
CAFFE_ENFORCE(result,
|
||||
"protobuf::SerializeToString failed for ", error_location);
|
||||
}
|
||||
return serialize_output;
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
// Serialize Tensor
|
||||
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<Tensor>()), TensorSerializer);
|
||||
|
@ -184,6 +184,24 @@ inline void CopyFromProtoWithCast(
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Serialization Helpers
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Converts MessageLite to string while also checking that SerializeAsString
|
||||
// succeeds. Pass description of class/function of the call if you'd
|
||||
// like it appended to the error message.
|
||||
CAFFE2_API std::string SerializeAsString_EnforceCheck(
|
||||
const google::protobuf::MessageLite&,
|
||||
const char* error_location = nullptr);
|
||||
|
||||
// Convert BlobProto to string with success checks.
|
||||
inline std::string SerializeBlobProtoAsString_EnforceCheck(
|
||||
const BlobProto& blob) {
|
||||
return SerializeAsString_EnforceCheck(blob, blob.name().c_str());
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_
|
||||
|
@ -65,7 +65,7 @@ class BlobTestFooSerializer : public BlobSerializerBase {
|
||||
reinterpret_cast<const char*>(
|
||||
&static_cast<const BlobTestFoo*>(pointer)->val),
|
||||
sizeof(int32_t)));
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -186,8 +186,8 @@ void DBReaderSerializer::Serialize(
|
||||
BlobProto blob_proto;
|
||||
blob_proto.set_name(name);
|
||||
blob_proto.set_type("DBReader");
|
||||
blob_proto.set_content(proto.SerializeAsString());
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
blob_proto.set_content(SerializeAsString_EnforceCheck(proto));
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
|
||||
void DBReaderDeserializer::Deserialize(const BlobProto& proto, Blob* blob) {
|
||||
|
@ -51,7 +51,7 @@ class Int8TensorCPUSerializer : public BlobSerializerBase {
|
||||
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
|
||||
}
|
||||
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -55,7 +55,7 @@ void QTensorSerializer<Context>::Serialize(
|
||||
proto.set_is_signed(qtensor.is_signed());
|
||||
detail::CopyToProtoWithCast(
|
||||
qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_);
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
|
||||
template <class Context>
|
||||
|
@ -20,7 +20,10 @@ class ProtoDBCursor : public Cursor {
|
||||
void SeekToFirst() override { iter_ = 0; }
|
||||
void Next() override { ++iter_; }
|
||||
string key() override { return proto_->protos(iter_).name(); }
|
||||
string value() override { return proto_->protos(iter_).SerializeAsString(); }
|
||||
string value() override {
|
||||
return
|
||||
SerializeAsString_EnforceCheck(proto_->protos(iter_), "ProtoDBCursor");
|
||||
}
|
||||
bool Valid() override { return iter_ < proto_->protos_size(); }
|
||||
|
||||
private:
|
||||
|
@ -155,7 +155,7 @@ class CounterSerializer : public BlobSerializerBase {
|
||||
proto.add_int64_data(
|
||||
(*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer))
|
||||
->retrieve());
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1451,7 +1451,7 @@ class TreeCursorSerializer : public BlobSerializerBase {
|
||||
}
|
||||
blob_proto.set_content(os.str());
|
||||
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
};
|
||||
|
||||
@ -1513,7 +1513,7 @@ void SharedTensorVectorPtrSerializer::Serialize(
|
||||
blob_proto.set_name(name);
|
||||
blob_proto.set_type("std::shared_ptr<std::vector<TensorCPU>>");
|
||||
blob_proto.set_content("");
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
};
|
||||
|
||||
void SharedTensorVectorPtrDeserializer::Deserialize(
|
||||
|
@ -381,7 +381,7 @@ class IndexSerializer : public BlobSerializerBase {
|
||||
os << base->maxElements() << " " << base->isFrozen();
|
||||
blob_proto.set_content(os.str());
|
||||
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -225,8 +225,8 @@ class MapSerializer : public BlobSerializerBase {
|
||||
BlobProto blob_proto;
|
||||
blob_proto.set_name(name);
|
||||
blob_proto.set_type(MapTypeTraits<KEY_T, VALUE_T>::MapTypeName());
|
||||
blob_proto.set_content(tensor_protos.SerializeAsString());
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
blob_proto.set_content(SerializeAsString_EnforceCheck(tensor_protos));
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -586,7 +586,8 @@ void addObjectMethods(py::module& m) {
|
||||
const auto& meta = GetGradientForOp(def, output_gradients);
|
||||
std::vector<py::bytes> grad_ops;
|
||||
for (const auto& op : meta.ops_) {
|
||||
grad_ops.push_back(op.SerializeAsString());
|
||||
grad_ops.push_back(
|
||||
SerializeAsString_EnforceCheck(op, "addObjectMethods"));
|
||||
}
|
||||
return std::pair<std::vector<py::bytes>, std::vector<GradientWrapper>>{
|
||||
grad_ops, meta.g_input_};
|
||||
|
@ -17,7 +17,7 @@ void MutexSerializer::Serialize(
|
||||
blob_proto.set_name(name);
|
||||
blob_proto.set_type("std::unique_ptr<std::mutex>");
|
||||
blob_proto.set_content("");
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
|
||||
void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) {
|
||||
|
Reference in New Issue
Block a user