diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 2504a8778c9a..deecd3630838 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -77,6 +77,11 @@ def skipIfEmbed(func): def do_export(model, inputs, *args, **kwargs): f = io.BytesIO() out = torch.onnx._export(model, inputs, f, *args, **kwargs) + if isinstance(model, torch.jit.ScriptModule): + # Special case for common case of passing a single Tensor + if isinstance(inputs, torch.Tensor): + inputs = (inputs,) + out = model(*inputs) return f.getvalue(), out @@ -178,7 +183,7 @@ class TestCaffe2Backend(unittest.TestCase): # Verify the model runs the same in Caffe2 verify.verify(model, input, c2, rtol=rtol, atol=atol, - do_constant_folding=do_constant_folding) + example_outputs=example_outputs, do_constant_folding=do_constant_folding) def run_model_test(self, model, train, batch_size, state_dict=None, input=None, use_gpu=True, rtol=0.001, atol=1e-7, @@ -1592,10 +1597,19 @@ class TestCaffe2Backend(unittest.TestCase): class TopKModel(torch.nn.Module): def forward(self, input): return torch.topk(input, 3) - model = TopKModel() + x = torch.arange(1., 6.) self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE) + def test_topk_script(self): + class TopKModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input): + return torch.topk(input, 3, dim=0) + + x = torch.randn(4, 3, requires_grad=True) + self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE, example_outputs=torch.topk(x, 3, dim=0)) + def test_floor(self): class FloorModel(torch.nn.Module): def forward(self, input): diff --git a/test/onnx/verify.py b/test/onnx/verify.py index b687a99962c1..b104dca726cb 100644 --- a/test/onnx/verify.py +++ b/test/onnx/verify.py @@ -244,7 +244,7 @@ def set_training(model, mode): def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7, - test_args=2, do_constant_folding=False): + test_args=2, do_constant_folding=False, example_outputs=None): """ Export a model into ONNX, import it into a specified ONNX backend, and then on a few random inputs verify that PyTorch and the backend produced the same @@ -358,14 +358,18 @@ def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol= with set_training(model, training): proto_bytes = io.BytesIO() torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose, - do_constant_folding=do_constant_folding) + do_constant_folding=do_constant_folding, example_outputs=example_outputs) + if isinstance(model, torch.jit.ScriptModule): + torch_out = model(*args) proto = load_bytes(proto_bytes) prepared = backend.prepare(proto) def run(args): alt_proto_bytes = io.BytesIO() torch_out = torch.onnx._export(model, args, alt_proto_bytes, verbose=verbose, - do_constant_folding=do_constant_folding) + do_constant_folding=do_constant_folding, example_outputs=example_outputs) + if isinstance(model, torch.jit.ScriptModule): + torch_out = model(*args) alt_proto = load_bytes(alt_proto_bytes) if proto.SerializeToString() != alt_proto.SerializeToString(): # OK, let's try to figure out what happened. diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 19b5aa4c7ce8..5003ea21b4ab 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -223,7 +223,7 @@ static std::shared_ptr _propagate_and_assign_input_and_output_shapes( output_values = output_values.at(0)->node()->inputs(); } AT_ASSERT(output_values.size() == outputs.size()); - for (size_t i = 0; i < retval->outputs().size(); ++i) { + for (size_t i = 0; i < outputs.size(); ++i) { auto scalar_type = outputs[i].scalar_type(); auto sizes = outputs[i].sizes(); auto type = diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 0a8b884c5749..970fda479850 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -56,7 +56,7 @@ def set_training(model, mode): def export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, - do_constant_folding=False, strip_doc_string=True): + do_constant_folding=False, example_outputs=None, strip_doc_string=True): r""" Export a model into ONNX format. This exporter runs your model once in order to get a trace of its execution to be exported; @@ -112,6 +112,8 @@ def export(model, args, f, export_params=True, verbose=False, training=False, optimization is applied to the model during export. Constant-folding optimization will replace some of the ops that have all constant inputs, with pre-computed constant nodes. + example_outputs (tuple of Tensors, default None): example_outputs must be provided + when exporting a ScriptModule or TorchScript Function. strip_doc_string (bool, default True): if True, strips the field "doc_string" from the exported model, which information about the stack trace. @@ -128,7 +130,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False, _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type=operator_export_type, opset_version=opset_version, _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding, - strip_doc_string=strip_doc_string) + example_outputs=example_outputs, strip_doc_string=strip_doc_string) # ONNX can't handle constants that are lists of tensors, which can