mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
Add python bindings (#12253)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12253 Adding python bindings to unblock DAI development Reviewed By: duc0 Differential Revision: D10141621 fbshipit-source-id: efac7fb8a0cc787e1c4cc94515e673812529a997
This commit is contained in:
committed by
Facebook Github Bot
parent
e7653c7561
commit
7103d0d938
@ -12,12 +12,24 @@ import errno
|
||||
|
||||
|
||||
class NNModule(object):
|
||||
def __init__(self, net=None):
|
||||
def __init__(self, net=None, device_map=None):
|
||||
if net is not None:
|
||||
serialized_proto = None
|
||||
if isinstance(net, core.Net):
|
||||
self._NNModule = C.NNModuleFromProtobuf(net.Proto().SerializeToString())
|
||||
serialized_proto = net.Proto().SerializeToString()
|
||||
elif isinstance(net, caffe2_pb2.NetDef):
|
||||
self._NNModule = C.NNModuleFromProtobuf(net.SerializeToString())
|
||||
serialized_proto = net.SerializeToString()
|
||||
|
||||
# Distributed
|
||||
if device_map is not None:
|
||||
serialized_device_map = {}
|
||||
for k in device_map:
|
||||
serialized_device_map[k] = device_map[k].SerializeToString()
|
||||
self._NNModule = C.NNModuleFromProtobufDistributed(serialized_proto,
|
||||
serialized_device_map)
|
||||
# Default
|
||||
elif serialized_proto:
|
||||
self._NNModule = C.NNModuleFromProtobuf(serialized_proto)
|
||||
else:
|
||||
raise Exception(
|
||||
"NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
|
||||
|
||||
Reference in New Issue
Block a user