mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
118 lines
4.0 KiB
C++
118 lines
4.0 KiB
C++
#ifndef CAFFE2_UTILS_PROTO_UTILS_H_
|
|
#define CAFFE2_UTILS_PROTO_UTILS_H_
|
|
|
|
#include "caffe2/proto/caffe2.pb.h"
|
|
#include "google/protobuf/message.h"
|
|
#include "glog/logging.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
using std::string;
|
|
using ::google::protobuf::Message;
|
|
using ::google::protobuf::MessageLite;
|
|
using std::string;
|
|
|
|
bool ReadProtoFromTextFile(const char* filename, Message* proto);
|
|
inline bool ReadProtoFromTextFile(const string filename, Message* proto) {
|
|
return ReadProtoFromTextFile(filename.c_str(), proto);
|
|
}
|
|
|
|
void WriteProtoToTextFile(const Message& proto, const char* filename);
|
|
inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
|
|
return WriteProtoToTextFile(proto, filename.c_str());
|
|
}
|
|
|
|
// Text format MessageLite wrappers: these functions do nothing but just
|
|
// allowing things to compile. It will produce a runtime error if you are using
|
|
// MessageLite but still want text support.
|
|
inline bool ReadProtoFromTextFile(const char* filename, MessageLite* proto) {
|
|
LOG(FATAL) << "If you are running lite version, you should not be "
|
|
<< "calling any text-format protobuffers.";
|
|
return false; // Just to suppress compiler warning.
|
|
}
|
|
inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) {
|
|
return ReadProtoFromTextFile(filename.c_str(), proto);
|
|
}
|
|
|
|
inline void WriteProtoToTextFile(const MessageLite& proto,
|
|
const char* filename) {
|
|
LOG(FATAL) << "If you are running lite version, you should not be "
|
|
<< "calling any text-format protobuffers.";
|
|
}
|
|
inline void WriteProtoToTextFile(const MessageLite& proto,
|
|
const string& filename) {
|
|
return WriteProtoToTextFile(proto, filename.c_str());
|
|
}
|
|
|
|
bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto);
|
|
inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) {
|
|
return ReadProtoFromBinaryFile(filename.c_str(), proto);
|
|
}
|
|
|
|
void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename);
|
|
inline void WriteProtoToBinaryFile(const MessageLite& proto,
|
|
const string& filename) {
|
|
return WriteProtoToBinaryFile(proto, filename.c_str());
|
|
}
|
|
|
|
// Read Proto from a file, letting the code figure out if it is text or binary.
|
|
inline bool ReadProtoFromFile(const char* filename, Message* proto) {
|
|
return (ReadProtoFromBinaryFile(filename, proto) ||
|
|
ReadProtoFromTextFile(filename, proto));
|
|
}
|
|
inline bool ReadProtoFromFile(const string& filename, Message* proto) {
|
|
return ReadProtoFromFile(filename.c_str(), proto);
|
|
}
|
|
|
|
inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) {
|
|
return (ReadProtoFromBinaryFile(filename, proto) ||
|
|
ReadProtoFromTextFile(filename, proto));
|
|
}
|
|
inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
|
|
return ReadProtoFromFile(filename.c_str(), proto);
|
|
}
|
|
|
|
inline const Argument& GetArgument(const OperatorDef& def, const string& name) {
|
|
for (const Argument& arg : def.arg()) {
|
|
if (arg.name() == name) {
|
|
return arg;
|
|
}
|
|
}
|
|
LOG(FATAL) << "Argument named " << name << " does not exist.";
|
|
}
|
|
|
|
inline Argument* GetMutableArgument(
|
|
const string& name, const bool create_if_missing, OperatorDef* def) {
|
|
for (int i = 0; i < def->arg_size(); ++i) {
|
|
if (def->arg(i).name() == name) {
|
|
return def->mutable_arg(i);
|
|
}
|
|
}
|
|
// If no argument of the right name is found...
|
|
if (create_if_missing) {
|
|
Argument* arg = def->add_arg();
|
|
arg->set_name(name);
|
|
return arg;
|
|
} else {
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
// A coarse support for the Any message in proto3. I am a bit afraid of going
|
|
// directly to proto3 yet, so let's do this first...
|
|
class Any {
|
|
template <typename MessageType>
|
|
static MessageType Parse(const Argument& arg) {
|
|
CHECK_EQ(arg.strings_size(), 1)
|
|
<< "An Any object should parse from a single string.";
|
|
MessageType message;
|
|
CHECK(message.ParseFromString(arg.strings(0)))
|
|
<< "Faild to parse from the string";
|
|
return message;
|
|
}
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_UTILS_PROTO_UTILS_H_
|