Some utility function changes

This commit is contained in:
Yangqing Jia
2015-07-28 14:08:27 -07:00
parent 45355ae79e
commit a07c255d16
3 changed files with 34 additions and 5 deletions

View File

@ -39,15 +39,19 @@ class OperatorBase {
vector<T> GetRepeatedArgument(const string& name);
template <typename MessageType>
MessageType GetAnyMessageArgument(const string& name) {
MessageType GetMessageArgument(const string& name) {
CHECK(arg_map_.count(name)) << "Cannot find parameter named " << name;
MessageType message;
CHECK(message.ParseFromString(arg_map_[name]->s()))
<< "Faild to parse content from the string";
if (arg_map_[name]->has_s()) {
CHECK(message.ParseFromString(arg_map_[name]->s()))
<< "Faild to parse content from the string";
} else {
VLOG(1) << "Return empty message for parameter " << name;
}
return message;
}
template <typename MessageType>
vector<MessageType> GetAnyRepeatedMessageArgument(const string& name) {
vector<MessageType> GetRepeatedMessageArgument(const string& name) {
CHECK(arg_map_.count(name)) << "Cannot find parameter named " << name;
vector<MessageType> messages(arg_map_[name]->strings_size());
for (int i = 0; i < messages.size(); ++i) {

View File

@ -165,7 +165,6 @@ class XavierFillOp final : public FillerOp<dtype, DeviceContext> {
DISABLE_COPY_AND_ASSIGN(XavierFillOp);
};
// This is mostly used just as a debugging purpose stuff: it fills a tensor
// sequentially with values 0, 1, 2..., which can then be used to check e.g.
// reshape operations by allowing one to read the indices more easily.

View File

@ -72,6 +72,32 @@ inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
return ReadProtoFromFile(filename.c_str(), proto);
}
inline const Argument& GetArgument(const OperatorDef& def, const string& name) {
for (const Argument& arg : def.arg()) {
if (arg.name() == name) {
return arg;
}
}
LOG(FATAL) << "Argument named " << name << " does not exist.";
}
inline Argument* GetMutableArgument(
const string& name, const bool create_if_missing, OperatorDef* def) {
for (int i = 0; i < def->arg_size(); ++i) {
if (def->arg(i).name() == name) {
return def->mutable_arg(i);
}
}
// If no argument of the right name is found...
if (create_if_missing) {
Argument* arg = def->add_arg();
arg->set_name(name);
return arg;
} else {
return nullptr;
}
}
// A coarse support for the Any message in proto3. I am a bit afraid of going
// directly to proto3 yet, so let's do this first...
class Any {