mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix some bug in the old code: the net device option overwite used to not work.
This commit is contained in:
@ -21,13 +21,18 @@ NetBase* CreateNet(const NetDef& net_def, Workspace* ws) {
|
||||
|
||||
SimpleNet::SimpleNet(const NetDef& net_def, Workspace* ws)
|
||||
: NetBase(net_def, ws) {
|
||||
bool net_def_has_device_option = net_def.has_device_option();
|
||||
// Initialize the operators
|
||||
for (const OperatorDef& operator_def : net_def.operators()) {
|
||||
VLOG(1) << "Creating operator " << operator_def.name()
|
||||
<< ":" << operator_def.type();
|
||||
if (!operator_def.has_device_option()) {
|
||||
operators_.emplace_back(
|
||||
CreateOperator(operator_def, net_def.device_option(), ws));
|
||||
if (!operator_def.has_device_option() && net_def_has_device_option) {
|
||||
// In the case that the operator def does not specify a device option but
|
||||
// the net def has a default option, we copy the device option over to the
|
||||
// operator def.
|
||||
OperatorDef temp_def(operator_def);
|
||||
temp_def.mutable_device_option()->CopyFrom(net_def.device_option());
|
||||
operators_.emplace_back(CreateOperator(temp_def, ws));
|
||||
} else {
|
||||
operators_.emplace_back(CreateOperator(operator_def, ws));
|
||||
}
|
||||
@ -60,14 +65,16 @@ ParallelNet::ParallelNet(const NetDef& net_def, Workspace* ws)
|
||||
: NetBase(net_def, ws), operator_nodes_(net_def.operators_size()) {
|
||||
// Blob creator allows us to track which operator created which blob.
|
||||
std::map<string, int> blob_creator;
|
||||
bool net_def_has_device_option = net_def.has_device_option();
|
||||
// Initialize the operators
|
||||
for (int idx = 0; idx < net_def.operators_size(); ++idx) {
|
||||
const OperatorDef& op_def = net_def.operators(idx);
|
||||
VLOG(1) << "Creating operator #" << idx << ": "
|
||||
<< op_def.name() << ":" << op_def.type();
|
||||
if (!op_def.has_device_option()) {
|
||||
operator_nodes_[idx].operator_.reset(
|
||||
CreateOperator(op_def, net_def.device_option(), ws));
|
||||
if (!op_def.has_device_option() && net_def_has_device_option) {
|
||||
OperatorDef temp_def(op_def);
|
||||
temp_def.mutable_device_option()->CopyFrom(net_def.device_option());
|
||||
operator_nodes_[idx].operator_.reset(CreateOperator(temp_def, ws));
|
||||
} else {
|
||||
operator_nodes_[idx].operator_.reset(CreateOperator(op_def, ws));
|
||||
}
|
||||
|
@ -90,9 +90,7 @@ bool OperatorBase::Verify() {
|
||||
return true;
|
||||
}
|
||||
|
||||
OperatorBase* CreateOperator(const OperatorDef& operator_def,
|
||||
const DeviceOption& device_option,
|
||||
Workspace* ws) {
|
||||
OperatorBase* CreateOperator(const OperatorDef& operator_def, Workspace* ws) {
|
||||
const string& key = operator_def.type();
|
||||
switch (operator_def.device_option().device_type()) {
|
||||
case CPU:
|
||||
|
@ -216,17 +216,8 @@ DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase,
|
||||
#define REGISTER_CUDNN_OPERATOR(name, ...) \
|
||||
REGISTER_CLASS(CUDNNOperatorRegistry, name, __VA_ARGS__)
|
||||
|
||||
// Creates an operator with the given operator definition and device option.
|
||||
OperatorBase* CreateOperator(const OperatorDef& operator_def,
|
||||
const DeviceOption& device_option,
|
||||
Workspace* ws);
|
||||
|
||||
// Create an operator with the given operator definition, and the device
|
||||
// option that is specified in the operator definition.
|
||||
inline OperatorBase* CreateOperator(const OperatorDef& operator_def,
|
||||
Workspace* ws) {
|
||||
return CreateOperator(operator_def, operator_def.device_option(), ws);
|
||||
}
|
||||
// Creates an operator with the given operator definition.
|
||||
OperatorBase* CreateOperator(const OperatorDef& operator_def, Workspace* ws);
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
|
Reference in New Issue
Block a user