mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Move ONNX integration tests from onnx-fb-universe to PyTorch repo
* Switch to use torchvision
* Delete single rnn operator tests, they have been covered in e2e tests in test_caffe2.py
* Mirror the fix in onnx-fb-universe to bypass cuda check
667326d84b
86 lines
3.3 KiB
Python
86 lines
3.3 KiB
Python
import io
|
|
import torch.onnx
|
|
import onnx
|
|
from caffe2.python.onnx.backend import Caffe2Backend
|
|
from caffe2.python.core import BlobReference, Net
|
|
|
|
|
|
_next_idx = 0
|
|
# Clone net takes a dict instead of a lambda
|
|
# It should probably take a lambda, it is more flexible
|
|
# We fake dict here
|
|
|
|
|
|
class _FakeDict(object):
|
|
def __init__(self, fn):
|
|
self.fn = fn
|
|
|
|
def get(self, name, _):
|
|
return self.fn(name)
|
|
|
|
|
|
def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None):
|
|
"""
|
|
Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built.
|
|
|
|
Arguments:
|
|
helper (caffe2.python.core.ModelHelder): the model helper where
|
|
this imported network should be inserted
|
|
model (torch.nn.Module): the model to be exported
|
|
sample_arguments (tuple of arguments): the inputs to
|
|
the model, e.g., such that ``model(*args)`` is a valid
|
|
invocation of the model. Any non-Variable arguments will
|
|
be hard-coded into the exported model; any Variable arguments
|
|
will become inputs of the exported model, in the order they
|
|
occur in args. If args is a Variable, this is equivalent
|
|
to having called it with a 1-ary tuple of that Variable.
|
|
(Note: passing keyword arguments to the model is not currently
|
|
supported. Give us a shout if you need it.)
|
|
caffe2_inputs (list of str or caffe2.python.core.BlobReference): the
|
|
caffe2 Blobs that should be inputs to this network. Must be
|
|
the same length as sample_arguments
|
|
prefix_name: prefix name to add to each member of the blob, if None then
|
|
a fresh prefix pytorch_input_N/ is used
|
|
Returns:
|
|
A tuple of caffe2.python.core.BlobReference objects referring to the
|
|
models outputs, or a single BlobReference when the model returns a single
|
|
value.
|
|
"""
|
|
if prefix_name is None:
|
|
global _next_idx
|
|
prefix_name = 'pytorch_import_' + str(_next_idx) + '/'
|
|
_next_idx += 1
|
|
|
|
# TODO: handle the case where model cannot be exported
|
|
# and embed as a Python op in Caffe2
|
|
f = io.BytesIO()
|
|
torch.onnx.export(
|
|
model, sample_arguments, f, export_params=True)
|
|
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
|
init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(
|
|
onnx_model)
|
|
|
|
initialized = set([x.name for x in onnx_model.graph.initializer])
|
|
uninitialized_inputs = {x.name: i for i, x in enumerate(
|
|
onnx_model.graph.input) if x.name not in initialized}
|
|
|
|
if(len(uninitialized_inputs) != len(caffe2_inputs)):
|
|
raise ValueError('Expected {} inputs but found {}'.format(
|
|
len(uninitialized_inputs), len(caffe2_inputs)))
|
|
|
|
def remap_blob_name(name):
|
|
if name in uninitialized_inputs:
|
|
idx = uninitialized_inputs[name]
|
|
return str(caffe2_inputs[idx])
|
|
return prefix_name + name
|
|
|
|
predict_net = Net(predict_net).Clone('anon', _FakeDict(remap_blob_name))
|
|
helper.net.AppendNet(predict_net)
|
|
|
|
init_net = Net(init_net).Clone('anon', _FakeDict(remap_blob_name))
|
|
helper.param_init_net.AppendNet(init_net)
|
|
|
|
results = tuple([BlobReference(remap_blob_name(x.name), helper.net)
|
|
for x in onnx_model.graph.output])
|
|
return results
|