diff --git a/caffe2/contrib/tensorboard/tensorboard_test.py b/caffe2/contrib/tensorboard/tensorboard_test.py index 943e75fe1dd0..489be5d43ff2 100644 --- a/caffe2/contrib/tensorboard/tensorboard_test.py +++ b/caffe2/contrib/tensorboard/tensorboard_test.py @@ -91,9 +91,11 @@ class TensorboardTest(unittest.TestCase): for i, (event, net) in enumerate(zip(events, nets), start=1): self.assertEqual(event.step, i) self.assertEqual(event.wall_time, i) - self.assertEqual( - event.graph_def, - tb_exporter.nets_to_graph_def([net]).SerializeToString()) + g = tf.GraphDef() + g.ParseFromString(event.graph_def) + self.assertMultiLineEqual( + str(g), + str(tb_exporter.nets_to_graph_def([net]))) if __name__ == "__main__": diff --git a/caffe2/proto/__init__.py b/caffe2/proto/__init__.py index e69de29bb2d1..a753f26c5380 100644 --- a/caffe2/proto/__init__.py +++ b/caffe2/proto/__init__.py @@ -0,0 +1,15 @@ +# NOTE: we have to import python protobuf here **before** we load cpp extension. +# Otherwise it breaks under certain build conditions if cpp implementation of +# protobuf is used. Presumably there's some registry in protobuf library and +# python side has to initialize the dictionary first, before static +# initialization in python extension does so. Otherwise, duplicated protobuf +# descriptors will be created and it can lead to obscure errors like +# "Parameter to MergeFrom() must be instance of same class: +# expected caffe2.NetDef got caffe2.NetDef." +# +# This has to be done for all python targets, so listing them here +from caffe2.proto import caffe2_pb2, metanet_pb2, torch_pb2 +try: + from caffe2.caffe2.fb.session.proto import session_pb2 +except ImportError: + pass diff --git a/caffe2/python/_import_c_extension.py b/caffe2/python/_import_c_extension.py index 4a10d79477d9..dc3e6ea2a2fd 100644 --- a/caffe2/python/_import_c_extension.py +++ b/caffe2/python/_import_c_extension.py @@ -5,6 +5,16 @@ import logging import sys from caffe2.python import extension_loader +# NOTE: we have to import python protobuf here **before** we load cpp extension. +# Otherwise it breaks under certain build conditions if cpp implementation of +# protobuf is used. Presumably there's some registry in protobuf library and +# python side has to initialize the dictionary first, before static +# initialization in python extension does so. Otherwise, duplicated protobuf +# descriptors will be created and it can lead to obscure errors like +# "Parameter to MergeFrom() must be instance of same class: +# expected caffe2.NetDef got caffe2.NetDef." +import caffe2.proto + # We will first try to load the gpu-enabled caffe2. If it fails, we will then # attempt to load the cpu version. The cpu backend is the minimum required, so # if that still fails, we will exit loud. diff --git a/caffe2/python/test/python_protobuf_test.py b/caffe2/python/test/python_protobuf_test.py new file mode 100644 index 000000000000..817f5e21a563 --- /dev/null +++ b/caffe2/python/test/python_protobuf_test.py @@ -0,0 +1,24 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# make sure we use cpp implementation of protobuf +import os +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp" + +# import cpp extension first +from caffe2.python import core +# then import protobuf +from caffe2.proto import caffe2_pb2, metanet_pb2 + +import unittest + + +class TestCrossProtoCalls(unittest.TestCase): + def testSimple(self): + net = caffe2_pb2.NetDef() + meta = metanet_pb2.MetaNetDef() + # if metanet_pb2 wasn't initialized properly the following fails with a + # cryptic message: "Parameter to MergeFrom() must be instance of same + # class: expected caffe2.NetDef got caffe2.NetDef." + meta.nets.add(key="foo", value=net)