Add python bindings (#12253)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12253

Adding python bindings to unblock DAI development

Reviewed By: duc0

Differential Revision: D10141621

fbshipit-source-id: efac7fb8a0cc787e1c4cc94515e673812529a997
This commit is contained in:
Bram Wasti
2018-10-08 12:20:34 -07:00
committed by Facebook Github Bot
parent e7653c7561
commit 7103d0d938
3 changed files with 45 additions and 4 deletions

View File

@ -12,12 +12,24 @@ import errno
class NNModule(object):
def __init__(self, net=None):
def __init__(self, net=None, device_map=None):
if net is not None:
serialized_proto = None
if isinstance(net, core.Net):
self._NNModule = C.NNModuleFromProtobuf(net.Proto().SerializeToString())
serialized_proto = net.Proto().SerializeToString()
elif isinstance(net, caffe2_pb2.NetDef):
self._NNModule = C.NNModuleFromProtobuf(net.SerializeToString())
serialized_proto = net.SerializeToString()
# Distributed
if device_map is not None:
serialized_device_map = {}
for k in device_map:
serialized_device_map[k] = device_map[k].SerializeToString()
self._NNModule = C.NNModuleFromProtobufDistributed(serialized_proto,
serialized_device_map)
# Default
elif serialized_proto:
self._NNModule = C.NNModuleFromProtobuf(serialized_proto)
else:
raise Exception(
"NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"