Fix some bug in the old code: the net device option overwite used to not work.

This commit is contained in:
Yangqing Jia
2015-07-01 13:37:40 -07:00
parent 8834e3eb13
commit e94be296b7
3 changed files with 16 additions and 20 deletions

View File

@ -21,13 +21,18 @@ NetBase* CreateNet(const NetDef& net_def, Workspace* ws) {
SimpleNet::SimpleNet(const NetDef& net_def, Workspace* ws)
: NetBase(net_def, ws) {
bool net_def_has_device_option = net_def.has_device_option();
// Initialize the operators
for (const OperatorDef& operator_def : net_def.operators()) {
VLOG(1) << "Creating operator " << operator_def.name()
<< ":" << operator_def.type();
if (!operator_def.has_device_option()) {
operators_.emplace_back(
CreateOperator(operator_def, net_def.device_option(), ws));
if (!operator_def.has_device_option() && net_def_has_device_option) {
// In the case that the operator def does not specify a device option but
// the net def has a default option, we copy the device option over to the
// operator def.
OperatorDef temp_def(operator_def);
temp_def.mutable_device_option()->CopyFrom(net_def.device_option());
operators_.emplace_back(CreateOperator(temp_def, ws));
} else {
operators_.emplace_back(CreateOperator(operator_def, ws));
}
@ -60,14 +65,16 @@ ParallelNet::ParallelNet(const NetDef& net_def, Workspace* ws)
: NetBase(net_def, ws), operator_nodes_(net_def.operators_size()) {
// Blob creator allows us to track which operator created which blob.
std::map<string, int> blob_creator;
bool net_def_has_device_option = net_def.has_device_option();
// Initialize the operators
for (int idx = 0; idx < net_def.operators_size(); ++idx) {
const OperatorDef& op_def = net_def.operators(idx);
VLOG(1) << "Creating operator #" << idx << ": "
<< op_def.name() << ":" << op_def.type();
if (!op_def.has_device_option()) {
operator_nodes_[idx].operator_.reset(
CreateOperator(op_def, net_def.device_option(), ws));
if (!op_def.has_device_option() && net_def_has_device_option) {
OperatorDef temp_def(op_def);
temp_def.mutable_device_option()->CopyFrom(net_def.device_option());
operator_nodes_[idx].operator_.reset(CreateOperator(temp_def, ws));
} else {
operator_nodes_[idx].operator_.reset(CreateOperator(op_def, ws));
}

View File

@ -90,9 +90,7 @@ bool OperatorBase::Verify() {
return true;
}
OperatorBase* CreateOperator(const OperatorDef& operator_def,
const DeviceOption& device_option,
Workspace* ws) {
OperatorBase* CreateOperator(const OperatorDef& operator_def, Workspace* ws) {
const string& key = operator_def.type();
switch (operator_def.device_option().device_type()) {
case CPU:

View File

@ -216,17 +216,8 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase,
#define REGISTER_CUDNN_OPERATOR(name, ...) \
REGISTER_CLASS(CUDNNOperatorRegistry, name, __VA_ARGS__)
// Creates an operator with the given operator definition and device option.
OperatorBase* CreateOperator(const OperatorDef& operator_def,
const DeviceOption& device_option,
Workspace* ws);
// Create an operator with the given operator definition, and the device
// option that is specified in the operator definition.
inline OperatorBase* CreateOperator(const OperatorDef& operator_def,
Workspace* ws) {
return CreateOperator(operator_def, operator_def.device_option(), ws);
}
// Creates an operator with the given operator definition.
OperatorBase* CreateOperator(const OperatorDef& operator_def, Workspace* ws);
} // namespace caffe2