mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Revert uuid change to OperatorDef protobuf
Summary: a few issues: 1. Randomization hurts memoization 1. Even if we make it non random, then we can get key colisions when loading it back. 2. RNNs use prototxt for step net and apparently its not forward compatible like normal protobuf is I am thinking of a better less invasive solution now. Reviewed By: jamesr66a Differential Revision: D5272118 fbshipit-source-id: ab577fad04fbfc632e1fceffa923377a0d3da1be
This commit is contained in:
		
				
					committed by
					
						
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							a6fcecaa71
						
					
				
				
					commit
					83e6a0bec8
				
			@ -15,9 +15,7 @@
 | 
				
			|||||||
namespace caffe2 {
 | 
					namespace caffe2 {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
 | 
					OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
 | 
				
			||||||
    : operator_ws_(ws),
 | 
					    : operator_def_(operator_def), arg_helper_(operator_def_) {
 | 
				
			||||||
      operator_def_(operator_def),
 | 
					 | 
				
			||||||
      arg_helper_(operator_def_) {
 | 
					 | 
				
			||||||
  for (const string& input_str : operator_def_.input()) {
 | 
					  for (const string& input_str : operator_def_.input()) {
 | 
				
			||||||
    auto* blob = ws->GetBlob(input_str);
 | 
					    auto* blob = ws->GetBlob(input_str);
 | 
				
			||||||
    CAFFE_ENFORCE(
 | 
					    CAFFE_ENFORCE(
 | 
				
			||||||
@ -57,9 +55,10 @@ unique_ptr<OperatorBase> TryCreateOperator(
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
unique_ptr<OperatorBase> _CreateOperator(
 | 
					}  // namespace
 | 
				
			||||||
    const OperatorDef& operator_def,
 | 
					
 | 
				
			||||||
    Workspace* ws) {
 | 
					unique_ptr<OperatorBase> CreateOperator(
 | 
				
			||||||
 | 
					    const OperatorDef& operator_def, Workspace* ws) {
 | 
				
			||||||
  static StaticLinkingProtector g_protector;
 | 
					  static StaticLinkingProtector g_protector;
 | 
				
			||||||
  // first, check with OpSchema if the operator is legal.
 | 
					  // first, check with OpSchema if the operator is legal.
 | 
				
			||||||
  auto* schema = OpSchemaRegistry::Schema(operator_def.type());
 | 
					  auto* schema = OpSchemaRegistry::Schema(operator_def.type());
 | 
				
			||||||
@ -112,26 +111,6 @@ unique_ptr<OperatorBase> _CreateOperator(
 | 
				
			|||||||
  return op;
 | 
					  return op;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
unique_ptr<OperatorBase> CreateOperator(
 | 
					 | 
				
			||||||
    const OperatorDef& operator_def,
 | 
					 | 
				
			||||||
    Workspace* ws) {
 | 
					 | 
				
			||||||
  try {
 | 
					 | 
				
			||||||
    auto op = _CreateOperator(operator_def, ws);
 | 
					 | 
				
			||||||
    return op;
 | 
					 | 
				
			||||||
  } catch (...) {
 | 
					 | 
				
			||||||
    if (operator_def.has_uuid()) {
 | 
					 | 
				
			||||||
      auto uuid = operator_def.uuid();
 | 
					 | 
				
			||||||
      VLOG(1) << "Operator constructor with uuid " << uuid << " failed";
 | 
					 | 
				
			||||||
      ws->last_failed_op_uuid = uuid;
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      VLOG(1) << "Failed operator constructor doesn't have uuid set";
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    throw;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() {
 | 
					std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() {
 | 
				
			||||||
  static std::map<int32_t, OperatorRegistry*> g_device_type_registry;
 | 
					  static std::map<int32_t, OperatorRegistry*> g_device_type_registry;
 | 
				
			||||||
  return &g_device_type_registry;
 | 
					  return &g_device_type_registry;
 | 
				
			||||||
 | 
				
			|||||||
@ -137,19 +137,8 @@ class OperatorBase {
 | 
				
			|||||||
    observer_ = nullptr;
 | 
					    observer_ = nullptr;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void RecordLastFailedUuid() {
 | 
					 | 
				
			||||||
    if (this->def().has_uuid()) {
 | 
					 | 
				
			||||||
      auto uuid = this->def().uuid();
 | 
					 | 
				
			||||||
      VLOG(1) << "Operator with uuid " << uuid << " failed";
 | 
					 | 
				
			||||||
      operator_ws_->last_failed_op_uuid = uuid;
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      VLOG(1) << "Failed operator doesn't have uuid set";
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 protected:
 | 
					 protected:
 | 
				
			||||||
  ObserverBase<OperatorBase>* observer_ = nullptr;
 | 
					  ObserverBase<OperatorBase>* observer_ = nullptr;
 | 
				
			||||||
  Workspace* operator_ws_;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  OperatorDef operator_def_;
 | 
					  OperatorDef operator_def_;
 | 
				
			||||||
@ -220,10 +209,6 @@ class Operator : public OperatorBase {
 | 
				
			|||||||
      context_.SwitchToDevice(stream_id);
 | 
					      context_.SwitchToDevice(stream_id);
 | 
				
			||||||
      bool started = RunOnDevice();
 | 
					      bool started = RunOnDevice();
 | 
				
			||||||
      bool finished = context_.FinishDeviceComputation();
 | 
					      bool finished = context_.FinishDeviceComputation();
 | 
				
			||||||
      auto result = started && finished;
 | 
					 | 
				
			||||||
      if (!result) {
 | 
					 | 
				
			||||||
        this->RecordLastFailedUuid();
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      if (!finished) {
 | 
					      if (!finished) {
 | 
				
			||||||
        // FinishDeviceComputation() returning error basically means that there
 | 
					        // FinishDeviceComputation() returning error basically means that there
 | 
				
			||||||
        // is something wrong with the device (like CUDA) that usually cannot be
 | 
					        // is something wrong with the device (like CUDA) that usually cannot be
 | 
				
			||||||
@ -234,14 +219,10 @@ class Operator : public OperatorBase {
 | 
				
			|||||||
      if (observer_) {
 | 
					      if (observer_) {
 | 
				
			||||||
        observer_->Stop();
 | 
					        observer_->Stop();
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      return result;
 | 
					      return (started && finished);
 | 
				
			||||||
    } catch (EnforceNotMet& err) {
 | 
					    } catch (EnforceNotMet& err) {
 | 
				
			||||||
      err.AppendMessage("Error from operator: \n" + ProtoDebugString(def()));
 | 
					      err.AppendMessage("Error from operator: \n" + ProtoDebugString(def()));
 | 
				
			||||||
      AddRelatedBlobInfo(&err);
 | 
					      AddRelatedBlobInfo(&err);
 | 
				
			||||||
      this->RecordLastFailedUuid();
 | 
					 | 
				
			||||||
      throw;
 | 
					 | 
				
			||||||
    } catch (...) {
 | 
					 | 
				
			||||||
      this->RecordLastFailedUuid();
 | 
					 | 
				
			||||||
      throw;
 | 
					      throw;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@ -249,18 +230,10 @@ class Operator : public OperatorBase {
 | 
				
			|||||||
  bool RunAsync(int stream_id = 0) final {
 | 
					  bool RunAsync(int stream_id = 0) final {
 | 
				
			||||||
    try {
 | 
					    try {
 | 
				
			||||||
      context_.SwitchToDevice(stream_id);
 | 
					      context_.SwitchToDevice(stream_id);
 | 
				
			||||||
      auto result = RunOnDevice();
 | 
					      return RunOnDevice();
 | 
				
			||||||
      if (!result) {
 | 
					 | 
				
			||||||
        this->RecordLastFailedUuid();
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      return result;
 | 
					 | 
				
			||||||
    } catch (EnforceNotMet& err) {
 | 
					    } catch (EnforceNotMet& err) {
 | 
				
			||||||
      err.AppendMessage("Error from operator: \n" + ProtoDebugString(def()));
 | 
					      err.AppendMessage("Error from operator: \n" + ProtoDebugString(def()));
 | 
				
			||||||
      AddRelatedBlobInfo(&err);
 | 
					      AddRelatedBlobInfo(&err);
 | 
				
			||||||
      this->RecordLastFailedUuid();
 | 
					 | 
				
			||||||
      throw;
 | 
					 | 
				
			||||||
    } catch (...) {
 | 
					 | 
				
			||||||
      this->RecordLastFailedUuid();
 | 
					 | 
				
			||||||
      throw;
 | 
					      throw;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
				
			|||||||
@ -208,9 +208,6 @@ class Workspace {
 | 
				
			|||||||
  bool RunOperatorOnce(const OperatorDef& op_def);
 | 
					  bool RunOperatorOnce(const OperatorDef& op_def);
 | 
				
			||||||
  bool RunNetOnce(const NetDef& net_def);
 | 
					  bool RunNetOnce(const NetDef& net_def);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 public:
 | 
					 | 
				
			||||||
  std::atomic<uint64_t> last_failed_op_uuid;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  BlobMap blob_map_;
 | 
					  BlobMap blob_map_;
 | 
				
			||||||
  NetMap net_map_;
 | 
					  NetMap net_map_;
 | 
				
			||||||
 | 
				
			|||||||
@ -160,8 +160,6 @@ message OperatorDef {
 | 
				
			|||||||
  // is_gradient_op argument is only used as a hint in shape inference
 | 
					  // is_gradient_op argument is only used as a hint in shape inference
 | 
				
			||||||
  // and has no runtime significance
 | 
					  // and has no runtime significance
 | 
				
			||||||
  optional bool is_gradient_op = 9 [default = false];
 | 
					  optional bool is_gradient_op = 9 [default = false];
 | 
				
			||||||
  // a random uuid used for joining with extra meta info
 | 
					 | 
				
			||||||
  optional uint64 uuid = 10 [default = 0];
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Network definition.
 | 
					// Network definition.
 | 
				
			||||||
 | 
				
			|||||||
@ -21,10 +21,9 @@ from caffe2.python import scope, utils, workspace
 | 
				
			|||||||
import caffe2.python._import_c_extension as C
 | 
					import caffe2.python._import_c_extension as C
 | 
				
			||||||
import pickle
 | 
					import pickle
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import uuid
 | 
					 | 
				
			||||||
import traceback
 | 
					 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Mac os specific message
 | 
					# Mac os specific message
 | 
				
			||||||
if (sys.platform == 'darwin' and 'leveldb' in C.registered_dbs()):
 | 
					if (sys.platform == 'darwin' and 'leveldb' in C.registered_dbs()):
 | 
				
			||||||
    print('If you are using homebrew leveldb on a Mac OS, you might see an '
 | 
					    print('If you are using homebrew leveldb on a Mac OS, you might see an '
 | 
				
			||||||
@ -319,18 +318,7 @@ def CreateOperator(
 | 
				
			|||||||
    # Add all other arguments
 | 
					    # Add all other arguments
 | 
				
			||||||
    for key, value in kwargs.items():
 | 
					    for key, value in kwargs.items():
 | 
				
			||||||
        operator.arg.add().CopyFrom(utils.MakeArgument(key, value))
 | 
					        operator.arg.add().CopyFrom(utils.MakeArgument(key, value))
 | 
				
			||||||
    operator.uuid = uuid.uuid4().int >> 64
 | 
					 | 
				
			||||||
    stack = traceback.extract_stack()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # string part of the stack that belongs to this file
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for i, line in reversed(list(enumerate(stack))):
 | 
					 | 
				
			||||||
        # get path of the core.py file
 | 
					 | 
				
			||||||
        name = __name__.replace('.', '/') + ".py"
 | 
					 | 
				
			||||||
        if name not in ' '.join(map(str, line)):
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    workspace.operator_tracebacks[operator.uuid] = stack[:i + 1]
 | 
					 | 
				
			||||||
    if workspace.IsImmediate():
 | 
					    if workspace.IsImmediate():
 | 
				
			||||||
        workspace.RunOperatorImmediate(operator)
 | 
					        workspace.RunOperatorImmediate(operator)
 | 
				
			||||||
    return operator
 | 
					    return operator
 | 
				
			||||||
@ -1564,12 +1552,6 @@ class Net(object):
 | 
				
			|||||||
        self._InvalidateLookupTables()
 | 
					        self._InvalidateLookupTables()
 | 
				
			||||||
        return self._net
 | 
					        return self._net
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def PopulateProtoWithFileName(self):
 | 
					 | 
				
			||||||
        for op in self.Proto().op:
 | 
					 | 
				
			||||||
            if op.uuid in workspace.operator_tracebacks:
 | 
					 | 
				
			||||||
                tb = workspace.operator_tracebacks[op.uuid]
 | 
					 | 
				
			||||||
                op.name = ':'.join(map(str, tb[-1][:2]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def NextScopedBlob(self, prefix='unnamed'):
 | 
					    def NextScopedBlob(self, prefix='unnamed'):
 | 
				
			||||||
        """Return the blob that has not been defined or registered in the
 | 
					        """Return the blob that has not been defined or registered in the
 | 
				
			||||||
        current net. It returns `ScopedBlobReference(prefix)`, if it's valid,
 | 
					        current net. It returns `ScopedBlobReference(prefix)`, if it's valid,
 | 
				
			||||||
 | 
				
			|||||||
@ -85,14 +85,6 @@ def AddNogradient(op, g_output):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestGradientCalculation(test_util.TestCase):
 | 
					class TestGradientCalculation(test_util.TestCase):
 | 
				
			||||||
    def assertEqual(self, op_list1, op_list2):
 | 
					 | 
				
			||||||
        if isinstance(op_list1, list) and isinstance(op_list2, list):
 | 
					 | 
				
			||||||
            for op in op_list1 + op_list2:
 | 
					 | 
				
			||||||
                if isinstance(op, caffe2_pb2.OperatorDef):
 | 
					 | 
				
			||||||
                    op.ClearField(bytes_to_native_str(b'uuid'))
 | 
					 | 
				
			||||||
        return super(TestGradientCalculation, self).assertEqual(
 | 
					 | 
				
			||||||
            op_list1, op_list2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @given(device_option=st.sampled_from([
 | 
					    @given(device_option=st.sampled_from([
 | 
				
			||||||
        None,
 | 
					        None,
 | 
				
			||||||
        core.DeviceOption(caffe2_pb2.CUDA, 1)]))
 | 
					        core.DeviceOption(caffe2_pb2.CUDA, 1)]))
 | 
				
			||||||
 | 
				
			|||||||
@ -1,23 +1,9 @@
 | 
				
			|||||||
from __future__ import unicode_literals
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from inspect import currentframe, getframeinfo
 | 
					 | 
				
			||||||
import unittest
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from future.utils import bytes_to_native_str
 | 
					 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					 | 
				
			||||||
from caffe2.proto import caffe2_pb2
 | 
					from caffe2.proto import caffe2_pb2
 | 
				
			||||||
from caffe2.python import core, workspace, test_util
 | 
					from caffe2.python import core, workspace, test_util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
def _remove_uuid(proto):
 | 
					 | 
				
			||||||
    if isinstance(proto, caffe2_pb2.NetDef):
 | 
					 | 
				
			||||||
        for op in proto.op:
 | 
					 | 
				
			||||||
            op.ClearField(bytes_to_native_str(b'uuid'))
 | 
					 | 
				
			||||||
    elif isinstance(proto, caffe2_pb2.OperatorDef):
 | 
					 | 
				
			||||||
        proto.ClearField(bytes_to_native_str(b'uuid'))
 | 
					 | 
				
			||||||
    return proto
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestScopes(test_util.TestCase):
 | 
					class TestScopes(test_util.TestCase):
 | 
				
			||||||
    def testBlobReferenceIsIndependentFromNameScope(self):
 | 
					    def testBlobReferenceIsIndependentFromNameScope(self):
 | 
				
			||||||
        blob_v = core.BlobReference("v")
 | 
					        blob_v = core.BlobReference("v")
 | 
				
			||||||
@ -255,8 +241,7 @@ class TestAutoNaming(test_util.TestCase):
 | 
				
			|||||||
        net_a = create_net()
 | 
					        net_a = create_net()
 | 
				
			||||||
        net_b = create_net()
 | 
					        net_b = create_net()
 | 
				
			||||||
        # created net proto is predicatable.
 | 
					        # created net proto is predicatable.
 | 
				
			||||||
        self.assertEqual(_remove_uuid(net_a.Proto()).op,
 | 
					        self.assertEqual(net_a.Proto().op, net_b.Proto().op)
 | 
				
			||||||
                         _remove_uuid(net_b.Proto()).op)
 | 
					 | 
				
			||||||
        self.assertEqual(net_a.Proto().op[0].output[0], 'foo/ab')
 | 
					        self.assertEqual(net_a.Proto().op[0].output[0], 'foo/ab')
 | 
				
			||||||
        self.assertEqual(net_a.Proto().op[1].output[0], 'cd')
 | 
					        self.assertEqual(net_a.Proto().op[1].output[0], 'cd')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -384,63 +369,6 @@ class TestExtractPredictorNet(test_util.TestCase):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestOperatorTraceback(test_util.TestCase):
 | 
					 | 
				
			||||||
    def test_operator_constructor_traceback(self):
 | 
					 | 
				
			||||||
        net = core.Net("test")
 | 
					 | 
				
			||||||
        a, b = net.AddExternalInput("a", "b")
 | 
					 | 
				
			||||||
        net.Mul([a, b], "c")
 | 
					 | 
				
			||||||
        with self.assertRaises(Exception):
 | 
					 | 
				
			||||||
            workspace.RunNetOnce(net)
 | 
					 | 
				
			||||||
        with self.assertRaises(Exception):
 | 
					 | 
				
			||||||
            workspace.CreateNet(net)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_operator_runtime_traceback(self):
 | 
					 | 
				
			||||||
        net = core.Net("test")
 | 
					 | 
				
			||||||
        a = net.AddExternalInput("a")
 | 
					 | 
				
			||||||
        workspace.blobs[a] = np.array([1, 2, 3], dtype=np.float32)
 | 
					 | 
				
			||||||
        net.Split(a, ["b", "c"], axis=0)
 | 
					 | 
				
			||||||
        with self.assertRaises(Exception):
 | 
					 | 
				
			||||||
            workspace.RunNetOnce(net)
 | 
					 | 
				
			||||||
        workspace.CreateNet(net)
 | 
					 | 
				
			||||||
        with self.assertRaises(Exception):
 | 
					 | 
				
			||||||
            workspace.RunNet(net)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_name_population(self):
 | 
					 | 
				
			||||||
        net = core.Net("test")
 | 
					 | 
				
			||||||
        # capture line number on which operator is added
 | 
					 | 
				
			||||||
        net.Mul(["a", "b"], "c"); cf = currentframe(); line = cf.f_lineno
 | 
					 | 
				
			||||||
        net.PopulateProtoWithFileName()
 | 
					 | 
				
			||||||
        print(net.Proto())
 | 
					 | 
				
			||||||
        filename = getframeinfo(cf).filename
 | 
					 | 
				
			||||||
        self.assertEqual(net.Proto().op[0].name, '{}:{}'.format(filename, line))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_c_workspace_constructor(self):
 | 
					 | 
				
			||||||
        net = core.Net("test")
 | 
					 | 
				
			||||||
        a, b = net.AddExternalInput("a", "b")
 | 
					 | 
				
			||||||
        net.Mul([a, b], "c"); cf = currentframe(); line = cf.f_lineno
 | 
					 | 
				
			||||||
        ws = workspace.C.Workspace()
 | 
					 | 
				
			||||||
        with self.assertRaises(Exception):
 | 
					 | 
				
			||||||
            ws.run(net)
 | 
					 | 
				
			||||||
        with self.assertRaises(Exception):
 | 
					 | 
				
			||||||
            ws.create_net(net)
 | 
					 | 
				
			||||||
        net.PopulateProtoWithFileName()
 | 
					 | 
				
			||||||
        filename = getframeinfo(cf).filename
 | 
					 | 
				
			||||||
        self.assertEqual(net.Proto().op[0].name, '{}:{}'.format(filename, line))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_c_workspace_runtime(self):
 | 
					 | 
				
			||||||
        net = core.Net("test")
 | 
					 | 
				
			||||||
        a = net.AddExternalInput("a")
 | 
					 | 
				
			||||||
        net.Split(a, ["b", "c"], axis=0); cf = currentframe(); line = cf.f_lineno
 | 
					 | 
				
			||||||
        ws = workspace.C.Workspace()
 | 
					 | 
				
			||||||
        ws.create_blob(str(a)).feed(np.array([1, 2, 3], dtype=np.float32))
 | 
					 | 
				
			||||||
        ws.create_net(net)
 | 
					 | 
				
			||||||
        with self.assertRaises(Exception):
 | 
					 | 
				
			||||||
            ws.run(net)
 | 
					 | 
				
			||||||
        net.PopulateProtoWithFileName()
 | 
					 | 
				
			||||||
        filename = getframeinfo(cf).filename
 | 
					 | 
				
			||||||
        self.assertEqual(net.Proto().op[0].name, '{}:{}'.format(filename, line))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@unittest.skipIf(not workspace.has_gpu_support, 'No GPU support')
 | 
					@unittest.skipIf(not workspace.has_gpu_support, 'No GPU support')
 | 
				
			||||||
class TestInferDevice(test_util.TestCase):
 | 
					class TestInferDevice(test_util.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -559,6 +487,45 @@ class TestInferDevice(test_util.TestCase):
 | 
				
			|||||||
        self.assertEqual(op.input[2], "fc_b_cuda_1")
 | 
					        self.assertEqual(op.input[2], "fc_b_cuda_1")
 | 
				
			||||||
        self.assertEqual(op.device_option.device_type, 1)
 | 
					        self.assertEqual(op.device_option.device_type, 1)
 | 
				
			||||||
        self.assertEqual(op.device_option.cuda_gpu_id, 1)
 | 
					        self.assertEqual(op.device_option.cuda_gpu_id, 1)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					For reference, net.Proto() should be like:
 | 
				
			||||||
 | 
					name: ""
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "fc_w"
 | 
				
			||||||
 | 
					  output: "fc_w_cuda_1"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "CopyCPUToGPU"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 1
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "fc_b"
 | 
				
			||||||
 | 
					  output: "fc_b_cuda_1"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "CopyCPUToGPU"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 1
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data"
 | 
				
			||||||
 | 
					  input: "fc_w_cuda_1"
 | 
				
			||||||
 | 
					  input: "fc_b_cuda_1"
 | 
				
			||||||
 | 
					  output: "fc1"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "FC"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 1
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					external_input: "data"
 | 
				
			||||||
 | 
					external_input: "fc_w"
 | 
				
			||||||
 | 
					external_input: "fc_b"
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_cross_nets_no_change(self):
 | 
					    def test_cross_nets_no_change(self):
 | 
				
			||||||
        net = core.Net("test")
 | 
					        net = core.Net("test")
 | 
				
			||||||
@ -583,6 +550,25 @@ class TestInferDevice(test_util.TestCase):
 | 
				
			|||||||
        self.assertEqual(op.input[2], "fc_b")
 | 
					        self.assertEqual(op.input[2], "fc_b")
 | 
				
			||||||
        self.assertEqual(op.device_option.device_type, 1)
 | 
					        self.assertEqual(op.device_option.device_type, 1)
 | 
				
			||||||
        self.assertEqual(op.device_option.cuda_gpu_id, 1)
 | 
					        self.assertEqual(op.device_option.cuda_gpu_id, 1)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					For reference, net.Proto() should be like:
 | 
				
			||||||
 | 
					name: ""
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data"
 | 
				
			||||||
 | 
					  input: "fc_w"
 | 
				
			||||||
 | 
					  input: "fc_b"
 | 
				
			||||||
 | 
					  output: "fc1"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "FC"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 1
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					external_input: "data"
 | 
				
			||||||
 | 
					external_input: "fc_w"
 | 
				
			||||||
 | 
					external_input: "fc_b"
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_inject_copy_multi_use(self):
 | 
					    def test_inject_copy_multi_use(self):
 | 
				
			||||||
        net = core.Net("test")
 | 
					        net = core.Net("test")
 | 
				
			||||||
@ -645,6 +631,83 @@ class TestInferDevice(test_util.TestCase):
 | 
				
			|||||||
        self.assertEqual(op.device_option.cuda_gpu_id, 1)
 | 
					        self.assertEqual(op.device_option.cuda_gpu_id, 1)
 | 
				
			||||||
        self.assertEqual(op.input[0], "data_cuda_1")
 | 
					        self.assertEqual(op.input[0], "data_cuda_1")
 | 
				
			||||||
        self.assertEqual(op.output[0], "relu6")
 | 
					        self.assertEqual(op.output[0], "relu6")
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					For reference, net.Proto() should be like:
 | 
				
			||||||
 | 
					name: ""
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data"
 | 
				
			||||||
 | 
					  output: "data_cuda_1"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "CopyCPUToGPU"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 1
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data_cuda_1"
 | 
				
			||||||
 | 
					  output: "relu1"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "Relu"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 1
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data"
 | 
				
			||||||
 | 
					  output: "relu2"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "Relu"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data_cuda_1"
 | 
				
			||||||
 | 
					  output: "relu3"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "Relu"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 1
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data"
 | 
				
			||||||
 | 
					  output: "relu4"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "Relu"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data"
 | 
				
			||||||
 | 
					  output: "data_cuda_0"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "CopyCPUToGPU"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 0
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data_cuda_0"
 | 
				
			||||||
 | 
					  output: "relu5"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "Relu"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 0
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  input: "data_cuda_1"
 | 
				
			||||||
 | 
					  output: "relu6"
 | 
				
			||||||
 | 
					  name: ""
 | 
				
			||||||
 | 
					  type: "Relu"
 | 
				
			||||||
 | 
					  device_option {
 | 
				
			||||||
 | 
					    device_type: 1
 | 
				
			||||||
 | 
					    cuda_gpu_id: 1
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					external_input: "data"
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
				
			|||||||
@ -541,12 +541,6 @@ void addObjectMethods(py::module& m) {
 | 
				
			|||||||
            py::gil_scoped_release g;
 | 
					            py::gil_scoped_release g;
 | 
				
			||||||
            CAFFE_ENFORCE(self->RunPlan(proto));
 | 
					            CAFFE_ENFORCE(self->RunPlan(proto));
 | 
				
			||||||
          })
 | 
					          })
 | 
				
			||||||
      .def(
 | 
					 | 
				
			||||||
          "last_failed_op_uuid",
 | 
					 | 
				
			||||||
          [](Workspace* self) {
 | 
					 | 
				
			||||||
            CAFFE_ENFORCE(self);
 | 
					 | 
				
			||||||
            return (uint64_t)self->last_failed_op_uuid;
 | 
					 | 
				
			||||||
          })
 | 
					 | 
				
			||||||
      .def_property_readonly_static("current", [](py::object /* type */) {
 | 
					      .def_property_readonly_static("current", [](py::object /* type */) {
 | 
				
			||||||
        auto ws = gWorkspaces.find(gCurrentWorkspaceName);
 | 
					        auto ws = gWorkspaces.find(gCurrentWorkspaceName);
 | 
				
			||||||
        CAFFE_ENFORCE(ws != gWorkspaces.end());
 | 
					        CAFFE_ENFORCE(ws != gWorkspaces.end());
 | 
				
			||||||
@ -963,10 +957,7 @@ void addGlobalMethods(py::module& m) {
 | 
				
			|||||||
        gRegistery()[token] = Func{func, pass_workspace};
 | 
					        gRegistery()[token] = Func{func, pass_workspace};
 | 
				
			||||||
        return token;
 | 
					        return token;
 | 
				
			||||||
      });
 | 
					      });
 | 
				
			||||||
  m.def("last_failed_op_uuid", []() {
 | 
					
 | 
				
			||||||
    CAFFE_ENFORCE(gWorkspace);
 | 
					 | 
				
			||||||
    return (uint64_t)gWorkspace->last_failed_op_uuid;
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
      "register_python_gradient_op",
 | 
					      "register_python_gradient_op",
 | 
				
			||||||
      [](const std::string& token, py::object func) {
 | 
					      [](const std::string& token, py::object func) {
 | 
				
			||||||
 | 
				
			|||||||
@ -43,8 +43,6 @@ Workspaces = C.workspaces
 | 
				
			|||||||
BenchmarkNet = C.benchmark_net
 | 
					BenchmarkNet = C.benchmark_net
 | 
				
			||||||
Predictor = C.Predictor
 | 
					Predictor = C.Predictor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
operator_tracebacks = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
is_asan = C.is_asan
 | 
					is_asan = C.is_asan
 | 
				
			||||||
has_gpu_support = C.has_gpu_support
 | 
					has_gpu_support = C.has_gpu_support
 | 
				
			||||||
if has_gpu_support:
 | 
					if has_gpu_support:
 | 
				
			||||||
@ -146,8 +144,7 @@ def CreateNet(net, overwrite=False, input_blobs=None):
 | 
				
			|||||||
        input_blobs = []
 | 
					        input_blobs = []
 | 
				
			||||||
    for input_blob in input_blobs:
 | 
					    for input_blob in input_blobs:
 | 
				
			||||||
        C.create_blob(input_blob)
 | 
					        C.create_blob(input_blob)
 | 
				
			||||||
    return CallWithExceptionIntercept(
 | 
					    return C.create_net(StringifyProto(net), overwrite)
 | 
				
			||||||
        C.create_net, C.last_failed_op_uuid, StringifyProto(net), overwrite)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def RunOperatorOnce(operator):
 | 
					def RunOperatorOnce(operator):
 | 
				
			||||||
@ -162,20 +159,8 @@ def RunOperatorsOnce(operators):
 | 
				
			|||||||
    return True
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def CallWithExceptionIntercept(func, uuid_fetcher, *args, **kwargs):
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        return func(*args, **kwargs)
 | 
					 | 
				
			||||||
    except Exception as ex:
 | 
					 | 
				
			||||||
        uuid = uuid_fetcher()
 | 
					 | 
				
			||||||
        if uuid in operator_tracebacks:
 | 
					 | 
				
			||||||
            for line in operator_tracebacks[uuid]:
 | 
					 | 
				
			||||||
                print(':'.join(map(str, line)))
 | 
					 | 
				
			||||||
        raise ex
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def RunNetOnce(net):
 | 
					def RunNetOnce(net):
 | 
				
			||||||
    return CallWithExceptionIntercept(
 | 
					    return C.run_net_once(StringifyProto(net))
 | 
				
			||||||
        C.run_net_once, C.last_failed_op_uuid, StringifyProto(net))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def RunNet(name, num_iter=1, allow_fail=False):
 | 
					def RunNet(name, num_iter=1, allow_fail=False):
 | 
				
			||||||
@ -188,9 +173,7 @@ def RunNet(name, num_iter=1, allow_fail=False):
 | 
				
			|||||||
    Returns:
 | 
					    Returns:
 | 
				
			||||||
      True or an exception.
 | 
					      True or an exception.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return CallWithExceptionIntercept(
 | 
					    return C.run_net(StringifyNetName(name), num_iter, allow_fail)
 | 
				
			||||||
        C.run_net, C.last_failed_op_uuid,
 | 
					 | 
				
			||||||
        StringifyNetName(name), num_iter, allow_fail)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def RunPlan(plan_or_step):
 | 
					def RunPlan(plan_or_step):
 | 
				
			||||||
@ -454,12 +437,11 @@ def FeedImmediate(*args, **kwargs):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# CWorkspace utilities
 | 
					# CWorkspace utilities
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
 | 
					def _Workspace_create_net(ws, net, overwrite=False):
 | 
				
			||||||
    return CallWithExceptionIntercept(
 | 
					    return ws._create_net(StringifyProto(net), overwrite)
 | 
				
			||||||
        ws._create_net, ws.last_failed_op_uuid, StringifyProto(net), overwrite)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
C.Workspace.create_net = _Workspace_create_net_with_exception_intercept
 | 
					C.Workspace.create_net = _Workspace_create_net
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _Workspace_run(ws, obj):
 | 
					def _Workspace_run(ws, obj):
 | 
				
			||||||
@ -475,12 +457,7 @@ def _Workspace_run(ws, obj):
 | 
				
			|||||||
        "Don't know how to do Workspace.run() on {}".format(type(obj)))
 | 
					        "Don't know how to do Workspace.run() on {}".format(type(obj)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _Workspace_run_with_exception_intercept(ws, obj):
 | 
					C.Workspace.run = _Workspace_run
 | 
				
			||||||
    return CallWithExceptionIntercept(
 | 
					 | 
				
			||||||
        _Workspace_run, ws.last_failed_op_uuid, ws, obj)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
C.Workspace.run = _Workspace_run_with_exception_intercept
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _Blob_feed(blob, arg, device_option=None):
 | 
					def _Blob_feed(blob, arg, device_option=None):
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user