mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enforce import order to make protobuf cpp implementation in python work (#18560)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18560 We have to import python protobuf here **before** we load cpp extension. Otherwise it breaks under certain build conditions if cpp implementation of protobuf is used. Presumably there's some registry in protobuf library and python side has to initialize the dictionary first, before static initialization in python extension does so. Otherwise, duplicated protobuf descriptors will be created and it can lead to obscure errors like Parameter to MergeFrom() must be instance of same class: expected caffe2.NetDef got caffe2.NetDef. I think it also fixes https://github.com/facebookarchive/caffe2/issues/1573 Reviewed By: ezyang, iroot900 Differential Revision: D14622054 fbshipit-source-id: 2499eb88ecdee85ff8d845859048f7ae5da2a480
This commit is contained in:
committed by
Facebook Github Bot
parent
3b71f2e1f2
commit
3af2d6d904
@ -91,9 +91,11 @@ class TensorboardTest(unittest.TestCase):
|
||||
for i, (event, net) in enumerate(zip(events, nets), start=1):
|
||||
self.assertEqual(event.step, i)
|
||||
self.assertEqual(event.wall_time, i)
|
||||
self.assertEqual(
|
||||
event.graph_def,
|
||||
tb_exporter.nets_to_graph_def([net]).SerializeToString())
|
||||
g = tf.GraphDef()
|
||||
g.ParseFromString(event.graph_def)
|
||||
self.assertMultiLineEqual(
|
||||
str(g),
|
||||
str(tb_exporter.nets_to_graph_def([net])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -0,0 +1,15 @@
|
||||
# NOTE: we have to import python protobuf here **before** we load cpp extension.
|
||||
# Otherwise it breaks under certain build conditions if cpp implementation of
|
||||
# protobuf is used. Presumably there's some registry in protobuf library and
|
||||
# python side has to initialize the dictionary first, before static
|
||||
# initialization in python extension does so. Otherwise, duplicated protobuf
|
||||
# descriptors will be created and it can lead to obscure errors like
|
||||
# "Parameter to MergeFrom() must be instance of same class:
|
||||
# expected caffe2.NetDef got caffe2.NetDef."
|
||||
#
|
||||
# This has to be done for all python targets, so listing them here
|
||||
from caffe2.proto import caffe2_pb2, metanet_pb2, torch_pb2
|
||||
try:
|
||||
from caffe2.caffe2.fb.session.proto import session_pb2
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -5,6 +5,16 @@ import logging
|
||||
import sys
|
||||
from caffe2.python import extension_loader
|
||||
|
||||
# NOTE: we have to import python protobuf here **before** we load cpp extension.
|
||||
# Otherwise it breaks under certain build conditions if cpp implementation of
|
||||
# protobuf is used. Presumably there's some registry in protobuf library and
|
||||
# python side has to initialize the dictionary first, before static
|
||||
# initialization in python extension does so. Otherwise, duplicated protobuf
|
||||
# descriptors will be created and it can lead to obscure errors like
|
||||
# "Parameter to MergeFrom() must be instance of same class:
|
||||
# expected caffe2.NetDef got caffe2.NetDef."
|
||||
import caffe2.proto
|
||||
|
||||
# We will first try to load the gpu-enabled caffe2. If it fails, we will then
|
||||
# attempt to load the cpu version. The cpu backend is the minimum required, so
|
||||
# if that still fails, we will exit loud.
|
||||
|
24
caffe2/python/test/python_protobuf_test.py
Normal file
24
caffe2/python/test/python_protobuf_test.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# make sure we use cpp implementation of protobuf
|
||||
import os
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
|
||||
|
||||
# import cpp extension first
|
||||
from caffe2.python import core
|
||||
# then import protobuf
|
||||
from caffe2.proto import caffe2_pb2, metanet_pb2
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestCrossProtoCalls(unittest.TestCase):
|
||||
def testSimple(self):
|
||||
net = caffe2_pb2.NetDef()
|
||||
meta = metanet_pb2.MetaNetDef()
|
||||
# if metanet_pb2 wasn't initialized properly the following fails with a
|
||||
# cryptic message: "Parameter to MergeFrom() must be instance of same
|
||||
# class: expected caffe2.NetDef got caffe2.NetDef."
|
||||
meta.nets.add(key="foo", value=net)
|
Reference in New Issue
Block a user