Fix bug in exporting node with multiple outputs by scripting

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20256

Differential Revision: D15422040

Pulled By: houseroad

fbshipit-source-id: 5de2a992d7d99a48905c39a1878eb0b3b68d6a3f
This commit is contained in:
BowenBao
2019-05-22 16:22:01 -07:00
committed by Facebook Github Bot
parent c2e3e79afc
commit 28be521e39
4 changed files with 28 additions and 8 deletions

View File

@ -77,6 +77,11 @@ def skipIfEmbed(func):
def do_export(model, inputs, *args, **kwargs):
f = io.BytesIO()
out = torch.onnx._export(model, inputs, f, *args, **kwargs)
if isinstance(model, torch.jit.ScriptModule):
# Special case for common case of passing a single Tensor
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
out = model(*inputs)
return f.getvalue(), out
@ -178,7 +183,7 @@ class TestCaffe2Backend(unittest.TestCase):
# Verify the model runs the same in Caffe2
verify.verify(model, input, c2, rtol=rtol, atol=atol,
do_constant_folding=do_constant_folding)
example_outputs=example_outputs, do_constant_folding=do_constant_folding)
def run_model_test(self, model, train, batch_size, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
@ -1592,10 +1597,19 @@ class TestCaffe2Backend(unittest.TestCase):
class TopKModel(torch.nn.Module):
def forward(self, input):
return torch.topk(input, 3)
model = TopKModel()
x = torch.arange(1., 6.)
self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
def test_topk_script(self):
class TopKModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return torch.topk(input, 3, dim=0)
x = torch.randn(4, 3, requires_grad=True)
self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE, example_outputs=torch.topk(x, 3, dim=0))
def test_floor(self):
class FloorModel(torch.nn.Module):
def forward(self, input):

View File

@ -244,7 +244,7 @@ def set_training(model, mode):
def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7,
test_args=2, do_constant_folding=False):
test_args=2, do_constant_folding=False, example_outputs=None):
"""
Export a model into ONNX, import it into a specified ONNX backend, and then
on a few random inputs verify that PyTorch and the backend produced the same
@ -358,14 +358,18 @@ def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=
with set_training(model, training):
proto_bytes = io.BytesIO()
torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose,
do_constant_folding=do_constant_folding)
do_constant_folding=do_constant_folding, example_outputs=example_outputs)
if isinstance(model, torch.jit.ScriptModule):
torch_out = model(*args)
proto = load_bytes(proto_bytes)
prepared = backend.prepare(proto)
def run(args):
alt_proto_bytes = io.BytesIO()
torch_out = torch.onnx._export(model, args, alt_proto_bytes, verbose=verbose,
do_constant_folding=do_constant_folding)
do_constant_folding=do_constant_folding, example_outputs=example_outputs)
if isinstance(model, torch.jit.ScriptModule):
torch_out = model(*args)
alt_proto = load_bytes(alt_proto_bytes)
if proto.SerializeToString() != alt_proto.SerializeToString():
# OK, let's try to figure out what happened.

View File

@ -223,7 +223,7 @@ static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
output_values = output_values.at(0)->node()->inputs();
}
AT_ASSERT(output_values.size() == outputs.size());
for (size_t i = 0; i < retval->outputs().size(); ++i) {
for (size_t i = 0; i < outputs.size(); ++i) {
auto scalar_type = outputs[i].scalar_type();
auto sizes = outputs[i].sizes();
auto type =

View File

@ -56,7 +56,7 @@ def set_training(model, mode):
def export(model, args, f, export_params=True, verbose=False, training=False,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=False, strip_doc_string=True):
do_constant_folding=False, example_outputs=None, strip_doc_string=True):
r"""
Export a model into ONNX format. This exporter runs your model
once in order to get a trace of its execution to be exported;
@ -112,6 +112,8 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
optimization is applied to the model during export. Constant-folding
optimization will replace some of the ops that have all constant
inputs, with pre-computed constant nodes.
example_outputs (tuple of Tensors, default None): example_outputs must be provided
when exporting a ScriptModule or TorchScript Function.
strip_doc_string (bool, default True): if True, strips the field
"doc_string" from the exported model, which information about the stack
trace.
@ -128,7 +130,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
_export(model, args, f, export_params, verbose, training, input_names, output_names,
operator_export_type=operator_export_type, opset_version=opset_version,
_retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
strip_doc_string=strip_doc_string)
example_outputs=example_outputs, strip_doc_string=strip_doc_string)
# ONNX can't handle constants that are lists of tensors, which can