mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Some utility function changes
This commit is contained in:
@ -39,15 +39,19 @@ class OperatorBase {
|
||||
vector<T> GetRepeatedArgument(const string& name);
|
||||
|
||||
template <typename MessageType>
|
||||
MessageType GetAnyMessageArgument(const string& name) {
|
||||
MessageType GetMessageArgument(const string& name) {
|
||||
CHECK(arg_map_.count(name)) << "Cannot find parameter named " << name;
|
||||
MessageType message;
|
||||
CHECK(message.ParseFromString(arg_map_[name]->s()))
|
||||
<< "Faild to parse content from the string";
|
||||
if (arg_map_[name]->has_s()) {
|
||||
CHECK(message.ParseFromString(arg_map_[name]->s()))
|
||||
<< "Faild to parse content from the string";
|
||||
} else {
|
||||
VLOG(1) << "Return empty message for parameter " << name;
|
||||
}
|
||||
return message;
|
||||
}
|
||||
template <typename MessageType>
|
||||
vector<MessageType> GetAnyRepeatedMessageArgument(const string& name) {
|
||||
vector<MessageType> GetRepeatedMessageArgument(const string& name) {
|
||||
CHECK(arg_map_.count(name)) << "Cannot find parameter named " << name;
|
||||
vector<MessageType> messages(arg_map_[name]->strings_size());
|
||||
for (int i = 0; i < messages.size(); ++i) {
|
||||
|
@ -165,7 +165,6 @@ class XavierFillOp final : public FillerOp<dtype, DeviceContext> {
|
||||
DISABLE_COPY_AND_ASSIGN(XavierFillOp);
|
||||
};
|
||||
|
||||
|
||||
// This is mostly used just as a debugging purpose stuff: it fills a tensor
|
||||
// sequentially with values 0, 1, 2..., which can then be used to check e.g.
|
||||
// reshape operations by allowing one to read the indices more easily.
|
||||
|
@ -72,6 +72,32 @@ inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
|
||||
return ReadProtoFromFile(filename.c_str(), proto);
|
||||
}
|
||||
|
||||
inline const Argument& GetArgument(const OperatorDef& def, const string& name) {
|
||||
for (const Argument& arg : def.arg()) {
|
||||
if (arg.name() == name) {
|
||||
return arg;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Argument named " << name << " does not exist.";
|
||||
}
|
||||
|
||||
inline Argument* GetMutableArgument(
|
||||
const string& name, const bool create_if_missing, OperatorDef* def) {
|
||||
for (int i = 0; i < def->arg_size(); ++i) {
|
||||
if (def->arg(i).name() == name) {
|
||||
return def->mutable_arg(i);
|
||||
}
|
||||
}
|
||||
// If no argument of the right name is found...
|
||||
if (create_if_missing) {
|
||||
Argument* arg = def->add_arg();
|
||||
arg->set_name(name);
|
||||
return arg;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// A coarse support for the Any message in proto3. I am a bit afraid of going
|
||||
// directly to proto3 yet, so let's do this first...
|
||||
class Any {
|
||||
|
Reference in New Issue
Block a user