mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PyTorch restricts activations to be in the range (0, 127). In ONNX, the supported ranges are (0, 255) and (-128, 127), respectfully, uint8 and int8. This PR extends support for range (0, 127), by adding additional clipping when detected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76055 Approved by: https://github.com/garymm
11129 lines
420 KiB
Python
11129 lines
420 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
import unittest
|
|
import onnxruntime
|
|
import torch
|
|
import torchvision
|
|
|
|
import numpy as np
|
|
import io
|
|
import itertools
|
|
import copy
|
|
import os
|
|
import random
|
|
|
|
import model_defs.word_language_model as word_language_model
|
|
import onnx
|
|
|
|
import torch.nn.functional as F
|
|
from torch.nn.utils import rnn as rnn_utils
|
|
from model_defs.lstm_flattening_result import (LstmFlatteningResultWithSeqLength,
|
|
LstmFlatteningResultWithoutSeqLength)
|
|
from model_defs.rnn_model_with_packed_sequence import (RnnModelWithPackedSequence,
|
|
RnnModelWithPackedSequenceWithState,
|
|
RnnModelWithPackedSequenceWithoutState)
|
|
from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion,
|
|
skipIfNoLapack, disableScriptTest, skipIfUnsupportedMaxOpsetVersion)
|
|
from test_pytorch_common import BATCH_SIZE
|
|
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
|
|
from typing import List, Tuple, Optional, Dict, Union
|
|
from torch import Tensor
|
|
|
|
from torchvision import ops
|
|
from torchvision.models.detection.image_list import ImageList
|
|
from torchvision.models.detection.transform import GeneralizedRCNNTransform
|
|
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
|
|
from torchvision.models.detection.roi_heads import RoIHeads
|
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
|
|
from collections import OrderedDict
|
|
|
|
from torch.nn.utils.rnn import PackedSequence
|
|
from torch.onnx import CheckerError, register_custom_op_symbolic, unregister_custom_op_symbolic
|
|
from torch.onnx.symbolic_helper import _unimplemented
|
|
from torch.onnx.utils import unpack_quantized_tensor
|
|
|
|
|
|
_ORT_PROVIDERS = ["CPUExecutionProvider"]
|
|
|
|
def flatten_tuples(elem):
|
|
tup = []
|
|
for t in elem:
|
|
if isinstance(t, (tuple)):
|
|
tup += flatten_tuples(t)
|
|
else:
|
|
tup += [t]
|
|
return tup
|
|
|
|
|
|
def to_numpy(elem):
|
|
if isinstance(elem, torch.Tensor):
|
|
if elem.requires_grad:
|
|
return elem.detach().cpu().numpy()
|
|
else:
|
|
return elem.cpu().numpy()
|
|
elif isinstance(elem, list) or isinstance(elem, tuple):
|
|
return [to_numpy(inp) for inp in elem]
|
|
elif isinstance(elem, bool):
|
|
return np.array(elem, dtype=bool)
|
|
elif isinstance(elem, int):
|
|
return np.array(elem, dtype=int)
|
|
elif isinstance(elem, float):
|
|
return np.array(elem, dtype=float)
|
|
elif isinstance(elem, dict):
|
|
dict_ = []
|
|
for k in elem:
|
|
dict_ += [to_numpy(k)] + [to_numpy(elem[k])]
|
|
return dict_
|
|
else:
|
|
return RuntimeError("Input has unknown type.")
|
|
|
|
|
|
def convert_to_onnx(model, input=None, opset_version=9, do_constant_folding=True,
|
|
keep_initializers_as_inputs=True, dynamic_axes=None,
|
|
input_names=None, output_names=None,
|
|
fixed_batch_size=False, training=None,
|
|
verbose=False):
|
|
f = io.BytesIO()
|
|
input_copy = copy.deepcopy(input)
|
|
torch.onnx._export(model, input_copy, f,
|
|
opset_version=opset_version,
|
|
do_constant_folding=do_constant_folding,
|
|
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names, output_names=output_names,
|
|
fixed_batch_size=fixed_batch_size, training=training,
|
|
verbose=verbose)
|
|
|
|
# compute onnxruntime output prediction
|
|
so = onnxruntime.SessionOptions()
|
|
# suppress ort warnings.
|
|
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
|
so.log_severity_level = 3
|
|
ort_sess = onnxruntime.InferenceSession(f.getvalue(), so, providers=_ORT_PROVIDERS)
|
|
return ort_sess
|
|
|
|
|
|
def inline_flatten_list(inputs, res_list):
|
|
for i in inputs:
|
|
res_list.append(i) if not isinstance(i, (list, tuple)) else inline_flatten_list(i, res_list)
|
|
return res_list
|
|
|
|
|
|
def unpack_to_numpy(value):
|
|
value_unpacked = []
|
|
for value_ in value:
|
|
value_unpacked.extend(unpack_quantized_tensor(value_))
|
|
value_final = [to_numpy(v) for v in value_unpacked]
|
|
return value_final
|
|
|
|
|
|
def run_ort(ort_sess, input):
|
|
input = unpack_to_numpy(flatten_tuples(input))
|
|
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(input))
|
|
ort_outs = ort_sess.run(None, ort_inputs)
|
|
return inline_flatten_list(ort_outs, [])
|
|
|
|
|
|
def ort_compare_with_pytorch(ort_outs, output, rtol, atol):
|
|
output, _ = torch.jit._flatten(output)
|
|
outputs = unpack_to_numpy(output)
|
|
|
|
# compare onnxruntime and PyTorch results
|
|
assert len(outputs) == len(ort_outs), "number of outputs differ"
|
|
|
|
# compare onnxruntime and PyTorch results
|
|
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
|
|
|
|
|
|
def run_model_test(self, model, batch_size=2, state_dict=None,
|
|
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
|
|
do_constant_folding=True, dynamic_axes=None,
|
|
test_with_inputs=None, input_names=None,
|
|
output_names=None, fixed_batch_size=False,
|
|
dict_check=True, training=None,
|
|
remained_onnx_input_idx=None, flatten=True,
|
|
verbose=False):
|
|
if training is not None and training == torch.onnx.TrainingMode.TRAINING:
|
|
model.train()
|
|
elif training is None or training == torch.onnx.TrainingMode.EVAL:
|
|
model.eval()
|
|
if input is None:
|
|
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
|
|
with torch.no_grad():
|
|
if isinstance(input, torch.Tensor):
|
|
input = (input,)
|
|
# In-place operators will update input tensor data as well.
|
|
# Thus inputs are replicated before every forward call.
|
|
if isinstance(input, dict):
|
|
input = (input,)
|
|
input_args = copy.deepcopy(input)
|
|
input_kwargs = {}
|
|
if dict_check and isinstance(input_args[-1], dict):
|
|
input_kwargs = input_args[-1]
|
|
input_args = input_args[:-1]
|
|
try:
|
|
model_copy = copy.deepcopy(model)
|
|
output = model_copy(*input_args, **input_kwargs)
|
|
except Exception:
|
|
output = model(*input_args, **input_kwargs)
|
|
if isinstance(output, torch.Tensor):
|
|
output = (output,)
|
|
|
|
if not dict_check and isinstance(input[-1], dict):
|
|
input = input + ({},)
|
|
|
|
ort_sess = convert_to_onnx(model, input=input, opset_version=self.opset_version,
|
|
do_constant_folding=do_constant_folding,
|
|
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
|
|
dynamic_axes=dynamic_axes, input_names=input_names,
|
|
output_names=output_names, fixed_batch_size=fixed_batch_size, training=training,
|
|
verbose=verbose)
|
|
# compute onnxruntime output prediction
|
|
if remained_onnx_input_idx is not None:
|
|
input_onnx = []
|
|
for idx in remained_onnx_input_idx:
|
|
input_onnx.append(input[idx])
|
|
input = input_onnx
|
|
|
|
input_copy = copy.deepcopy(input)
|
|
if flatten:
|
|
input_copy, _ = torch.jit._flatten(input_copy)
|
|
|
|
ort_outs = run_ort(ort_sess, input_copy)
|
|
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
|
|
|
|
|
|
# if additional test inputs are provided run the onnx
|
|
# model with these inputs and check the outputs
|
|
if test_with_inputs is not None:
|
|
for test_input in test_with_inputs:
|
|
if isinstance(test_input, torch.Tensor):
|
|
test_input = (test_input,)
|
|
test_input_copy = copy.deepcopy(test_input)
|
|
output = model(*test_input_copy)
|
|
if isinstance(output, torch.Tensor):
|
|
output = (output,)
|
|
if remained_onnx_input_idx is not None:
|
|
test_input_onnx = []
|
|
for idx in remained_onnx_input_idx:
|
|
test_input_onnx.append(test_input[idx])
|
|
test_input = test_input_onnx
|
|
if flatten:
|
|
test_input, _ = torch.jit._flatten(test_input)
|
|
ort_outs = run_ort(ort_sess, test_input)
|
|
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
|
|
|
|
|
|
def _init_test_generalized_rcnn_transform():
|
|
min_size = 100
|
|
max_size = 200
|
|
image_mean = [0.485, 0.456, 0.406]
|
|
image_std = [0.229, 0.224, 0.225]
|
|
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
|
|
return transform
|
|
|
|
|
|
def _init_test_rpn():
|
|
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
|
|
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
|
|
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
|
|
out_channels = 256
|
|
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
|
|
rpn_fg_iou_thresh = 0.7
|
|
rpn_bg_iou_thresh = 0.3
|
|
rpn_batch_size_per_image = 256
|
|
rpn_positive_fraction = 0.5
|
|
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
|
|
rpn_post_nms_top_n = dict(training=2000, testing=1000)
|
|
rpn_nms_thresh = 0.7
|
|
rpn_score_thresh = 0.0
|
|
|
|
rpn = RegionProposalNetwork(
|
|
rpn_anchor_generator, rpn_head,
|
|
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
|
|
rpn_batch_size_per_image, rpn_positive_fraction,
|
|
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
|
|
score_thresh=rpn_score_thresh)
|
|
return rpn
|
|
|
|
def _init_test_roi_heads_faster_rcnn():
|
|
out_channels = 256
|
|
num_classes = 91
|
|
|
|
box_fg_iou_thresh = 0.5
|
|
box_bg_iou_thresh = 0.5
|
|
box_batch_size_per_image = 512
|
|
box_positive_fraction = 0.25
|
|
bbox_reg_weights = None
|
|
box_score_thresh = 0.05
|
|
box_nms_thresh = 0.5
|
|
box_detections_per_img = 100
|
|
|
|
box_roi_pool = ops.MultiScaleRoIAlign(
|
|
featmap_names=["0", "1", "2", "3"],
|
|
output_size=7,
|
|
sampling_ratio=2)
|
|
|
|
resolution = box_roi_pool.output_size[0]
|
|
representation_size = 1024
|
|
box_head = TwoMLPHead(
|
|
out_channels * resolution ** 2,
|
|
representation_size)
|
|
|
|
representation_size = 1024
|
|
box_predictor = FastRCNNPredictor(
|
|
representation_size,
|
|
num_classes)
|
|
|
|
roi_heads = RoIHeads(
|
|
box_roi_pool, box_head, box_predictor,
|
|
box_fg_iou_thresh, box_bg_iou_thresh,
|
|
box_batch_size_per_image, box_positive_fraction,
|
|
bbox_reg_weights,
|
|
box_score_thresh, box_nms_thresh, box_detections_per_img)
|
|
return roi_heads
|
|
|
|
def _construct_tensor_for_quantization_test(shape: Tuple[int],
|
|
offset: Optional[Union[int, float]] = None,
|
|
max_val: Optional[Union[int, float]] = None
|
|
) -> torch.Tensor:
|
|
"""Helper function to generate weights and test inputs in a deterministic way.
|
|
|
|
Due to difference in implementation details between PyTorch and ONNXRuntime, randomly generated
|
|
test data for quantization tests can be flaky. To help stablize the test, this helper function is
|
|
used to generate weights and test inputs in a deterministic way.
|
|
|
|
Args:
|
|
shape (Tuple[int]): Shape for tensor to construct.
|
|
offset (Optional[Union[int, float]]): Offset to be added to the generated tensor.
|
|
max_val (Optional[Union[int, float]]): If any element within tensor has a larger absolute value than
|
|
max_val, the tensor will be scaled by max_val / tensor.abs().max(). This step is done after
|
|
applying offset.
|
|
"""
|
|
tensor = torch.arange(np.prod(shape), dtype=torch.float).view(shape)
|
|
if offset is not None:
|
|
tensor = tensor + offset
|
|
if max_val is not None and tensor.abs().max() > max_val:
|
|
tensor = tensor * max_val / tensor.abs().max()
|
|
return tensor
|
|
|
|
def set_rng_seed(seed):
|
|
torch.manual_seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
class _TestONNXRuntime:
|
|
"""Abstract base class for test cases.
|
|
|
|
Intentionally not a sub-class of unittest.TestCase so that unittest / pytest
|
|
don't run it directly. unitest.TestCase is mixed in as another base class when
|
|
creating concrete sub-types. See MakeTestCase().
|
|
"""
|
|
opset_version = -1 # Sub-classes must override
|
|
keep_initializers_as_inputs = True # For IR version 3 type export.
|
|
|
|
def setUp(self):
|
|
torch.manual_seed(0)
|
|
onnxruntime.set_seed(0)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(0)
|
|
np.random.seed(seed=0)
|
|
os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0"
|
|
self.is_script_test_enabled = True
|
|
|
|
# The exported ONNX model may have less inputs than the pytorch model because of const folding.
|
|
# This mostly happens in unit test, where we widely use torch.size or torch.shape.
|
|
# So the output is only dependent on the input shape, not value.
|
|
# remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model.
|
|
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True,
|
|
batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None,
|
|
input_names=None, output_names=None, fixed_batch_size=False, dict_check=True,
|
|
training=None, remained_onnx_input_idx=None, verbose=False):
|
|
def _run_test(m, remained_onnx_input_idx, flatten=True):
|
|
return run_model_test(self, m, batch_size=batch_size,
|
|
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
|
|
do_constant_folding=do_constant_folding,
|
|
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs,
|
|
input_names=input_names, output_names=output_names,
|
|
fixed_batch_size=fixed_batch_size, dict_check=dict_check,
|
|
training=training, remained_onnx_input_idx=remained_onnx_input_idx,
|
|
flatten=flatten, verbose=verbose)
|
|
|
|
if isinstance(remained_onnx_input_idx, dict):
|
|
scripting_remained_onnx_input_idx = remained_onnx_input_idx['scripting']
|
|
tracing_remained_onnx_input_idx = remained_onnx_input_idx['tracing']
|
|
else:
|
|
scripting_remained_onnx_input_idx = remained_onnx_input_idx
|
|
tracing_remained_onnx_input_idx = remained_onnx_input_idx
|
|
|
|
if self.is_script_test_enabled and not isinstance(model, torch.jit.ScriptModule):
|
|
script_model = torch.jit.script(model)
|
|
_run_test(script_model, scripting_remained_onnx_input_idx, flatten=False)
|
|
|
|
_run_test(model, tracing_remained_onnx_input_idx)
|
|
|
|
def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7,
|
|
do_constant_folding=True, dynamic_axes=None,
|
|
input_names=None, output_names=None,
|
|
ort_optim_on=True, training=None):
|
|
import os
|
|
import tempfile
|
|
|
|
if training is not None and training == torch.onnx.TrainingMode.TRAINING:
|
|
model.train()
|
|
elif training is None or training == torch.onnx.TrainingMode.EVAL:
|
|
model.eval()
|
|
with torch.no_grad():
|
|
if isinstance(input, torch.Tensor):
|
|
input = (input,)
|
|
# In-place operators will update input tensor data as well.
|
|
# Thus inputs are replicated before every forward call.
|
|
input_copy = copy.deepcopy(input)
|
|
output = model(*input_copy)
|
|
if isinstance(output, torch.Tensor):
|
|
output = (output,)
|
|
|
|
# export the model to ONNX
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model_file_name = os.path.join(tmpdirname, "model.onnx")
|
|
input_copy = copy.deepcopy(input)
|
|
torch.onnx.export(model, input_copy, model_file_name,
|
|
opset_version=self.opset_version,
|
|
verbose=False,
|
|
do_constant_folding=do_constant_folding,
|
|
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names, output_names=output_names)
|
|
# compute onnxruntime output prediction
|
|
ort_sess_opt = onnxruntime.SessionOptions()
|
|
ort_sess_opt.graph_optimization_level = \
|
|
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED if ort_optim_on else \
|
|
onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
# suppress ort warnings.
|
|
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
|
ort_sess_opt.log_severity_level = 3
|
|
ort_sess = onnxruntime.InferenceSession(model_file_name,
|
|
sess_options=ort_sess_opt,
|
|
providers=_ORT_PROVIDERS)
|
|
input_copy = copy.deepcopy(input)
|
|
ort_outs = run_ort(ort_sess, input_copy)
|
|
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
|
|
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
|
|
def test_embedding_model_with_external_data(self):
|
|
class LargeModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LargeModel, self).__init__()
|
|
dim = 15
|
|
n = 4 * 100
|
|
self.emb = torch.nn.Embedding(n, dim)
|
|
self.lin1 = torch.nn.Linear(dim, 1)
|
|
self.seq = torch.nn.Sequential(
|
|
self.emb,
|
|
self.lin1,
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.seq(input)
|
|
|
|
model = LargeModel()
|
|
x = torch.tensor([2], dtype=torch.long)
|
|
self.run_model_test_with_external_data(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
|
|
def test_large_model_with_external_data(self):
|
|
class LargeModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LargeModel, self).__init__()
|
|
dim = 5
|
|
n = 40 * 4 * 10 ** 6
|
|
self.emb = torch.nn.Embedding(n, dim)
|
|
self.lin1 = torch.nn.Linear(dim, 1)
|
|
self.seq = torch.nn.Sequential(
|
|
self.emb,
|
|
self.lin1,
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.seq(input)
|
|
|
|
x = torch.tensor([2], dtype=torch.long)
|
|
self.run_model_test_with_external_data(LargeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
|
|
def test_large_model_with_non_str_file(self):
|
|
class LargeModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LargeModel, self).__init__()
|
|
dim = 5
|
|
n = 40 * 4 * 10 ** 6
|
|
self.emb = torch.nn.Embedding(n, dim)
|
|
self.lin1 = torch.nn.Linear(dim, 1)
|
|
self.seq = torch.nn.Sequential(
|
|
self.emb,
|
|
self.lin1,
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.seq(input)
|
|
|
|
x = torch.tensor([2], dtype=torch.long)
|
|
f = io.BytesIO()
|
|
err_msg = ("The serialized model is larger than the 2GiB limit imposed by the protobuf library. "
|
|
"Therefore the output file must be a file path, so that the ONNX external data can be written to "
|
|
"the same directory. Please specify the output file name.")
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
torch.onnx.export(LargeModel(), x, f)
|
|
|
|
def test_fuse_conv_bn1d(self):
|
|
class Fuse(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Fuse, self).__init__()
|
|
self.conv = torch.nn.Conv1d(16, 33, 3, stride=2)
|
|
self.bn = torch.nn.BatchNorm1d(33)
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
return self.bn(out)
|
|
|
|
model = Fuse()
|
|
x = torch.randn(20, 16, 50, requires_grad=True)
|
|
self.run_test(model, (x,))
|
|
|
|
def test_fuse_conv_bn2d(self):
|
|
class Fuse(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Fuse, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=False)
|
|
self.bn = torch.nn.BatchNorm2d(2)
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
return self.bn(out)
|
|
|
|
model = Fuse()
|
|
x = torch.randn(2, 3, 2, 2, requires_grad=True)
|
|
self.run_test(model, (x,))
|
|
|
|
def test_fuse_conv_bn3d(self):
|
|
class Fuse(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Fuse, self).__init__()
|
|
self.conv = torch.nn.Conv3d(3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False)
|
|
self.bn = torch.nn.BatchNorm3d(2)
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
return self.bn(out)
|
|
|
|
model = Fuse()
|
|
x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
|
|
self.run_test(model, (x,), rtol=1e-3, atol=1e-6)
|
|
|
|
def test_fuse_conv_in_block(self):
|
|
class Fuse(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Fuse, self).__init__()
|
|
self.conv = torch.nn.Conv1d(
|
|
in_channels=5,
|
|
out_channels=5,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=2,
|
|
dilation=1
|
|
)
|
|
self.bn = torch.nn.BatchNorm1d(5)
|
|
|
|
def forward(self, x):
|
|
results_available = True
|
|
|
|
if x.sum() > -1:
|
|
results_available = False
|
|
|
|
if results_available:
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
|
|
return x
|
|
|
|
model = Fuse()
|
|
x = torch.randn(2, 5, 9, requires_grad=True)
|
|
self.run_test(torch.jit.script(model), (x,),
|
|
input_names=['x'], dynamic_axes={'x': [0, 2]},
|
|
rtol=1e-3, atol=1e-6)
|
|
|
|
def test_conv_tbc(self):
|
|
from torch.nn.modules.utils import _single
|
|
|
|
class ConvTBC(torch.nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
|
|
super(ConvTBC, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = _single(kernel_size)
|
|
self.padding = _single(padding)
|
|
|
|
self.weight = torch.nn.Parameter(
|
|
torch.Tensor(self.kernel_size[0], in_channels, out_channels)
|
|
)
|
|
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
torch.nn.init.xavier_normal_(self.weight)
|
|
torch.nn.init.zeros_(self.bias)
|
|
|
|
def conv_tbc(self, input):
|
|
return torch.conv_tbc(
|
|
input.contiguous(), self.weight, self.bias, self.padding[0]
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.conv_tbc(input)
|
|
|
|
in_channels = 3
|
|
out_channels = 5
|
|
kernel_size = 5
|
|
model = ConvTBC(in_channels, out_channels, kernel_size, padding=0)
|
|
x = torch.randn(10, 7, in_channels, requires_grad=True)
|
|
self.run_test(model, (x,), atol=1e-5)
|
|
|
|
def test_reshape_constant_fold(self):
|
|
class Reshape(torch.nn.Module):
|
|
def __init__(self, ):
|
|
super(Reshape, self).__init__()
|
|
self.register_buffer("weight", torch.ones(5))
|
|
|
|
def forward(self, x):
|
|
scale_1 = self.weight.reshape(1, -1, 1, 1)
|
|
return x * scale_1
|
|
|
|
x = torch.randn(4, 5)
|
|
self.run_test(Reshape(), (x,), rtol=1e-3, atol=1e-5)
|
|
|
|
def run_word_language_model(self, model_name):
|
|
ntokens = 50
|
|
emsize = 5
|
|
nhid = 5
|
|
nlayers = 5
|
|
dropout = 0.2
|
|
tied = False
|
|
batchsize = 5
|
|
if model_name == "GRU":
|
|
model = word_language_model.RNNModelWithTensorHidden(model_name, ntokens, emsize,
|
|
nhid, nlayers, dropout, tied,
|
|
batchsize)
|
|
elif model_name == "LSTM":
|
|
model = word_language_model.RNNModelWithTupleHidden(model_name, ntokens, emsize,
|
|
nhid, nlayers, dropout, tied,
|
|
batchsize)
|
|
else:
|
|
model = word_language_model.RNNModel(model_name, ntokens, emsize,
|
|
nhid, nlayers, dropout, tied,
|
|
batchsize)
|
|
x = torch.arange(0, ntokens).long().view(-1, batchsize)
|
|
# Only support CPU version, since tracer is not working in GPU RNN.
|
|
self.run_test(model, (x, model.hidden))
|
|
|
|
def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
|
|
import os
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
|
|
data_dir = os.path.join(os.path.dirname(__file__), "assets")
|
|
path = os.path.join(data_dir, *rel_path.split("/"))
|
|
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
|
|
|
|
return transforms.ToTensor()(image)
|
|
|
|
def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
return ([self.get_image("grace_hopper_517x606.jpg", (100, 320))],
|
|
[self.get_image("rgb_pytorch.png", (250, 380))])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # Faster RCNN model is not scriptable
|
|
def test_faster_rcnn(self):
|
|
model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=True,
|
|
min_size=200, max_size=300)
|
|
model.eval()
|
|
x1 = torch.randn(3, 200, 300, requires_grad=True)
|
|
x2 = torch.randn(3, 200, 300, requires_grad=True)
|
|
self.run_test(model, ([x1, x2],), rtol=1e-3, atol=1e-5)
|
|
self.run_test(model, ([x1, x2],), input_names=["images_tensors"], output_names=["outputs"],
|
|
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, rtol=1e-3, atol=1e-5)
|
|
dummy_image = [torch.ones(3, 100, 100) * 0.3]
|
|
images, test_images = self.get_test_images()
|
|
self.run_test(model, (images,), test_with_inputs=[(images, ), (test_images, ), (dummy_image, )],
|
|
input_names=["images_tensors"], output_names=["outputs"],
|
|
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, rtol=1e-3, atol=1e-5)
|
|
self.run_test(model, (dummy_image,), test_with_inputs=[(dummy_image, ), (images, )],
|
|
input_names=["images_tensors"], output_names=["outputs"],
|
|
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, rtol=1e-3, atol=1e-5)
|
|
|
|
def test_paste_mask_in_image(self):
|
|
masks = torch.rand(10, 1, 26, 26)
|
|
boxes = torch.rand(10, 4)
|
|
boxes[:, 2:] += torch.rand(10, 2)
|
|
boxes *= 50
|
|
o_im_s = (100, 100)
|
|
from torchvision.models.detection.roi_heads import paste_masks_in_image
|
|
out = paste_masks_in_image(masks, boxes, o_im_s)
|
|
jit_trace = torch.jit.trace(paste_masks_in_image,
|
|
(masks, boxes,
|
|
[torch.tensor(o_im_s[0]),
|
|
torch.tensor(o_im_s[1])]))
|
|
out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])
|
|
|
|
assert torch.all(out.eq(out_trace))
|
|
|
|
masks2 = torch.rand(20, 1, 26, 26)
|
|
boxes2 = torch.rand(20, 4)
|
|
boxes2[:, 2:] += torch.rand(20, 2)
|
|
boxes2 *= 100
|
|
o_im_s2 = (200, 200)
|
|
from torchvision.models.detection.roi_heads import paste_masks_in_image
|
|
out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
|
|
out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])
|
|
|
|
assert torch.all(out2.eq(out_trace2))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_mask_rcnn(self):
|
|
model = torchvision.models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=True,
|
|
min_size=200, max_size=300)
|
|
images, test_images = self.get_test_images()
|
|
self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
|
|
self.run_test(model, (images,), input_names=["images_tensors"], output_names=["boxes", "labels", "scores", "masks"],
|
|
dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0],
|
|
"scores": [0], "masks": [0, 1, 2]}, rtol=1e-3, atol=1e-5)
|
|
dummy_image = [torch.ones(3, 100, 100) * 0.3]
|
|
self.run_test(model, (images,), test_with_inputs=[(images,), (test_images,), (dummy_image,)],
|
|
input_names=["images_tensors"], output_names=["boxes", "labels", "scores", "masks"],
|
|
dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0],
|
|
"scores": [0], "masks": [0, 1, 2]}, rtol=1e-3, atol=1e-5)
|
|
self.run_test(model, (dummy_image,), test_with_inputs=[(dummy_image,), (images,)],
|
|
input_names=["images_tensors"], output_names=["boxes", "labels", "scores", "masks"],
|
|
dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0],
|
|
"scores": [0], "masks": [0, 1, 2]}, rtol=1e-3, atol=1e-5)
|
|
|
|
def test_heatmaps_to_keypoints(self):
|
|
maps = torch.rand(10, 1, 26, 26)
|
|
rois = torch.rand(10, 4)
|
|
from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
|
|
out = heatmaps_to_keypoints(maps, rois)
|
|
jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
|
|
out_trace = jit_trace(maps, rois)
|
|
|
|
assert torch.all(out[0].eq(out_trace[0]))
|
|
assert torch.all(out[1].eq(out_trace[1]))
|
|
|
|
maps2 = torch.rand(20, 2, 21, 21)
|
|
rois2 = torch.rand(20, 4)
|
|
from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
|
|
out2 = heatmaps_to_keypoints(maps2, rois2)
|
|
out_trace2 = jit_trace(maps2, rois2)
|
|
|
|
assert torch.all(out2[0].eq(out_trace2[0]))
|
|
assert torch.all(out2[1].eq(out_trace2[1]))
|
|
|
|
@unittest.skip("Failing, see https://github.com/pytorch/pytorch/issues/66528")
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_keypoint_rcnn(self):
|
|
model = torchvision.models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False,
|
|
min_size=200, max_size=300)
|
|
images, test_images = self.get_test_images()
|
|
self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
|
|
self.run_test(model, (images,), input_names=["images_tensors"],
|
|
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
|
|
dynamic_axes={"images_tensors": [0, 1, 2]},
|
|
rtol=1e-3, atol=1e-5)
|
|
dummy_images = [torch.ones(3, 100, 100) * 0.3]
|
|
self.run_test(model, (images,), test_with_inputs=[(images, ), (test_images, ), (dummy_images, )],
|
|
input_names=["images_tensors"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
|
|
dynamic_axes={"images_tensors": [0, 1, 2]},
|
|
rtol=5e-3, atol=1e-5)
|
|
self.run_test(model, (dummy_images,), test_with_inputs=[(dummy_images, ), (test_images, )],
|
|
input_names=["images_tensors"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
|
|
dynamic_axes={"images_tensors": [0, 1, 2]},
|
|
rtol=5e-3, atol=1e-5)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_shufflenet_v2_dynamic_axes(self):
|
|
model = torchvision.models.shufflenet_v2_x0_5(pretrained=False)
|
|
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
|
|
test_inputs = torch.randn(3, 3, 224, 224, requires_grad=True)
|
|
self.run_test(model, (dummy_input,), test_with_inputs=[(dummy_input,), (test_inputs,)],
|
|
input_names=["input_images"], output_names=["outputs"],
|
|
dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}},
|
|
rtol=1e-3, atol=1e-5)
|
|
|
|
@disableScriptTest()
|
|
def test_mobilenet_v3(self):
|
|
model = torchvision.models.quantization.mobilenet_v3_large(pretrained=False)
|
|
dummy_input = torch.randn(1, 3, 224, 224)
|
|
self.run_test(model, (dummy_input,))
|
|
|
|
@unittest.skip("Unstable loading pretrained quantized mobilenet v3: https://github.com/pytorch/vision/issues/5303")
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@disableScriptTest()
|
|
def test_mobilenet_v3_quant(self):
|
|
model = torchvision.models.quantization.mobilenet_v3_large(pretrained=True, quantize=True)
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
data_dir = os.path.join(os.path.dirname(__file__), "assets")
|
|
path = os.path.join(data_dir, "grace_hopper_517x606.jpg")
|
|
input_image = Image.open(path)
|
|
# Based on example from https://pytorch.org/hub/pytorch_vision_resnet/
|
|
preprocess = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
input_tensor = preprocess(input_image).unsqueeze(0)
|
|
|
|
# Due to precision error from quantization, check only that the top prediction matches.
|
|
class TopPredictor(torch.nn.Module):
|
|
def __init__(self, mobilenet):
|
|
super().__init__()
|
|
self.mobilenet = mobilenet
|
|
|
|
def forward(self, x):
|
|
x = self.mobilenet(x)
|
|
_, topk_catid = torch.topk(x[0], 1)
|
|
return topk_catid
|
|
|
|
# Currently, we need convert the model to ScriptModule before export.
|
|
# The reason is that PackedParams contains int (not tensor).
|
|
# Then it fails when the exporter calls _trace_and_get_graph_from_model().
|
|
# TODO: https://msdata.visualstudio.com/Vienna/_workitems/edit/1547858
|
|
model = torch.jit.trace(TopPredictor(model), input_tensor)
|
|
self.run_test(model, (input_tensor, ))
|
|
|
|
@disableScriptTest()
|
|
def test_word_language_model_RNN_TANH(self):
|
|
self.run_word_language_model("RNN_TANH")
|
|
|
|
@disableScriptTest()
|
|
def test_word_language_model_RNN_RELU(self):
|
|
self.run_word_language_model("RNN_RELU")
|
|
|
|
@disableScriptTest() # scripting prim::unchecked_cast prim::setattr
|
|
def test_word_language_model_LSTM(self):
|
|
self.run_word_language_model("LSTM")
|
|
|
|
def test_word_language_model_GRU(self):
|
|
self.run_word_language_model("GRU")
|
|
|
|
def test_index_1d(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[0]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
def test_index_2d_1dimslice(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[0:1, :]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
def test_index_2d_sliceint(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[1, :]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
def test_index_2d_neg_slice(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[0:-1, :]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_mask(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[torch.tensor([0, 1, 0], dtype=torch.uint8)]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[torch.tensor([0, 1, 0], dtype=torch.bool)]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_data(self):
|
|
class Data(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x.new_zeros(x.data.size())
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(Data(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(Data(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_mask_nd(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[input > 0]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), m1)
|
|
|
|
@disableScriptTest()
|
|
def test_dict(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x_in):
|
|
x_out = {}
|
|
x_out["test_key_out"] = torch.add(x_in[list(x_in.keys())[0]], list(x_in.keys())[0])
|
|
return x_out
|
|
|
|
x = {torch.tensor(1.): torch.randn(1, 2, 3)}
|
|
self.run_test(MyModel(), (x, {}))
|
|
|
|
@disableScriptTest()
|
|
def test_dict_str(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x_in):
|
|
x_out = {}
|
|
x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.)
|
|
return x_out
|
|
|
|
x = {"test_key_in": torch.randn(1, 2, 3)}
|
|
self.run_test(MyModel(), (x, {}))
|
|
|
|
@disableScriptTest() # User-defined class not supported
|
|
def test_dict_output(self):
|
|
class DictModelOutput(OrderedDict):
|
|
tensor_out: torch.Tensor
|
|
tuple_out: Optional[Tuple[torch.Tensor]] = None
|
|
list_out: Optional[List[torch.Tensor]] = None
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, a, b, c, d):
|
|
return DictModelOutput(
|
|
tensor_out=a,
|
|
tuple_out=(b, c),
|
|
list_out=[d],
|
|
)
|
|
|
|
a = torch.randn(2, 3)
|
|
b = torch.randn(2, 3)
|
|
c = torch.randn(2, 3)
|
|
d = torch.randn(2, 3)
|
|
self.run_test(MyModel(), (a, b, c, d))
|
|
|
|
def test_tuple_output(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, a, b, c, d):
|
|
return a, (b, c), d
|
|
|
|
a = torch.randn(2, 3)
|
|
b = torch.randn(2, 3)
|
|
c = torch.randn(2, 3)
|
|
d = torch.randn(2, 3)
|
|
self.run_test(MyModel(), (a, b, c, d))
|
|
|
|
def test_nested_tuple_output(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, a, b, c, d):
|
|
return a, ((b,), (c, d))
|
|
|
|
a = torch.randn(2, 3)
|
|
b = torch.randn(2, 3)
|
|
c = torch.randn(2, 3)
|
|
d = torch.randn(2, 3)
|
|
self.run_test(MyModel(), (a, b, c, d))
|
|
|
|
def test_tuple_input(self):
|
|
class TupleModel(torch.nn.Module):
|
|
def forward(self, a: Tuple[torch.Tensor, torch.Tensor]):
|
|
return a
|
|
|
|
x = (torch.randn(3, 4), torch.randn(4, 3))
|
|
self.run_test(TupleModel(), input=(x,))
|
|
|
|
def test_tuple_primitive_input(self):
|
|
class TupleModel(torch.nn.Module):
|
|
def forward(self, a: Tuple[int, torch.Tensor], b):
|
|
return a[0], a[1] + b
|
|
|
|
x = (3, torch.randn(4, 3))
|
|
y = torch.randn(4, 3)
|
|
self.run_test(TupleModel(), input=(x, y))
|
|
|
|
def test_nested_tuple_input(self):
|
|
class NestedTupleModel(torch.nn.Module):
|
|
def forward(self, a, b: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]):
|
|
return a + b[0] + b[1][0] + b[1][1]
|
|
|
|
x = torch.randn(4, 5)
|
|
y = (torch.randn(4, 5), (torch.randn(1, 5), torch.randn(4, 1)))
|
|
self.run_test(NestedTupleModel(), input=(x, y))
|
|
|
|
@disableScriptTest()
|
|
def test_optional_inputs_with_no_optionals(self):
|
|
class NoOptionalModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input
|
|
|
|
# Without empty optional arguments dictionary
|
|
x = torch.randn(2, 3)
|
|
self.run_test(NoOptionalModel(), (x,))
|
|
# With empty optional arguments dictionary
|
|
y = torch.randn(2, 3)
|
|
self.run_test(NoOptionalModel(), (y, {}))
|
|
|
|
@disableScriptTest() # ScriptModule could not be exported without the Input Descriptor for optional inputs
|
|
def test_optional_inputs_with_mixed_optionals(self):
|
|
class MixedModel(torch.nn.Module):
|
|
def forward(self, x, y=None, z=None):
|
|
if y is not None:
|
|
return x + y
|
|
if z is not None:
|
|
return x + z
|
|
return x
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
z = torch.randn(2, 3)
|
|
# Without optional arguments dictionary
|
|
self.run_test(MixedModel(), (x, y, None))
|
|
self.run_test(MixedModel(), (x, None, z))
|
|
# With optional arguments dictionary
|
|
self.run_test(MixedModel(), (x, {"y": y, "z": None}))
|
|
self.run_test(MixedModel(), (x, {"y": None, "z": z}))
|
|
self.run_test(MixedModel(), (x, {"z": z}))
|
|
self.run_test(MixedModel(), (x, {"y": y}))
|
|
|
|
@disableScriptTest() # ScriptModule could not be exported without the Input Descriptor for optional inputs
|
|
def test_optional_inputs_with_all_optionals(self):
|
|
class AllOptionalModel(torch.nn.Module):
|
|
def forward(self, y=None, z=None):
|
|
if y is not None:
|
|
return y
|
|
if z is not None:
|
|
return z
|
|
|
|
y = torch.randn(2, 3)
|
|
# Without optional arguments dictionary
|
|
self.run_test(AllOptionalModel(), (y, None))
|
|
# With optional arguments dictionary
|
|
self.run_test(AllOptionalModel(), {"y": y, "z": None})
|
|
|
|
@disableScriptTest()
|
|
def test_input_names_with_optional_args(self):
|
|
class NoOptionalModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input
|
|
|
|
# Without empty optional arguments dictionary
|
|
x = torch.randn(2, 3)
|
|
self.run_test(NoOptionalModel(), (x,), input_names=["input_x"])
|
|
# With empty optional arguments dictionary
|
|
y = torch.randn(2, 3)
|
|
self.run_test(NoOptionalModel(), (y, {}))
|
|
|
|
class MixedModel(torch.nn.Module):
|
|
def forward(self, x, y=None, z=None):
|
|
if y is not None:
|
|
return x + y
|
|
if z is not None:
|
|
return x + z
|
|
return x
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
z = torch.randn(2, 3)
|
|
# Without optional arguments dictionary
|
|
self.run_test(MixedModel(), (x, y, None), input_names=["input_x", "input_y"])
|
|
self.run_test(MixedModel(), (x, None, z), input_names=["input_x", "input_z"])
|
|
|
|
# With optional arguments dictionary
|
|
self.run_test(MixedModel(), (x, {"y": y, "z": None}), input_names=["input_x", "input_y"])
|
|
self.run_test(MixedModel(), (x, {"y": None, "z": z}), input_names=["input_x", "input_z"])
|
|
|
|
class AllOptionalModel(torch.nn.Module):
|
|
def forward(self, y=None, z=None):
|
|
if y is not None:
|
|
return y
|
|
if z is not None:
|
|
return z
|
|
|
|
y = torch.randn(2, 3)
|
|
z = torch.randn(2, 3)
|
|
# Without optional arguments dictionary
|
|
self.run_test(AllOptionalModel(), (y, None), input_names=["input_y"])
|
|
self.run_test(AllOptionalModel(), (None, z), input_names=["input_z"])
|
|
# With optional arguments dictionary
|
|
self.run_test(AllOptionalModel(), {"y": y, "z": None}, input_names=["input_y"])
|
|
self.run_test(AllOptionalModel(), {"y": None, "z": z}, input_names=["input_z"])
|
|
|
|
def test_input_as_output(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x, y
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(3, 4)
|
|
self.run_test(Model(), (x, y), input_names=["x", "y"], output_names=["x_out", "y_out"])
|
|
|
|
@disableScriptTest()
|
|
def test_none_as_input(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
if y is not None:
|
|
return x + y
|
|
return x
|
|
|
|
x = torch.randn(2, 3)
|
|
self.run_test(Model(), (x, None))
|
|
|
|
@disableScriptTest() # ScriptModule could not be exported without the Input Descriptor for optional inputs
|
|
def test_none_as_tuple_input(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
if y[0] is not None:
|
|
return x + y[0]
|
|
if y[1] is not None:
|
|
return x + y[1]
|
|
return x
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
self.run_test(Model(), (x, (None, y)))
|
|
|
|
@disableScriptTest() # ScriptModule could not be exported without the Input Descriptor for optional inputs
|
|
def test_none_as_named_input(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y=None, z=None):
|
|
if y is not None:
|
|
return x + y
|
|
if z is not None:
|
|
return x + z
|
|
return x
|
|
|
|
x = torch.randn(2, 3)
|
|
z = torch.randn(2, 3)
|
|
self.run_test(Model(), (x, None, z))
|
|
|
|
def test_primitive_input_integer(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x: int, y):
|
|
return x + y
|
|
|
|
x = 3
|
|
y = torch.randint(10, (2, 3, 4))
|
|
self.run_test(Model(), (x, y))
|
|
|
|
def test_primitive_input_floating(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x: float, y):
|
|
return x + y
|
|
|
|
x = 3.0
|
|
y = torch.randn(2, 3, 4)
|
|
self.run_test(Model(), (x, y))
|
|
|
|
def test_primitive_input_bool(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, flag: bool, x, y):
|
|
if flag:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
flag = True
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(2, 3, 4)
|
|
self.run_test(torch.jit.script(Model()), (flag, x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_cste_script(self):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(MyModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(MyModel(), x, remained_onnx_input_idx=[])
|
|
|
|
def test_scalar_tensor(self):
|
|
class test(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.scalar_tensor(input.size(0)), \
|
|
torch.scalar_tensor(input.size(1), dtype=torch.int64)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(7, 8, 9)
|
|
model = test()
|
|
self.run_test(model, x, test_with_inputs=[y],
|
|
input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1, 2]})
|
|
|
|
def test_tensor(self):
|
|
class ScalarInputModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor(input.shape[1])
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(ScalarInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(ScalarInputModel(), x, remained_onnx_input_idx=[])
|
|
|
|
class TensorInputModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor([input.shape[0], input.shape[1]])
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(TensorInputModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(TensorInputModel(), x, remained_onnx_input_idx=[])
|
|
|
|
class FloatInputModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor([float(input)])
|
|
|
|
x = torch.randn(1)
|
|
self.run_test(FloatInputModel(), x)
|
|
|
|
class InputWithDtypeModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor(input.shape[1], dtype=torch.long)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(InputWithDtypeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(InputWithDtypeModel(), x, remained_onnx_input_idx=[])
|
|
|
|
class MixedInputModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.tensor([input.shape[0], int(input)])
|
|
|
|
x = torch.randn(1)
|
|
self.run_test(MixedInputModel(), x)
|
|
|
|
def test_hardtanh(self):
|
|
model = torch.nn.Hardtanh(-1.5, 2.5)
|
|
x = torch.arange(-5, 5).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
def test_hardtanh_script_with_default_values(self):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.nn.functional.hardtanh(x)
|
|
|
|
x = torch.arange(-5, 5).to(dtype=torch.float32)
|
|
self.run_test(MyModel(), x)
|
|
|
|
def test_hardswish(self):
|
|
model = torch.nn.Hardswish()
|
|
|
|
x = torch.rand(3, 3).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
# Testing edge cases
|
|
x = torch.tensor(3).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
x = torch.tensor(-3).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
def test_hardswish_script(self):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.nn.functional.hardswish(x)
|
|
|
|
x = torch.rand(3, 3).to(dtype=torch.float32)
|
|
self.run_test(MyModel(), x)
|
|
|
|
def test_hardsigmoid(self):
|
|
model = torch.nn.Hardsigmoid()
|
|
|
|
x = torch.rand(3, 3).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
# corner cases
|
|
x = torch.tensor(3).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
x = torch.tensor(-3).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
def test_tanhshrink(self):
|
|
model = torch.nn.Tanhshrink()
|
|
|
|
x = torch.rand(3, 3).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_hardshrink(self):
|
|
model = torch.nn.Hardshrink()
|
|
|
|
x = torch.rand(3, 3).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
# Testing edge cases
|
|
x = torch.tensor(0.5).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
x = torch.tensor(-0.5).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_softshrink(self):
|
|
model = torch.nn.Softshrink()
|
|
|
|
x = torch.rand(3, 3).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
# Testing edge cases
|
|
x = torch.tensor(0.5).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
x = torch.tensor(-0.5).to(dtype=torch.float32)
|
|
self.run_test(model, x)
|
|
|
|
def test_clamp(self):
|
|
class ClampModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.clamp(-0.5, 0.5)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(ClampModel(), x)
|
|
|
|
class ClampMinModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.clamp(min=-0.5)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(ClampMinModel(), x)
|
|
|
|
class ClampMaxModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.clamp(max=0.5)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(ClampMaxModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_clamp_dyn(self):
|
|
class ClampMaxModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x.clamp(None, x.size(0))
|
|
|
|
x = torch.arange(16).view(4, 4).float()
|
|
self.run_test(ClampMaxModel(), x)
|
|
|
|
|
|
class ClampMinModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x.clamp(x.size(0), None)
|
|
|
|
x = torch.arange(16).view(4, 4).float()
|
|
self.run_test(ClampMinModel(), x)
|
|
|
|
class ClampMinMaxModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x.clamp(x.size(0), x.size(1))
|
|
|
|
x = torch.arange(16).view(2, 8).float()
|
|
self.run_test(ClampMinMaxModel(), x)
|
|
|
|
class ClampTensorModel(torch.nn.Module):
|
|
def forward(self, x, min, max):
|
|
return x.clamp(min, max)
|
|
|
|
x = torch.randn(3, 4)
|
|
y = torch.randn(3, 4)
|
|
z = torch.randn(3, 4)
|
|
self.run_test(ClampTensorModel(), (x, y, z))
|
|
|
|
class ClampTensorMinModel(torch.nn.Module):
|
|
def forward(self, x, min):
|
|
return x.clamp(min=min)
|
|
|
|
self.run_test(ClampTensorMinModel(), (x, y))
|
|
|
|
class ClampTensorMaxModel(torch.nn.Module):
|
|
def forward(self, x, max):
|
|
return x.clamp(max=max)
|
|
|
|
self.run_test(ClampTensorMaxModel(), (x, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_full_trace(self):
|
|
class FullModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.full((3, 4), x, dtype=torch.long)
|
|
|
|
x = torch.tensor(12)
|
|
self.run_test(FullModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_full_script(self):
|
|
class FullModelScripting(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.full((3, 4), x, dtype=torch.long)
|
|
|
|
x = torch.tensor(12)
|
|
self.run_test(FullModelScripting(), x)
|
|
|
|
def test_fuse_addmm(self):
|
|
class AddmmModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mm(x, x) + x
|
|
|
|
x = torch.ones(3, 3)
|
|
self.run_test(AddmmModel(), x)
|
|
|
|
def test_maxpool(self):
|
|
model = torch.nn.MaxPool1d(2, stride=1)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_conv(self):
|
|
class TraceModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TraceModel, self).__init__()
|
|
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
|
|
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
|
self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
|
|
|
|
def forward(self, input1, input2, input3):
|
|
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
|
|
|
|
x1 = torch.randn(20, 16, 50)
|
|
x2 = torch.randn(20, 16, 50, 100)
|
|
x3 = torch.randn(20, 16, 10, 50, 100)
|
|
|
|
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
|
|
|
|
def test_conv_shape_inference(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
|
|
|
def forward(self, input):
|
|
return self.conv2(input) + 2
|
|
|
|
x = torch.randn(20, 16, 50, 100)
|
|
self.run_test(Model(), x, atol=10e-5,
|
|
input_names=["x"],
|
|
dynamic_axes={"x": [0]})
|
|
|
|
def test_conv_transpose(self):
|
|
class TraceModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TraceModel, self).__init__()
|
|
self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
|
|
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
|
self.conv3 = torch.nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
|
|
|
|
def forward(self, input1, input2, input3):
|
|
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
|
|
|
|
x1 = torch.randn(20, 16, 50)
|
|
x2 = torch.randn(20, 16, 50, 100)
|
|
x3 = torch.randn(20, 16, 10, 50, 100)
|
|
|
|
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
|
|
|
|
# Conversion of Transpose depends on input shape to be known.
|
|
# The following test only works when onnx shape inference is enabled.
|
|
def test_transpose_infer_shape(self):
|
|
class TransposeModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(TransposeModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x.transpose(0, 1)
|
|
|
|
x = torch.randn(32, 3, 64, 64)
|
|
y = torch.randn(16, 3, 8, 64)
|
|
self.run_test(TransposeModule(), x, input_names=["x"],
|
|
dynamic_axes={"x": [0, 2]},
|
|
test_with_inputs=[y])
|
|
|
|
def squeeze_model_tests(self, d, x1, x2):
|
|
class Squeeze(torch.nn.Module):
|
|
def __init__(self, d):
|
|
super(Squeeze, self).__init__()
|
|
self.d = d
|
|
|
|
def forward(self, x):
|
|
if self.d is not None:
|
|
return torch.squeeze(x, dim=self.d)
|
|
else:
|
|
return torch.squeeze(x)
|
|
|
|
x2 = [] if x2 is None else [x2]
|
|
if len(x2) > 0:
|
|
self.run_test(Squeeze(d), x1,
|
|
input_names=["input"], dynamic_axes={"input": {0: "0", 1: "1", 2: "2"}},
|
|
test_with_inputs=x2)
|
|
else:
|
|
self.run_test(Squeeze(d), x1)
|
|
|
|
def test_squeeze_without_no_op(self):
|
|
x = torch.randn(2, 1, 4)
|
|
self.squeeze_model_tests(1, x, None)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_squeeze_dynamic(self):
|
|
x_squeeze = torch.randn(2, 1, 4)
|
|
x_noop = torch.randn(2, 2, 3)
|
|
self.squeeze_model_tests(1, x_squeeze, x_noop)
|
|
|
|
def test_squeeze_neg_without_no_op(self):
|
|
x = torch.randn(2, 1, 4)
|
|
self.squeeze_model_tests(-2, x, None)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_squeeze_neg(self):
|
|
x_squeeze = torch.randn(2, 1, 4)
|
|
x_noop = torch.randn(2, 2, 3)
|
|
self.squeeze_model_tests(-2, x_squeeze, x_noop)
|
|
|
|
def test_squeeze_all_dims(self):
|
|
x_squeeze = torch.randn(2, 1, 4)
|
|
x_noop = torch.randn(2, 2, 3)
|
|
self.squeeze_model_tests(None, x_squeeze, x_noop)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_squeeze_no_op(self):
|
|
x_noop = torch.randn(2, 1, 4)
|
|
x_squeeze = torch.randn(2, 2, 1)
|
|
self.squeeze_model_tests(2, x_noop, x_squeeze)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_squeeze_runtime_dim(self):
|
|
class Squeeze(torch.nn.Module):
|
|
def forward(self, d1, d2):
|
|
t = torch.zeros(d1[0], d2[0])
|
|
return t.squeeze(0)
|
|
|
|
d1 = torch.tensor([1])
|
|
d3 = torch.tensor([3])
|
|
d4 = torch.tensor([4])
|
|
self.run_test(Squeeze(), (d1, d4), test_with_inputs=[(d3, d4)])
|
|
self.run_test(Squeeze(), (d3, d4), test_with_inputs=[(d1, d3)])
|
|
|
|
def test_squeeze(self):
|
|
class Squeeze(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.squeeze(x, dim=-2)
|
|
|
|
x = torch.randn(2, 1, 4)
|
|
self.run_test(Squeeze(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_squeeze_dynamic_dim(self):
|
|
class Squeeze(torch.nn.Module):
|
|
def forward(self, x, dim: int):
|
|
return torch.squeeze(x, dim)
|
|
|
|
x = torch.randn(2, 1, 4)
|
|
dim = 1
|
|
self.run_test(Squeeze(), (x, dim))
|
|
|
|
def test_unsqueeze(self):
|
|
class Unsqueeze(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.unsqueeze(x, dim=-2)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Unsqueeze(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_unsqueeze_dynamic_dim(self):
|
|
class Unsqueeze(torch.nn.Module):
|
|
def forward(self, x, dim: int):
|
|
return torch.unsqueeze(x, dim)
|
|
|
|
x = torch.randn(2, 1, 4)
|
|
dim = -1
|
|
self.run_test(Unsqueeze(), (x, dim))
|
|
|
|
def test_maxpool_default_stride(self):
|
|
class MaxPoolModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.max_pool2d(x, 2)
|
|
|
|
model = MaxPoolModel()
|
|
x = torch.randn(10, 20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_maxpool_adaptive(self):
|
|
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
|
|
x = torch.randn(20, 16, 50, requires_grad=True)
|
|
y = torch.randn(32, 16, 50, requires_grad=True)
|
|
self.run_test(model, x, input_names=["x"],
|
|
dynamic_axes={"x" : [0]},
|
|
test_with_inputs=[y])
|
|
|
|
def test_maxpool_2d(self):
|
|
model = torch.nn.MaxPool2d(5, padding=(1, 2))
|
|
x = torch.randn(1, 20, 16, 50, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_1d_ceil(self):
|
|
model = torch.nn.MaxPool1d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_2d_ceil(self):
|
|
model = torch.nn.MaxPool2d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 32)
|
|
self.run_test(model, x)
|
|
|
|
def test_maxpool_3d_ceil(self):
|
|
model = torch.nn.MaxPool3d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 44, 31)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_maxpool_with_indices(self):
|
|
model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_maxpool_dilation(self):
|
|
model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_default_stride(self):
|
|
class AvgPoolModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.avg_pool2d(x, 2)
|
|
|
|
model = AvgPoolModel()
|
|
x = torch.randn(10, 20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool(self):
|
|
model = torch.nn.AvgPool1d(2, stride=1)
|
|
x = torch.randn(20, 16, 50)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_1d_ceil(self):
|
|
model = torch.nn.AvgPool1d(3, 2, ceil_mode=True)
|
|
x = torch.randn(1, 1, 7)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_2d_ceil(self):
|
|
model = torch.nn.AvgPool2d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 32)
|
|
self.run_test(model, x)
|
|
|
|
def test_avgpool_3d_ceil(self):
|
|
model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
|
|
x = torch.randn(20, 16, 50, 44, 31)
|
|
y = torch.randn(32, 8, 50, 44, 31)
|
|
self.run_test(model, x, input_names=["x"],
|
|
dynamic_axes={"x" : [0, 1]},
|
|
test_with_inputs=[y])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_floating_point(self):
|
|
class FloatingPoint(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
if x.is_floating_point():
|
|
return x.new_zeros(x.shape)
|
|
return x.new_zeros(x.shape)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[])
|
|
|
|
class FloatingPoint(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
if x.size(0) > 1:
|
|
a = x + 2
|
|
if a.is_floating_point():
|
|
return x + 1
|
|
return x + 1
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(FloatingPoint(), x)
|
|
|
|
# Operator rank mismatch between outputs of two branches for opsets below 11.
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_floating_point_infer_dtype(self):
|
|
class FloatingPoint(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
if x.size(0) > 1:
|
|
a = x + 2
|
|
if a.is_floating_point():
|
|
return x.new_zeros(x.shape[1:])
|
|
return x.new_zeros(x.shape)
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(FloatingPoint(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(FloatingPoint(), x, remained_onnx_input_idx=[])
|
|
|
|
class FloatingPoint(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
if x.size(0) > 1:
|
|
a = x + 2
|
|
if a.is_floating_point():
|
|
return x + 1
|
|
return x
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4).to(torch.int32)
|
|
self.run_test(FloatingPoint(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_prim_min(self):
|
|
@torch.jit.script
|
|
def list_append(boxes: List[torch.Tensor]):
|
|
temp = []
|
|
for i, b in enumerate(boxes): # enumerate is creating a prim::min op in torch graph
|
|
temp.append(torch.full_like(b[:, 1], i))
|
|
return temp[0]
|
|
|
|
class Min(torch.nn.Module):
|
|
def forward(self, x):
|
|
boxes = [x for _ in range(3)]
|
|
return list_append(boxes)
|
|
|
|
x = torch.rand(5, 5)
|
|
self.run_test(Min(), (x,))
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
i = 3
|
|
return min(x[i], i)
|
|
|
|
x = torch.arange(6, dtype=torch.int64)
|
|
self.run_test(M(), (x,))
|
|
|
|
def test_arithmetic(self):
|
|
class ArithmeticModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x + 2
|
|
x = x - 4
|
|
x = x * 6
|
|
x = x / 8
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(ArithmeticModule(), x)
|
|
|
|
def test_arithmetic_prim_long(self):
|
|
class ArithmeticModule(torch.nn.Module):
|
|
def forward(self, x, y: int):
|
|
x = x + y
|
|
x = x - y
|
|
x = x * (y * 3)
|
|
x = x / (y * 4)
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = 2
|
|
self.run_test(ArithmeticModule(), (x, y))
|
|
|
|
class ArithmeticModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x + 2
|
|
x = x - 3
|
|
return x.shape[0]
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
|
|
|
|
def test_arithmetic_prim_float(self):
|
|
class ArithmeticModule(torch.nn.Module):
|
|
def forward(self, x, y: float):
|
|
x = x + y
|
|
x = x - y
|
|
x = x * (y * 3)
|
|
x = x / (y * 4)
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = 2.5
|
|
self.run_test(ArithmeticModule(), (x, y))
|
|
|
|
class ArithmeticModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x + 2
|
|
x = x - 3
|
|
return x.shape[1] / 2
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(ArithmeticModule(), x, remained_onnx_input_idx=[])
|
|
|
|
def test_arithmetic_prim_bool(self):
|
|
class ArithmeticModule(torch.nn.Module):
|
|
def forward(self, x, y: int, z: bool, t: float):
|
|
x = x + y
|
|
x = x - y
|
|
if z:
|
|
x = x * (y * 3)
|
|
x = x / (y * 4)
|
|
return x / t, z
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = 2
|
|
z = False
|
|
t = 2.5
|
|
self.run_test(ArithmeticModule(), (x, y, z, t))
|
|
|
|
class ArithmeticModule(torch.nn.Module):
|
|
def forward(self, x: int, y: int):
|
|
return x == y
|
|
|
|
x = 3
|
|
y = 2
|
|
self.run_test(ArithmeticModule(), (x, y))
|
|
|
|
@disableScriptTest()
|
|
def test_tuple_with_none_outputs(self):
|
|
class TupleModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
l = (x, None, (x, None))
|
|
return (x, l)
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(TupleModel(), (x,))
|
|
|
|
# In scripting the first transpose node do not carry shape and dtype info.
|
|
# The following test only works when onnx shape inference is enabled.
|
|
def test_arithmetic_infer_dtype(self):
|
|
class ArithmeticModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x = x.t()
|
|
x = x + 2
|
|
x = x - 4
|
|
x = x * 6
|
|
x = x / 8
|
|
return x
|
|
|
|
x = torch.randn(2, 3)
|
|
self.run_test(ArithmeticModule(), x)
|
|
|
|
def test_floor_div(self):
|
|
class FloorDivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x // 3, x // 2., \
|
|
x.to(dtype=torch.float64) // 3, x.to(dtype=torch.float64) // 2., \
|
|
x.to(dtype=torch.int64) // 3, x.to(dtype=torch.int64) // 2., \
|
|
x // (y + 1.).to(dtype=torch.int64), x // y, \
|
|
x.to(dtype=torch.float64) // y.to(dtype=torch.int64), x.to(dtype=torch.float64) // y.to(dtype=torch.float64), \
|
|
x.to(dtype=torch.int64) // y.to(dtype=torch.int64), x.to(dtype=torch.int64) // y
|
|
|
|
x = torch.arange(-2, 4).reshape(2, 3, 1)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
|
|
self.run_test(FloorDivModule(), (x, y))
|
|
|
|
def test_floor_div_script(self):
|
|
class FloorDivModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
return x // 3, x // 2., x // y
|
|
|
|
x = torch.arange(-2, 4).reshape(2, 3, 1)
|
|
y = torch.randn(2, 3, 4)
|
|
self.run_test(FloorDivModule(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_floordiv(self):
|
|
class FloordivModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_zeros(x.size(2) // x.size(1))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(FloordivModule(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(FloordivModule(), (x,), remained_onnx_input_idx=[])
|
|
|
|
def test_div(self):
|
|
class DivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x / y, torch.true_divide(x, y)
|
|
|
|
x = torch.randn(2, 3, 4).to(torch.int)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
|
self.run_test(DivModule(), (x, y))
|
|
self.run_test(DivModule(), (x.float(), y.float()))
|
|
|
|
# Note: div cannot (generally) be exported via scripting
|
|
# since its type promotion logic is dependent on knowing the scalar types
|
|
# of the input tensors. That is, the ONNX graph is dependent on the
|
|
# data type of the inputs. This makes it appropriate for tracing only.
|
|
def test_div_promotion_trace(self):
|
|
class DivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x / y, torch.true_divide(x, y)
|
|
|
|
x = torch.randn(2, 3, 4).to(torch.int)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
|
|
|
prev_default = torch.get_default_dtype()
|
|
|
|
torch.set_default_dtype(torch.float)
|
|
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
|
|
|
|
torch.set_default_dtype(torch.double)
|
|
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
|
|
|
|
torch.set_default_dtype(prev_default)
|
|
|
|
# In scripting x, y do not carry shape and dtype info.
|
|
# The following test only works when onnx shape inference is enabled.
|
|
def test_div_promotion_script(self):
|
|
class DivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
# Add transpose to hide shape/type information
|
|
# Otherwise shape and type are still avaiable from input.
|
|
x = x.transpose(1, 2)
|
|
y = y.transpose(1, 2)
|
|
return x / y, torch.true_divide(x, y)
|
|
|
|
x = torch.randn(2, 3, 4).to(torch.int)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
|
|
|
prev_default = torch.get_default_dtype()
|
|
|
|
# 1. x,y are int, and output is float.
|
|
# This can be handled by the default case, where both are cast to float.
|
|
# It works even if type of x, y are unknown.
|
|
torch.set_default_dtype(torch.float)
|
|
self.run_test(torch.jit.script(DivModule()), (x, y))
|
|
|
|
# 2. x,y are int, and output is double.
|
|
# This can be handled by the default case, where both are cast to double.
|
|
# It works even if type of x, y are unknown.
|
|
torch.set_default_dtype(torch.double)
|
|
self.run_test(torch.jit.script(DivModule()), (x, y))
|
|
|
|
# 3. x is int, y is double, and output is double.
|
|
# This can only be handled when both type of x and y are known.
|
|
torch.set_default_dtype(prev_default)
|
|
x = torch.randn(2, 3, 4).to(torch.int)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
|
|
self.run_test(torch.jit.script(DivModule()), (x, y))
|
|
|
|
def test_div_rounding_mode(self):
|
|
class TrueDivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return (x.div(y, rounding_mode=None),
|
|
torch.div(x, y, rounding_mode=None))
|
|
|
|
class TruncDivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return (x.div(y, rounding_mode="trunc"),
|
|
torch.div(x, y, rounding_mode="trunc"))
|
|
|
|
class FloorDivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return (x.div(y, rounding_mode="floor"),
|
|
torch.div(x, y, rounding_mode="floor"))
|
|
|
|
modules = [TrueDivModule(), TruncDivModule(), FloorDivModule()]
|
|
|
|
x = (torch.randn(2, 3, 4) * 100).to(torch.int)
|
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
|
|
|
for module in modules:
|
|
self.run_test(module, (x, y))
|
|
self.run_test(torch.jit.trace(module, (x, y)), (x, y))
|
|
self.run_test(torch.jit.script(module), (x, y))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.rand(2, 3, 4) * 10.0 + 0.1
|
|
|
|
for module in modules:
|
|
self.run_test(module, (x, y))
|
|
self.run_test(torch.jit.trace(module, (x, y)), (x, y))
|
|
self.run_test(torch.jit.script(module), (x, y))
|
|
|
|
def test_slice_trace(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[0:1]
|
|
|
|
x = torch.randn(3)
|
|
self.run_test(MyModule(), x)
|
|
|
|
def test_slice_neg(self):
|
|
class NegSlice(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[-1:]
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(NegSlice(), x)
|
|
|
|
def test_slice_neg_large(self):
|
|
class NegSlice(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, :, -3:-1, :, -1]
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(NegSlice(), x)
|
|
|
|
def test_slice_neg_large_negone(self):
|
|
class NegSlice(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, :, :, :, -1]
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(NegSlice(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_slice_with_input_index(self):
|
|
class InputIndexSlice(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
x[:y.size(0), 0, :] = y
|
|
return x
|
|
|
|
x = torch.zeros((56, 6, 256))
|
|
y = torch.rand((22, 256))
|
|
self.run_test(InputIndexSlice(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@disableScriptTest() # scripting tuple/list append
|
|
def test_slice_dynamic(self):
|
|
class DynamicSliceExportMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
results = []
|
|
for i in range(4):
|
|
results.append(x[:x.size(0) - i, i:x.size(2), i:3])
|
|
return tuple(results)
|
|
|
|
x = torch.rand(5, 5, 5)
|
|
y = torch.randn(6, 7, 8)
|
|
self.run_test(DynamicSliceExportMod(), x, test_with_inputs=[y],
|
|
input_names=["input_1"],
|
|
output_names=["output_1"],
|
|
dynamic_axes={"input_1": [0, 1, 2],
|
|
"output_1": [0, 1, 2]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_slice_dynamic_script(self):
|
|
class DynamicSliceModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x[1:x.size(1)]
|
|
|
|
x = torch.rand(1, 2)
|
|
self.run_test(DynamicSliceModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_slice_dynamic_shape_script(self):
|
|
class DynamicSliceModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_zeros(x.shape[1:x.size(2)])
|
|
|
|
x = torch.rand(1, 2, 3, 4)
|
|
self.run_test(DynamicSliceModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]})
|
|
self.run_test(DynamicSliceModel(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@disableScriptTest() # scripting tuple/list append
|
|
def test_slice_dynamic_to_end(self):
|
|
class DynamicSliceExportMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
results = []
|
|
for i in range(4):
|
|
results.append(x[:, i:, x.size(2) - 5])
|
|
return tuple(results)
|
|
|
|
x = torch.rand(5, 5, 5)
|
|
self.run_test(DynamicSliceExportMod(), x,
|
|
dynamic_axes={"input_1": [0, 1, 2],
|
|
"output_1": [0, 1, 2]})
|
|
|
|
def test_square(self):
|
|
class Square(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.square(x)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Square(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_dynamic(self):
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.arange(input.shape[0]), \
|
|
torch.arange(12), \
|
|
torch.arange(start=input.shape[0], end=input.shape[0] + 5)
|
|
|
|
x = torch.randn(5, 3, 2)
|
|
y = torch.randn(8, 3, 2)
|
|
self.run_test(ArangeModel(), x, test_with_inputs=[y],
|
|
input_names=["input_1"],
|
|
output_names=["output_1", "output_2", "output_3"],
|
|
dynamic_axes={"input_1": [0],
|
|
"output_1": [0]})
|
|
self.run_test(torch.jit.script(ArangeModel()), x,
|
|
test_with_inputs=[y], input_names=["input_1"],
|
|
output_names=["output_1", "output_2", "output_3"],
|
|
dynamic_axes={"input_1": [0],
|
|
"output_1": [0]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_dynamic_arange_out(self):
|
|
class ArangeOutModel(torch.nn.Module):
|
|
def forward(self, end):
|
|
out_t = torch.tensor([1], dtype=torch.int64)
|
|
return torch.arange(end, out=out_t)
|
|
|
|
x = torch.tensor(8)
|
|
self.run_test(ArangeOutModel(), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_dynamic_arange_start_out(self):
|
|
class ArangeStartOutModel(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
out_t = torch.tensor([1], dtype=torch.int64)
|
|
return torch.arange(start.size(0), end, out=out_t)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8)
|
|
self.run_test(ArangeStartOutModel(), (x, y),
|
|
input_names=["x", "y"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_linspace(self):
|
|
class LinspaceModel(torch.nn.Module):
|
|
def forward(self, start, end, steps):
|
|
return torch.linspace(start, end, steps)
|
|
|
|
x = torch.tensor(3, dtype=torch.float)
|
|
y = torch.tensor(10, dtype=torch.float)
|
|
z = torch.tensor(5, dtype=torch.int)
|
|
self.run_test(LinspaceModel(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_linspace_negative_start(self):
|
|
class LinspaceModel(torch.nn.Module):
|
|
def forward(self, start, end, steps):
|
|
return torch.linspace(start, end, steps)
|
|
|
|
x = torch.tensor(-1, dtype=torch.float)
|
|
y = torch.tensor(1, dtype=torch.float)
|
|
z = torch.tensor(6, dtype=torch.int)
|
|
self.run_test(LinspaceModel(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_with_floats_out(self):
|
|
class ArangeModelEnd(torch.nn.Module):
|
|
def forward(self, end):
|
|
out_t = torch.tensor([1], dtype=torch.float)
|
|
return torch.arange(end, out=out_t)
|
|
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeModelEnd(), (y))
|
|
|
|
class ArangeModelStep(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
out_t = torch.tensor([1], dtype=torch.float)
|
|
return torch.arange(start.size(0), end, 1.5, out=out_t)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeModelStep(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_with_floats(self):
|
|
class ArangeModelEnd(torch.nn.Module):
|
|
def forward(self, end):
|
|
return torch.arange(end)
|
|
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeModelEnd(), (y))
|
|
|
|
class ArangeModelStep(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
return torch.arange(start.size(0), end, 1.5)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeModelStep(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
|
|
|
|
class ArangeModelStepNeg(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
return torch.arange(end, start.size(0), -1.5)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeModelStepNeg(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ArangeModelStepNeg(), (x, y), remained_onnx_input_idx=[1])
|
|
|
|
class ArangeModelStart(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
return torch.arange(start.size(0), end)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeModelStart(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ArangeModelStart(), (x, y), remained_onnx_input_idx=[1])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_with_floats_override(self):
|
|
class ArangeModelEnd(torch.nn.Module):
|
|
def forward(self, end):
|
|
return torch.arange(end, dtype=torch.int64)
|
|
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeModelEnd(), (y))
|
|
|
|
class ArangeModelStep(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
return torch.arange(start.size(0), end, 1.5, dtype=torch.int64)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeModelStep(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ArangeModelStep(), (x, y), remained_onnx_input_idx=[1])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_out(self):
|
|
class ArangeOutModel(torch.nn.Module):
|
|
def forward(self, end):
|
|
out_t = torch.tensor([1], dtype=torch.float)
|
|
return torch.arange(end, out=out_t)
|
|
|
|
x = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeOutModel(), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_start_out(self):
|
|
class ArangeStartOutModel(torch.nn.Module):
|
|
def forward(self, start, end):
|
|
out_t = torch.tensor([1], dtype=torch.float)
|
|
return torch.arange(start.size(0), end, out=out_t)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor(8.5, dtype=torch.float)
|
|
self.run_test(ArangeStartOutModel(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ArangeStartOutModel(), (x, y), remained_onnx_input_idx=[1])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_no_type(self):
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, end):
|
|
return torch.arange(end), \
|
|
torch.arange(0, end)
|
|
|
|
x = torch.tensor(6.2, dtype=torch.float)
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_size(self):
|
|
class SizeModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape)
|
|
|
|
x = torch.randn(5, 3, 2)
|
|
self.run_test(SizeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(SizeModel(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # x.stride() not scriptable
|
|
def test_as_strided(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
chunk_size = list(x.size())
|
|
chunk_size[1] = chunk_size[1] * 2 - 1
|
|
chunk_stride = list(x.stride())
|
|
chunk_stride[1] = chunk_stride[1] // 2
|
|
return x.as_strided((3, 3, 3), (1, 4, 2), storage_offset=2), x.as_strided(chunk_size, chunk_stride)
|
|
|
|
x = torch.randn(5, 8, 7)
|
|
self.run_test(Model(), x)
|
|
|
|
@disableScriptTest() # Ellipses followed by tensor indexing not scriptable
|
|
def test_tensor_index_advanced_indexing_ellipsis(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[..., torch.tensor([2, 1]), torch.tensor([0, 3])]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
def test_tensor_index_advanced_indexing(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[:, torch.tensor([[0, 2], [1, 1]]), :, torch.tensor([2, 1]), torch.tensor([0, 3])]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[:, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])]
|
|
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[:, torch.tensor([0, 2]), torch.tensor([1]), 2:4, torch.tensor([[1], [4]])]
|
|
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
def test_tensor_index_advanced_indexing_consecutive(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input[:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None]
|
|
|
|
m1 = torch.randn(3, 4, 5, 6, 7)
|
|
self.run_test(MyModel(), (m1,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put(self):
|
|
class IndexPutModel(torch.nn.Module):
|
|
def forward(self, x, ind, update):
|
|
x[ind] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
ind = torch.tensor([1], dtype=torch.long)
|
|
update = torch.ones(4)
|
|
self.run_test(IndexPutModel(), (x, ind, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_singular(self):
|
|
class IndexPutBoolModel(torch.nn.Module):
|
|
def forward(self, mask, indices):
|
|
mask[indices] = True
|
|
return mask
|
|
|
|
mask = torch.zeros(100, dtype=torch.bool)
|
|
indices = (torch.rand(25) * mask.shape[0]).to(torch.int64)
|
|
self.run_test(IndexPutBoolModel(), (mask, indices))
|
|
|
|
class IndexPutFloatModel(torch.nn.Module):
|
|
def forward(self, mask, indices):
|
|
mask[indices] = torch.tensor(5.5)
|
|
return mask
|
|
|
|
mask = torch.rand(100, dtype=torch.float)
|
|
indices = (torch.rand(50) * mask.shape[0]).to(torch.int64)
|
|
self.run_test(IndexPutFloatModel(), (mask, indices))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_accumulate(self):
|
|
class IndexPutModel(torch.nn.Module):
|
|
def forward(self, x, ind, update):
|
|
return x.index_put((ind, ), update, accumulate=True)
|
|
|
|
x = torch.randn(3, 4)
|
|
ind = torch.tensor([2], dtype=torch.long)
|
|
update = torch.ones(4)
|
|
self.run_test(IndexPutModel(), (x, ind, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_slice_index(self):
|
|
class IndexPutModel(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[1:2, 1:3, torch.tensor([1])] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.tensor([10, 15]).view(1, 2, 1)
|
|
self.run_test(IndexPutModel(), (x, update))
|
|
|
|
class IndexPutModel2(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[torch.tensor([0, 2]), torch.tensor([1, 2])] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.randn(2, 5)
|
|
self.run_test(IndexPutModel2(), (x, update))
|
|
|
|
class IndexPutModel3(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[torch.tensor([0, 2]), 1:2] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.tensor([10, 15]).view(2, 1, 1)
|
|
self.run_test(IndexPutModel3(), (x, update))
|
|
|
|
class IndexPutModel4(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[torch.tensor([0, 2]), 2] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.tensor([10, 15]).view(2, 1)
|
|
self.run_test(IndexPutModel4(), (x, update))
|
|
|
|
class IndexPutModel5(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[1:3, torch.tensor([0, 2]), 2] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.tensor([10, 15]).view(2, 1)
|
|
self.run_test(IndexPutModel5(), (x, update))
|
|
|
|
class IndexPutModel6(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[1:3, 0] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
|
|
self.run_test(IndexPutModel6(), (x, update))
|
|
|
|
class IndexPutModel7(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[1:, 0] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
|
|
self.run_test(IndexPutModel7(), (x, update))
|
|
|
|
class IndexPutModel8(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[:3, 0] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
update = torch.arange(3 * 5).to(torch.float).view(3, 5)
|
|
self.run_test(IndexPutModel8(), (x, update))
|
|
|
|
class IndexPutModel9(torch.nn.Module):
|
|
def forward(self, poses):
|
|
w = 32
|
|
x = poses[:, :, 0] - (w - 1) // 2
|
|
boxes = torch.zeros([poses.shape[0], 17, 4])
|
|
boxes[:, :, 0] = x
|
|
return boxes
|
|
|
|
x = torch.zeros([2, 17, 3], dtype=torch.int64)
|
|
self.run_test(IndexPutModel9(), (x,))
|
|
|
|
class IndexPutModel10(torch.nn.Module):
|
|
def forward(self, x, ind, update):
|
|
x[ind, 1:3] = update.view(1, 1, 1, 5).expand(2, 2, 2, 5)
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
ind = torch.tensor([[0, 2], [1, 1]])
|
|
update = torch.randn(5)
|
|
self.run_test(IndexPutModel10(), (x, ind, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # Ellipses followed by tensor indexing not scriptable
|
|
def test_index_put_ellipsis(self):
|
|
class IndexPutModel(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[..., torch.tensor([2, 1, 3]), 2:4] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
update = torch.randn(3, 1, 1, 3, 2)
|
|
self.run_test(IndexPutModel(), (x, update))
|
|
|
|
class IndexPutModel2(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[2, ..., torch.tensor([2, 1, 3]), 2:4] += update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5, 6, 7)
|
|
update = torch.randn(4, 1, 3, 2)
|
|
self.run_test(IndexPutModel2(), (x, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_loop(self):
|
|
@torch.jit.script
|
|
def ngram_attention_bias(sequence_length: int, ngram: int, device: torch.device, dtype: torch.dtype):
|
|
bias = torch.ones((ngram, sequence_length), device=device, dtype=dtype) * float("-inf")
|
|
for stream_idx in range(ngram):
|
|
for i in range(sequence_length):
|
|
bias = bias * 2
|
|
bias[stream_idx, i] = 5
|
|
bias = bias * 5
|
|
bias[0, 0] = 5
|
|
|
|
for stream_idx in range(ngram):
|
|
for i in range(sequence_length):
|
|
bias[stream_idx, i] = 5
|
|
bias[0, i] = 5
|
|
return bias
|
|
|
|
class ScriptModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(ScriptModel, self).__init__()
|
|
self.ngram = 2
|
|
self.max_target_positions = 512
|
|
|
|
def forward(self, hidden_states):
|
|
seq_length, batch_size = hidden_states.shape[:2]
|
|
predict_causal_mask = ngram_attention_bias(
|
|
self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype
|
|
)
|
|
predict_causal_mask = predict_causal_mask[:, :seq_length]
|
|
return predict_causal_mask
|
|
|
|
x = torch.randn(6, 2)
|
|
y = torch.randn(4, 1)
|
|
self.run_test(ScriptModel(), x, input_names=["x"],
|
|
dynamic_axes={"x": {0: "seq_length", 1: "batch_size"}}, test_with_inputs=[y])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_copy_(self):
|
|
class CopyModel(torch.nn.Module):
|
|
def forward(self, x, data):
|
|
x[1:3] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
update = torch.randn(2, 4)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
# mixed slice and select
|
|
class CopyModel2(torch.nn.Module):
|
|
def forward(self, x, data):
|
|
x[1:3, 0] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
update = torch.tensor([0], dtype=torch.float32)
|
|
self.run_test(CopyModel2(), (x, update))
|
|
|
|
update = torch.tensor([2, 3], dtype=torch.float32)
|
|
self.run_test(CopyModel2(), (x, update))
|
|
|
|
update = torch.randn(2)
|
|
self.run_test(CopyModel2(), (x, update))
|
|
|
|
class CopyModel3(torch.nn.Module):
|
|
def forward(self, x, data):
|
|
x[1, 1:3] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
update = torch.tensor([0], dtype=torch.float32)
|
|
self.run_test(CopyModel3(), (x, update))
|
|
|
|
update = torch.tensor([2, 3], dtype=torch.float32)
|
|
self.run_test(CopyModel3(), (x, update))
|
|
|
|
update = torch.randn(2)
|
|
self.run_test(CopyModel3(), (x, update))
|
|
|
|
class CopyModel4(torch.nn.Module):
|
|
def forward(self, x, ind, data):
|
|
x[ind] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
ind = torch.tensor(2)
|
|
data = torch.randn(4)
|
|
self.run_test(CopyModel4(), (x, ind, data))
|
|
|
|
class CopyModel5(torch.nn.Module):
|
|
def forward(self, x, mask):
|
|
if mask is not None:
|
|
x.copy_(mask)
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
mask = torch.randn(3, 1)
|
|
self.run_test(CopyModel5(), (x, mask))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # Model not scriptable (output with shape doesn't match the broadcast shape)
|
|
def test_copy_tracing(self):
|
|
class CopyModel(torch.nn.Module):
|
|
def forward(self, x, data):
|
|
x[1, 1:3] = data
|
|
return x
|
|
|
|
x = torch.randn(3, 4)
|
|
update = torch.randn(1, 2)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_copy_ellipsis(self):
|
|
class CopyModel(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
x[..., 1] = update
|
|
return x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
update = torch.ones(1)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
x = torch.randn(2, 3, 4, 5, 6)
|
|
update = torch.ones(1)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_copy_ellipsis_script(self):
|
|
class CopyModel(torch.nn.Module):
|
|
def forward(self, x, update):
|
|
# Insert reshape node to ensure no shape/type info for
|
|
# x in scripting, without onnx shape inference.
|
|
x = x.reshape(4, 3, 5, 6)
|
|
x[2, ..., 1:3] = update
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5, 6)
|
|
|
|
update = torch.ones(1)
|
|
self.run_test(CopyModel(), (x, update))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_flip(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.flip(x, dims=[0])
|
|
|
|
x = torch.tensor(np.arange(6.0).reshape(2, 3))
|
|
self.run_test(MyModule(), x)
|
|
|
|
def test_random(self):
|
|
class RandN(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, (torch.randn(2, 3, 4) + x).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandN(), x)
|
|
|
|
class Rand(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, (torch.rand(2, 3, 4) + x).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Rand(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_random_dynamic_size(self):
|
|
class RandN(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, torch.randn(x.size()).size(1))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandN(), x)
|
|
|
|
class Rand(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, torch.rand(x.size()).size(1))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Rand(), x)
|
|
|
|
def test_random_like(self):
|
|
class RandNLike(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, torch.randn_like(x).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandNLike(), x)
|
|
self.run_test(torch.jit.script(RandNLike()), x)
|
|
|
|
class RandLike(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, torch.rand_like(x).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandLike(), x)
|
|
self.run_test(torch.jit.script(RandLike()), x)
|
|
|
|
def test_random_like_dtype(self):
|
|
class RandNLike(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x.to(torch.double), torch.randn_like(x, dtype=torch.double).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandNLike(), x)
|
|
|
|
class RandLike(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x.to(torch.double), torch.rand_like(x, dtype=torch.double).size(0))
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(RandLike(), x)
|
|
|
|
def test_bernoulli(self):
|
|
class Bernoulli(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mul(x, torch.bernoulli(x).size(0))
|
|
|
|
x = torch.empty(3, 3).uniform_(0, 1)
|
|
self.run_test(Bernoulli(), x)
|
|
|
|
x = torch.empty(2, 3, 3, dtype=torch.double).uniform_(0, 1)
|
|
self.run_test(Bernoulli(), x)
|
|
|
|
@unittest.skip("Bug in ORT, skip test until rel-1.11.")
|
|
@skipIfUnsupportedMinOpsetVersion(14)
|
|
def test_reshape_allowzero(self):
|
|
class ReshapeModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x.reshape(3, 4, 0)
|
|
return x
|
|
|
|
x = torch.randn(0, 3, 4)
|
|
self.run_test(ReshapeModel(), x)
|
|
|
|
def test_reshape_different_rank(self):
|
|
class ReshapeModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x.reshape(-1, 2, 4, 4, 5, 5)
|
|
return x
|
|
|
|
x = torch.randn(1, 32, 5, 5)
|
|
self.run_test(ReshapeModel(), x)
|
|
|
|
def _interpolate(self, x, mode, use_size, is_upsample, align_corners=False):
|
|
class MyModel(torch.nn.Module):
|
|
__constants__ = ["mode", "use_size", "is_upsample", "size", "scale", "size_array", "scale_array", "align_corners"]
|
|
|
|
def __init__(self, mode, use_size, is_upsample, align_corners):
|
|
super(MyModel, self).__init__()
|
|
self.mode = mode
|
|
self.use_size = use_size
|
|
self.is_upsample = is_upsample
|
|
self.align_corners = align_corners
|
|
self.scale = 2.0 if self.is_upsample else 0.5
|
|
self.size = 24 if self.is_upsample else 2
|
|
if x.dim() == 3:
|
|
self.scale_array = [2.3]
|
|
self.size_array = [16]
|
|
elif x.dim() == 4:
|
|
self.scale_array = [2.3, 3.1]
|
|
self.size_array = [16, 32]
|
|
else:
|
|
self.scale_array = [2.3, 3.1, 4.6]
|
|
self.size_array = [16, 32, 64]
|
|
|
|
def forward(self, x):
|
|
if self.use_size:
|
|
if self.align_corners:
|
|
return torch.nn.functional.interpolate(x, mode=self.mode, size=self.size, align_corners=True), \
|
|
torch.nn.functional.interpolate(x, mode=self.mode, size=self.size_array, align_corners=True)
|
|
return torch.nn.functional.interpolate(x, mode=self.mode, size=self.size), \
|
|
torch.nn.functional.interpolate(x, mode=self.mode, size=self.size_array)
|
|
if self.align_corners:
|
|
return torch.nn.functional.interpolate(x, mode=self.mode,
|
|
scale_factor=self.scale, recompute_scale_factor=False), \
|
|
torch.nn.functional.interpolate(x, mode=self.mode,
|
|
scale_factor=self.scale_array, recompute_scale_factor=False)
|
|
return torch.nn.functional.interpolate(x, mode=self.mode,
|
|
scale_factor=self.scale, recompute_scale_factor=False), \
|
|
torch.nn.functional.interpolate(x, mode=self.mode,
|
|
scale_factor=self.scale_array, recompute_scale_factor=False)
|
|
|
|
model = MyModel(mode, use_size, is_upsample, align_corners)
|
|
self.run_test(model, x, atol=1e-6)
|
|
|
|
def _interpolate_tests(self, is_upsample):
|
|
# - cubic mode is not supported for opsets below 11;
|
|
# - linear mode does not match for opsets below 11;
|
|
modes = ["nearest", "linear", "bicubic"]
|
|
if self.opset_version < 11:
|
|
modes = ["nearest"]
|
|
x = [torch.randn(1, 2, 6, requires_grad=True),
|
|
torch.randn(1, 2, 4, 6, requires_grad=True),
|
|
torch.randn(1, 2, 4, 4, 6, requires_grad=True)]
|
|
|
|
for mode in modes:
|
|
for xi in x:
|
|
mode_i = mode
|
|
# TODO: enable bicubic downsample when ORT precision loss fixed
|
|
if mode == "bicubic" and xi.dim() != 4:
|
|
continue
|
|
elif mode == "linear":
|
|
if xi.dim() == 3:
|
|
# TODO : enable when linear mode is implemented for 1d inputs in ORT
|
|
continue
|
|
elif xi.dim() == 4:
|
|
mode_i = "bilinear"
|
|
elif xi.dim() == 5:
|
|
# TODO : enable when linear mode is implemented for 3d inputs in ORT
|
|
mode_i = "trilinear"
|
|
continue
|
|
self._interpolate(xi, mode_i, True, is_upsample)
|
|
# test with align_corners if supported
|
|
if mode != "nearest":
|
|
self._interpolate(xi, mode_i, True, is_upsample, True)
|
|
# the following cases, require dynamic sizes/scales,
|
|
# which which is not supported for opset_version < 9
|
|
if self.opset_version >= 9:
|
|
self._interpolate(xi, mode_i, True, is_upsample)
|
|
# test with align_corners if supported
|
|
if mode != "nearest":
|
|
self._interpolate(xi, mode_i, False, is_upsample, True)
|
|
self._interpolate(xi, mode_i, False, is_upsample)
|
|
|
|
# ONNX export failed on interpolate scripting because dynamic size not supported for opsets below 9.
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_interpolate_upsample(self):
|
|
self._interpolate_tests(True)
|
|
|
|
@skipIfUnsupportedMaxOpsetVersion(8)
|
|
@disableScriptTest() # Scripting supported for opsets > 8. See test_interpolate_upsample
|
|
def test_interpolate_upsample_trace(self):
|
|
self._interpolate_tests(True)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_interpolate_function_substitution(self):
|
|
class ScriptModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.)
|
|
|
|
class ScriptModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptModule, self).__init__()
|
|
self.submodule = ScriptModel()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return self.submodule(input)
|
|
|
|
x = torch.randn(1, 2, 4, 4, 6)
|
|
self.run_test(ScriptModule(), (x,))
|
|
|
|
@torch.jit.script
|
|
def script_method(x):
|
|
return torch.nn.functional.interpolate(x, mode="nearest", scale_factor=2.)
|
|
|
|
class TracingModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return script_method(x)
|
|
|
|
self.run_test(TracingModule(), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_interpolate_downsample(self):
|
|
self._interpolate_tests(False)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_interpolate_no_shape(self):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
x = torch.add(x, x)
|
|
out1 = torch.nn.functional.interpolate(x, mode="bilinear", size=(16, 16), align_corners=False)
|
|
out2 = torch.nn.functional.interpolate(x, mode="nearest", size=(int(y.size(0)), int(y.size(1))))
|
|
return out1, out2
|
|
|
|
x = torch.randn(1, 2, 4, 4, requires_grad=True)
|
|
y = torch.randn(16, 16, requires_grad=True)
|
|
self.run_test(MyModel(), (x, y), input_names=["x", "y"], dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1]})
|
|
self.run_test(MyModel(), (x, y), remained_onnx_input_idx=[0])
|
|
|
|
@disableScriptTest() # scripting throws the ONNXRuntimeError
|
|
def test_interpolate_adaptive_pooling_error(self):
|
|
x = torch.randn(1, 2, 6, requires_grad=True)
|
|
with self.assertRaises(RuntimeError) as cm:
|
|
self._interpolate(x, "area", True, True)
|
|
|
|
with self.assertRaises(RuntimeError) as cm:
|
|
self._interpolate(x, "area", False, True)
|
|
|
|
def test_groupnorm(self):
|
|
model = torch.nn.GroupNorm(3, 6, 0.002)
|
|
x = torch.randn(4, 6, 180, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.GroupNorm(1, 6, 0.002)
|
|
x = torch.randn(4, 6, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.GroupNorm(6, 6, 0.002)
|
|
x = torch.randn(4, 6, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
def test_groupnorm_noaffine(self):
|
|
model = torch.nn.GroupNorm(4, 8, 0.002, affine=False)
|
|
x = torch.randn(3, 8, 224, 224)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.GroupNorm(1, 6, 0.002, affine=False)
|
|
x = torch.randn(4, 6, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.GroupNorm(6, 6, 0.002, affine=False)
|
|
x = torch.randn(4, 6, 180, 180)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_listunpack(self):
|
|
class ListUnpack(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
a, b = x.shape
|
|
return x.new_zeros((a, b))
|
|
|
|
x = torch.randn(2, 3)
|
|
self.run_test(ListUnpack(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(ListUnpack(), x, remained_onnx_input_idx=[])
|
|
|
|
class ListUnpackSlice(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
a, b = x.shape[2:]
|
|
return x.new_zeros((a, b))
|
|
|
|
x = torch.randn(2, 3, 4, 5)
|
|
self.run_test(ListUnpackSlice(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]})
|
|
self.run_test(ListUnpackSlice(), x, remained_onnx_input_idx=[])
|
|
|
|
def test_pow(self):
|
|
class PowModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.pow(y)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(2, 3, 4)
|
|
self.run_test(PowModule(), (x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 4))
|
|
y = torch.randint(10, (2, 3, 4)).to(dtype=torch.int32)
|
|
self.run_test(PowModule(), (x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 4))
|
|
y = torch.randint(10, (2, 3, 4))
|
|
self.run_test(PowModule(), (x, y))
|
|
|
|
x = torch.randn(2, 3, 4).to(dtype=torch.float64)
|
|
y = torch.randint(10, (2, 3, 4))
|
|
self.run_test(PowModule(), (x, y))
|
|
|
|
class PowModule2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.pow(2, x)
|
|
|
|
x = torch.randn(1, 10)
|
|
self.run_test(PowModule2(), (x,))
|
|
|
|
x = torch.randint(10, (2, 3, 4))
|
|
self.run_test(PowModule2(), (x,))
|
|
|
|
x = torch.randn(1, 10).to(dtype=torch.float64)
|
|
self.run_test(PowModule2(), (x,))
|
|
|
|
class PowModule3(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return y[torch.pow(2, x)]
|
|
|
|
x = torch.randint(5, (2, 3, 4))
|
|
y = torch.rand(100)
|
|
self.run_test(PowModule3(), (x, y))
|
|
|
|
# the arithmeticOps(Add\Sub\Mul\Div\Gemm\Pow\Mod) with low precision include unit8 will be failed in ORT
|
|
# add to(dtype=torch.long) to avoid ORT output type does not match expected type.
|
|
# will be fixed in ONNX version 14.
|
|
@skipIfUnsupportedMaxOpsetVersion(13)
|
|
def test_arithmeticOps_with_low_precision(self):
|
|
class AddModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
class SubModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x - y
|
|
|
|
class MulModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x * y
|
|
|
|
class DivModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x / y
|
|
|
|
class PowModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.pow(y)
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.uint8)
|
|
z = torch.tensor([1], dtype=torch.uint8)
|
|
self.run_test(AddModule(), (x, y))
|
|
self.run_test(SubModule(), (x, y))
|
|
self.run_test(MulModule(), (x, y))
|
|
self.run_test(DivModule(), (x, y))
|
|
self.run_test(PowModule(), (x, z))
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.int8)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.int8)
|
|
z = torch.tensor([1], dtype=torch.int8)
|
|
self.run_test(AddModule(), (x, y))
|
|
self.run_test(SubModule(), (x, y))
|
|
self.run_test(MulModule(), (x, y))
|
|
self.run_test(DivModule(), (x, y))
|
|
self.run_test(PowModule(), (x, z))
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.int16)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.int16)
|
|
z = torch.tensor([1], dtype=torch.int16)
|
|
self.run_test(AddModule(), (x, y))
|
|
self.run_test(SubModule(), (x, y))
|
|
self.run_test(MulModule(), (x, y))
|
|
self.run_test(DivModule(), (x, y))
|
|
self.run_test(PowModule(), (x, z))
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.float32)
|
|
z = torch.tensor([1], dtype=torch.float64)
|
|
self.run_test(AddModule(), (x, y))
|
|
self.run_test(SubModule(), (x, y))
|
|
self.run_test(MulModule(), (x, y))
|
|
self.run_test(DivModule(), (x, y))
|
|
self.run_test(PowModule(), (x, z))
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.int64)
|
|
z = torch.tensor([1], dtype=torch.int32)
|
|
self.run_test(AddModule(), (x, y))
|
|
self.run_test(SubModule(), (x, y))
|
|
self.run_test(MulModule(), (x, y))
|
|
self.run_test(DivModule(), (x, y))
|
|
self.run_test(PowModule(), (x, z))
|
|
|
|
# fmod was added in version 10
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@skipIfUnsupportedMaxOpsetVersion(13)
|
|
def test_mod_with_low_precision(self):
|
|
class ModModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.fmod(x, y).to(dtype=torch.long)
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.uint8)
|
|
self.run_test(ModModule(), (x, y))
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.int8)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.int8)
|
|
self.run_test(ModModule(), (x, y))
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.int16)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.int16)
|
|
self.run_test(ModModule(), (x, y))
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.int32)
|
|
self.run_test(ModModule(), (x, y))
|
|
|
|
x = torch.tensor([2, 3, 5], dtype=torch.uint8)
|
|
y = torch.tensor([2, 3, 5], dtype=torch.float64)
|
|
self.run_test(ModModule(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_empty_constant_shape(self):
|
|
class Zeros(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.zeros(())
|
|
y += x
|
|
return y
|
|
x = torch.tensor(42.)
|
|
self.run_test(Zeros(), x)
|
|
|
|
class Ones(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.ones(())
|
|
y += x
|
|
return y
|
|
x = torch.tensor(42.)
|
|
self.run_test(Ones(), x)
|
|
|
|
class Full(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.full((), 1.)
|
|
y += x
|
|
return y
|
|
x = torch.tensor(42.)
|
|
self.run_test(Full(), x)
|
|
|
|
class Empty(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.empty(()).fill_(0)
|
|
y += x
|
|
return y
|
|
x = torch.tensor(42.)
|
|
self.run_test(Empty(), x)
|
|
|
|
def test_std(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
class StandardDeviationUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, unbiased=True)
|
|
|
|
model = StandardDeviationUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_along_dims(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, dim=(0, 1), unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
class StandardDeviationUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, dim=(0, 1), unbiased=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviationUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_keepdim(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
class StandardDeviationUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, dim=(0, 1), unbiased=True, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviationUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_correction(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std(input, dim=(0, 1), correction=3, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
def test_var(self):
|
|
class Variance(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var(input, unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = Variance()
|
|
self.run_test(model, x)
|
|
|
|
class VarianceUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var(input, unbiased=True)
|
|
|
|
model = VarianceUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
class VarianceSqrt(torch.nn.Module):
|
|
def forward(self, input):
|
|
y = torch.var(input, 1)
|
|
return torch.sqrt(y + 1e-8)
|
|
|
|
x = torch.randn(1, 2, 3, 300, 300)
|
|
model = VarianceSqrt()
|
|
self.run_test(model, x)
|
|
|
|
def test_var_along_dims(self):
|
|
class Variance(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var(input, dim=(0, 1), unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = Variance()
|
|
self.run_test(model, x)
|
|
|
|
class VarianceUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var(input, dim=(0, 1), unbiased=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = VarianceUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_var_keepdim(self):
|
|
class Variance(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var(input, dim=(0, 1), unbiased=False, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = Variance()
|
|
self.run_test(model, x)
|
|
|
|
class VarianceUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var(input, dim=(0, 1), unbiased=True, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = VarianceUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_var_correction(self):
|
|
class Variance(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var(input, dim=(0, 1), correction=3, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = Variance()
|
|
self.run_test(model, x)
|
|
|
|
def test_var_mean(self):
|
|
class Variance(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = Variance()
|
|
self.run_test(model, x)
|
|
|
|
class VarianceUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, unbiased=True)
|
|
|
|
model = VarianceUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_var_mean_along_dims(self):
|
|
class Variance(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, dim=(0, 1), unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = Variance()
|
|
self.run_test(model, x)
|
|
|
|
class VarianceUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, dim=(0, 1), unbiased=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = VarianceUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_var_mean_mixed_dims(self):
|
|
class ReverseDims(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, dim=(2, 1), unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = ReverseDims()
|
|
self.run_test(model, x)
|
|
|
|
class SkipDims(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, dim=(0, 2), unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = SkipDims()
|
|
self.run_test(model, x)
|
|
|
|
class NonZeroDims(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, dim=(1, 2), unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = NonZeroDims()
|
|
self.run_test(model, x)
|
|
|
|
def test_var_mean_keepdim(self):
|
|
class Variance(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, dim=(0, 1), unbiased=False, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = Variance()
|
|
self.run_test(model, x)
|
|
|
|
class VarianceUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, dim=(0, 1), unbiased=True, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = VarianceUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_var_mean_correction(self):
|
|
class Variance(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = Variance()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_mean(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std_mean(input, unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
class StandardDeviationUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std_mean(input, unbiased=True)
|
|
|
|
model = StandardDeviationUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_mean_along_dims(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std_mean(input, dim=(0, 1), unbiased=False)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
class VarianceUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std_mean(input, dim=(0, 1), unbiased=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = VarianceUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_mean_keepdim(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std_mean(input, dim=(0, 1), unbiased=False, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
class StandardDeviationUnbiased(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.std_mean(input, dim=(0, 1), unbiased=True, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviationUnbiased()
|
|
self.run_test(model, x)
|
|
|
|
def test_std_mean_correction(self):
|
|
class StandardDeviation(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.var_mean(input, dim=(0, 1), correction=3, keepdim=True)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = StandardDeviation()
|
|
self.run_test(model, x)
|
|
|
|
def test_bitshift(self):
|
|
class BitshiftModel(torch.nn.Module):
|
|
def forward(self, input, input2):
|
|
return input >> 1, input << 3.1, \
|
|
input2 >> torch.tensor([1, 2]), input2 << 4.2
|
|
input = torch.arange(24, dtype=torch.float32).reshape(3, 4, 2)
|
|
input2 = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
|
|
self.run_test(BitshiftModel(), (input, input2))
|
|
|
|
def test_bitshift_other_fp(self):
|
|
class BitshiftModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input << 2.4
|
|
input = torch.arange(24, dtype=torch.int64).reshape(3, 4, 2)
|
|
self.run_test(BitshiftModel(), input)
|
|
|
|
# uint8 not implemented in ORT for Mul used in
|
|
# exporting bitshift for opset_version < 10
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_bitshift_uint8(self):
|
|
class BitshiftModel(torch.nn.Module):
|
|
def forward(self, input, input2):
|
|
return input >> 1, input << 3., \
|
|
input2 >> torch.tensor([1, 2], dtype=torch.uint8), input2 << 4.
|
|
input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
|
|
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
|
|
self.run_test(BitshiftModel(), (input, input2))
|
|
|
|
def test_narrow(self):
|
|
class NarrowModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.narrow(input, 0, 0, 2)
|
|
|
|
x = torch.randn(3, 3, requires_grad=True)
|
|
self.run_test(NarrowModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_narrow_dynamic(self):
|
|
class NarrowModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.narrow(input, 0, 0, input.shape[0] - 1)
|
|
|
|
x = torch.randn(3, 3, requires_grad=True)
|
|
self.run_test(NarrowModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_fill(self):
|
|
class IndexFillModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
index = torch.tensor([2, 0])
|
|
return input.index_fill(2, index, -1)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(IndexFillModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_copy(self):
|
|
class IndexCopyModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
index = torch.tensor([2, 0])
|
|
source = torch.ones(3, 2, 5)
|
|
return input.index_copy(1, index, source)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(IndexCopyModel(), x)
|
|
|
|
def test_select(self):
|
|
class Select(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, 1]
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(Select(), x)
|
|
|
|
def test_select_negative_index(self):
|
|
class Select(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[:, -1]
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(Select(), x)
|
|
|
|
def test_index_select_constant_scaler_index(self):
|
|
class IndexSelectScalerIndexModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
index = 2
|
|
return torch.index_select(x, 1, torch.tensor(index))
|
|
x = torch.randn(3, 4)
|
|
self.run_test(IndexSelectScalerIndexModel(), x)
|
|
|
|
def test_index_select_scaler_index(self):
|
|
class IndexSelectScalerIndexModel(torch.nn.Module):
|
|
def __init__(self, index_base):
|
|
super(IndexSelectScalerIndexModel, self).__init__()
|
|
self.index_base = torch.tensor(index_base)
|
|
|
|
def forward(self, x, index_offset):
|
|
index = self.index_base + index_offset
|
|
return torch.index_select(x, 1, index)
|
|
x = torch.randn(3, 4)
|
|
offset = 2
|
|
index_offset = torch.tensor(offset)
|
|
base = 1
|
|
self.run_test(IndexSelectScalerIndexModel(base), (x, index_offset))
|
|
|
|
def test_take(self):
|
|
class TakeModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.take(x, y)
|
|
|
|
x = torch.randn(6, 4, 3, 3)
|
|
y = torch.tensor([4, 1, 7, 15, 63])
|
|
self.run_test(TakeModel(), (x, y))
|
|
|
|
def test_topk(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.topk(x, 3)
|
|
|
|
x = torch.arange(1., 6., requires_grad=True)
|
|
self.run_test(MyModule(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_topk_int32_k(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, k):
|
|
return torch.topk(x, k)
|
|
|
|
x = torch.arange(1., 6.)
|
|
k = torch.tensor(3, dtype=torch.int32)
|
|
self.run_test(Model(), (x, k))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_topk_smallest_unsorted(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x, k):
|
|
# When sorted=False, order of elements in the outout tensors
|
|
# are not expected to match between PyTorch and ORT
|
|
topk_unsorted = torch.topk(x, k, largest=False, sorted=False)
|
|
topk_sorted = torch.topk(x, k, largest=False, sorted=True)
|
|
return topk_sorted, torch.sort(topk_unsorted.values).values
|
|
|
|
x = torch.arange(1., 6., requires_grad=True)
|
|
k = torch.tensor(3)
|
|
self.run_test(MyModule(), (x, k))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_topk_script(self):
|
|
class MyModuleDynamic(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, k):
|
|
return torch.topk(x, k)
|
|
|
|
x = torch.arange(1., 6., requires_grad=True)
|
|
k = torch.tensor(3)
|
|
self.run_test(MyModuleDynamic(), [x, k])
|
|
|
|
@disableScriptTest() # Python builtin apply of FunctionMeta object is currently not supported in Torchscript.
|
|
@skipIfUnsupportedMinOpsetVersion(11) # Clip op min is an input since opset 11.
|
|
def test_auto_grad(self):
|
|
class MyClip(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input, scalar):
|
|
ctx.save_for_backward(input)
|
|
return input.clamp(min=scalar)
|
|
|
|
class MyRelu(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input):
|
|
ctx.save_for_backward(input)
|
|
return input.clamp(min=0)
|
|
|
|
def symbolic_python_op(ctx: torch.onnx.SymbolicContext, g: torch._C.Graph, *args, **kwargs):
|
|
n = ctx.cur_node
|
|
name = kwargs["name"]
|
|
if name == "MyClip":
|
|
return g.op("Clip", args[0], args[1], outputs=n.outputsSize())
|
|
elif name == "MyRelu":
|
|
return g.op("Relu", args[0], outputs=n.outputsSize())
|
|
else:
|
|
return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
|
|
|
|
register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)
|
|
self.addCleanup(unregister_custom_op_symbolic, "prim::PythonOp", 1)
|
|
|
|
class MyClipModule(torch.nn.Module):
|
|
def forward(self, x, min):
|
|
return MyClip.apply(x, min)
|
|
x = torch.randn(3, 3)
|
|
min = torch.tensor([0.0])
|
|
self.run_test(MyClipModule(), (x, min))
|
|
|
|
class MyReluModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return MyRelu.apply(x)
|
|
x = torch.randn(3, 3)
|
|
self.run_test(MyReluModule(), x)
|
|
|
|
def test_clip_int(self):
|
|
class MyClipInt(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.clamp(x, 0, 1)
|
|
self.run_test(MyClipInt(), torch.randn(3, 3).to(torch.int64))
|
|
|
|
def test_relu_int(self):
|
|
self.run_test(torch.nn.ReLU(), torch.randn(3, 3).to(torch.int32))
|
|
|
|
def test_pad_int(self):
|
|
class MyPadInt(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.pad(x, (1, 1))
|
|
self.run_test(MyPadInt(), torch.randn(3, 3).to(torch.int32))
|
|
|
|
def test_min_int(self):
|
|
class MyMinInt(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.min(x, x + 1)
|
|
self.run_test(MyMinInt(), torch.randn(3, 3).to(torch.int32))
|
|
|
|
def test_max_int(self):
|
|
class MyMaxnInt(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.max(x, x + 1)
|
|
self.run_test(MyMaxnInt(), torch.randn(3, 3).to(torch.int32))
|
|
|
|
@skipIfUnsupportedOpsetVersion([7])
|
|
def test_normalize(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.normalize(x)
|
|
|
|
x = torch.randn(3, 3)
|
|
self.run_test(Model(), x)
|
|
|
|
def test_layer_norm(self):
|
|
model = torch.nn.LayerNorm([10, 10])
|
|
x = torch.randn(20, 5, 10, 10)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm1d(self):
|
|
x = torch.randn(10, 10)
|
|
model = torch.nn.BatchNorm1d(10, affine=True)
|
|
self.run_test(model, x)
|
|
|
|
x = torch.randn(10, 10, 128)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm1d_noaffine(self):
|
|
x = torch.randn(10, 10)
|
|
model = torch.nn.BatchNorm1d(10, affine=False)
|
|
self.run_test(model, x)
|
|
|
|
x = torch.randn(10, 10, 128)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm1d_norunningstats(self):
|
|
x = torch.randn(10, 10)
|
|
model = torch.nn.BatchNorm1d(10, track_running_stats=False)
|
|
self.run_test(model, x)
|
|
|
|
x = torch.randn(10, 10, 128)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm2d(self):
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model = torch.nn.BatchNorm2d(3, affine=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm2d_noaffine(self):
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model = torch.nn.BatchNorm2d(3, affine=False)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm2d_norunningstats(self):
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model = torch.nn.BatchNorm2d(3, track_running_stats=False)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm3d(self):
|
|
x = torch.randn(10, 3, 128, 128, 128)
|
|
model = torch.nn.BatchNorm3d(3, affine=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_batchnorm3d_noaffine(self):
|
|
x = torch.randn(10, 3, 128, 128, 128)
|
|
model = torch.nn.BatchNorm3d(3, affine=False)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because ConstantOfShape op is not supported for opset < 9
|
|
def test_instancenorm1d_runningstats(self):
|
|
x = torch.randn(10, 5, 128)
|
|
model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=True)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_instancenorm1d_norunningstats(self):
|
|
x = torch.randn(10, 5, 128)
|
|
model = torch.nn.InstanceNorm1d(5, affine=True, track_running_stats=False)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.InstanceNorm1d(5, affine=False, track_running_stats=False)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because ConstantOfShape op is not supported for opset < 9
|
|
def test_instancenorm2d_runningstats(self):
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=True)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_instancenorm2d_norunningstats(self):
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model = torch.nn.InstanceNorm2d(3, affine=True, track_running_stats=False)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.InstanceNorm2d(3, affine=False, track_running_stats=False)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because ConstantOfShape op is not supported for opset < 9
|
|
def test_instancenorm3d_runningstats(self):
|
|
x = torch.randn(10, 3, 128, 128, 128)
|
|
model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=True)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_instancenorm3d_norunningstats(self):
|
|
x = torch.randn(10, 3, 128, 128, 128)
|
|
model = torch.nn.InstanceNorm3d(3, affine=True, track_running_stats=False)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.InstanceNorm3d(3, affine=False, track_running_stats=False)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scatter_with_scalar(self):
|
|
class ScatterModel(torch.nn.Module):
|
|
def forward(self, input, indices):
|
|
values = 1.0
|
|
return input.scatter(1, indices, values)
|
|
|
|
input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=torch.float64)
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
self.run_test(ScatterModel(), input=(input, indices))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scatter_with_scalar_different_types(self):
|
|
# Tests the case when scalar src (updates values) type is different
|
|
# from self type. Happens only with scalar src - PyTorch does not
|
|
# allow this when src is a tensor.
|
|
class ScatterModel(torch.nn.Module):
|
|
def forward(self, input, indices):
|
|
values = 1.0
|
|
return input.scatter(1, indices, values)
|
|
|
|
input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=torch.float32)
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
self.run_test(ScatterModel(), input=(input, indices))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scatter(self):
|
|
class ScatterModel(torch.nn.Module):
|
|
def forward(self, input, indices, values):
|
|
return input.scatter(1, indices, values)
|
|
|
|
input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
|
|
self.run_test(ScatterModel(), input=(input, indices, values))
|
|
|
|
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
|
indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
|
|
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
|
|
self.run_test(ScatterModel(), (input, indices, values))
|
|
|
|
input = torch.zeros(3, 4, 5, 6)
|
|
indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64)
|
|
indices = indices.view(3, 2, 1, 1).expand(3, 2, 5, 6)
|
|
values = torch.arange(3 * 2 * 5 * 6, dtype=torch.float32).view(3, 2, 5, 6)
|
|
self.run_test(ScatterModel(), (input, indices, values))
|
|
|
|
input = torch.zeros(3, 4, 2)
|
|
indices = torch.tensor([[[1, 0], [0, 2]], [[1, 1], [0, 1]], [[2, 1], [2, 2]]])
|
|
values = torch.arange(3 * 2 * 2, dtype=torch.float32).view(3, 2, 2)
|
|
self.run_test(ScatterModel(), (input, indices, values))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scatter_add(self):
|
|
class ScatterModel(torch.nn.Module):
|
|
def forward(self, input, indices, values):
|
|
return input.scatter_add(1, indices, values)
|
|
|
|
input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])
|
|
self.run_test(ScatterModel(), input=(input, indices, values))
|
|
|
|
@torch.jit.script
|
|
def scatter_sum(src: torch.Tensor, index: torch.Tensor):
|
|
size = src.size()
|
|
out = torch.zeros(size, dtype=src.dtype)
|
|
return out.scatter_add_(1, index, src)
|
|
|
|
class ScatterModel(torch.nn.Module):
|
|
def forward(self, src, index):
|
|
return scatter_sum(src, index)
|
|
|
|
src = torch.rand(3, 2)
|
|
index = torch.tensor([[0, 1], [0, 1], [0, 1]], dtype=torch.int64)
|
|
self.run_test(ScatterModel(), (src, index))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_bucketize(self):
|
|
class BucketModel(torch.nn.Module):
|
|
def forward(self, input, boundaries):
|
|
return torch.bucketize(input, boundaries), \
|
|
torch.bucketize(input, boundaries, right=True)
|
|
|
|
input = torch.tensor([[2, 5, 10], [6, 8, 3]])
|
|
boundaries = torch.tensor([1, 5, 7, 8, 10])
|
|
self.run_test(BucketModel(), (input, boundaries))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_one_hot(self):
|
|
class OneHot(torch.nn.Module):
|
|
def __init__(self, num_classes):
|
|
super().__init__()
|
|
self.num_classes = num_classes
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.one_hot(x, self.num_classes)
|
|
|
|
x = torch.arange(10)
|
|
self.run_test(OneHot(15), (x))
|
|
|
|
class OneHot(torch.nn.Module):
|
|
def forward(self, x, num_classes):
|
|
num_classes = num_classes.to(torch.int32)
|
|
return torch.nn.functional.one_hot(x, num_classes[0])
|
|
|
|
x = torch.arange(10)
|
|
num_classes = 15 * torch.ones(1)
|
|
self.run_test(OneHot(), (x, num_classes))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_gather(self):
|
|
class GatherModel(torch.nn.Module):
|
|
def forward(self, input, indices):
|
|
return input.gather(1, indices)
|
|
|
|
input = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
|
|
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
|
|
self.run_test(GatherModel(), input=(input, indices))
|
|
|
|
@disableScriptTest() # Scripting error: Cannot instantiate nn module
|
|
def test_gather_constant_fold(self):
|
|
class GatherModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(GatherModule, self).__init__()
|
|
self.register_buffer("weight", torch.ones(5))
|
|
# torch.nn.Embedding is converted to ONNX::Gather.
|
|
# Constant folding will be triggerred for constant inputs.
|
|
# This pattern is common for constant mask inputs in transformer models.
|
|
self.embed = torch.nn.Embedding(8, 3)
|
|
|
|
def forward(self, x):
|
|
# shape is of rank 0
|
|
shape = self.weight.shape[0]
|
|
m = 5 - shape
|
|
y = torch.ones(1, 4, dtype=torch.long)
|
|
return x.clamp(min=m), self.embed(y)
|
|
|
|
x = torch.randn(1)
|
|
self.run_test(GatherModule(), (x,))
|
|
|
|
class GatherModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(GatherModule, self).__init__()
|
|
self.register_buffer("weight", torch.ones(2))
|
|
|
|
def forward(self, x):
|
|
# shape is of rank 0
|
|
shape = self.weight.shape[0]
|
|
pad = [1, shape, shape, shape]
|
|
zero_pad = torch.nn.ZeroPad2d(pad)
|
|
return zero_pad(x)
|
|
|
|
x = torch.randn(1, 3, 2)
|
|
self.run_test(GatherModule(), (x,))
|
|
|
|
class GatherModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(GatherModule, self).__init__()
|
|
self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1))
|
|
|
|
def forward(self, x):
|
|
x += self.rb[0]
|
|
return x
|
|
|
|
x = torch.randn(1, 3, 224, 224)
|
|
self.run_test(GatherModule(), (x,),
|
|
dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
|
|
"output": {0: "batch", 1: "class", 2: "height", 3: "width"}},
|
|
input_names=['input'], output_names=['output'])
|
|
|
|
@skipIfUnsupportedOpsetVersion([13])
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_expand(self):
|
|
class ExpandModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.expand(2, 3, -1)
|
|
|
|
input = torch.randn(2, 1, 4)
|
|
self.run_test(ExpandModel(), input=(input))
|
|
|
|
class ExpandInferDimModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.expand(-1, input.size(0))
|
|
|
|
input = torch.randn(3, 1)
|
|
self.run_test(ExpandInferDimModel(), input=(input))
|
|
|
|
class ExpandTensorSizeModel(torch.nn.Module):
|
|
def forward(self, input, size):
|
|
return input.expand(size)
|
|
|
|
input = torch.randn(3,)
|
|
size = torch.tensor(-1)
|
|
self.run_test(ExpandTensorSizeModel(), input=(input, size))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11) # index_put is supported in opsets >= 11
|
|
def test_dynamic_expand_as(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
x[:, x.size(0):] = 0
|
|
return x
|
|
|
|
x = torch.ones(2, 5)
|
|
x2 = torch.randn(3, 4)
|
|
self.run_test(Model(), (x, ),
|
|
input_names=["x"],
|
|
dynamic_axes={"x": [0, 1]},
|
|
test_with_inputs=[x2])
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
x[:, x.size(0):] = torch.tensor([1, 2, 3])
|
|
return x
|
|
|
|
x = torch.ones(2, 5, 3)
|
|
x2 = torch.randn(3, 4, 3)
|
|
self.run_test(Model(), (x, ),
|
|
input_names=["x"],
|
|
dynamic_axes={"x": [0, 1, 2]},
|
|
test_with_inputs=[x2])
|
|
|
|
def test_multinomial(self):
|
|
class Multinomial(torch.nn.Module):
|
|
def forward(self, weight):
|
|
return torch.multinomial(weight, 3, replacement=True)
|
|
|
|
class MultinomialNoReplacement(torch.nn.Module):
|
|
def forward(self, weight):
|
|
return torch.multinomial(weight, 1)
|
|
|
|
weight = torch.tensor([[0, 10, 0, 0], [0, 0, 100, 0]], dtype=torch.float)
|
|
self.run_test(Multinomial(), (weight,))
|
|
self.run_test(MultinomialNoReplacement(), (weight,))
|
|
|
|
def _test_reduced_ops(self, op):
|
|
class ReducedOpModule(torch.nn.Module):
|
|
def forward(self, input):
|
|
return op(input, dim=-1)
|
|
|
|
if op != torch.mean: # torch.mean only supports float types
|
|
x = torch.randint(10, (4, 4), dtype=torch.uint8)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int8)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int16)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int32)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int64)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
# torch.mean only supports float types
|
|
# ORT does not support double ReduceProd for double
|
|
if op != torch.prod and op != torch.mean:
|
|
x = torch.randn(4, 5, dtype=torch.double)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
if op != torch.prod: # torch.prod not implemented for Half
|
|
x = torch.randn(4, 4, dtype=torch.half)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
x = torch.randn(4, 5, dtype=torch.float)
|
|
self.run_test(ReducedOpModule(), x)
|
|
|
|
def test_reduced_sum(self):
|
|
return self._test_reduced_ops(op=torch.sum)
|
|
|
|
def test_reduced_mean(self):
|
|
return self._test_reduced_ops(op=torch.mean)
|
|
|
|
def test_reduced_prod(self):
|
|
return self._test_reduced_ops(op=torch.prod)
|
|
|
|
def test_reduced_sum_dtypes(self):
|
|
class NoDimModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.sum(dtype=torch.float)
|
|
|
|
class DimModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.sum(dim=-1, dtype=torch.float)
|
|
|
|
input = torch.randn((4, 4), dtype=torch.half)
|
|
self.run_test(NoDimModel(), input)
|
|
self.run_test(DimModel(), input)
|
|
|
|
def test_reduced_min_max(self):
|
|
class ReducedMinMaxModule(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.min(input, dim=-1)[0], torch.max(input, dim=0)[0]
|
|
x = torch.randint(10, (4, 4), dtype=torch.int32)
|
|
self.run_test(ReducedMinMaxModule(), x)
|
|
|
|
x = torch.randint(10, (4, 4), dtype=torch.int64)
|
|
self.run_test(ReducedMinMaxModule(), x)
|
|
|
|
x = torch.randn(4, 5, dtype=torch.float)
|
|
self.run_test(ReducedMinMaxModule(), x)
|
|
|
|
def test_reduce_log_sum_exp(self):
|
|
class ReduceLogSumExpModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
a = torch.logsumexp(input, dim=0)
|
|
b = torch.logsumexp(input, dim=(0, 1))
|
|
return a + b
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
self.run_test(ReduceLogSumExpModel(), x)
|
|
|
|
def test_softmax(self):
|
|
for i in range(-4, 3):
|
|
model = torch.nn.Softmax(dim=i)
|
|
input = torch.randn(3, 4, 5, 6)
|
|
self.run_test(model, input)
|
|
|
|
class SoftmaxUnknownRank(torch.nn.Module):
|
|
def __init__(self, i):
|
|
super().__init__()
|
|
self.softmax = torch.nn.Softmax(dim=i)
|
|
|
|
def forward(self, x):
|
|
return self.softmax(x.reshape(3, 4, 5, 6))
|
|
|
|
model = torch.jit.script(SoftmaxUnknownRank(i))
|
|
self.run_test(model, input)
|
|
|
|
def test_softmax_large_values(self):
|
|
input = torch.tensor([[-1e12, -1e12, -1e12], [1e12, 0.0, -5.0], [3.0, 4.0, 5.0]])
|
|
for i in range(-2, 1):
|
|
model = torch.nn.Softmax(dim=i)
|
|
self.run_test(model, input)
|
|
|
|
class SoftmaxUnknownRank(torch.nn.Module):
|
|
def __init__(self, i):
|
|
super().__init__()
|
|
self.softmax = torch.nn.Softmax(dim=i)
|
|
|
|
def forward(self, x):
|
|
return self.softmax(x.reshape(3, 3))
|
|
|
|
model = torch.jit.script(SoftmaxUnknownRank(i))
|
|
self.run_test(model, input)
|
|
|
|
def test_logsoftmax(self):
|
|
for i in range(7)[2:]:
|
|
model = torch.nn.LogSoftmax(dim=i - 1)
|
|
dims = [2] * (i - 2) + [3, 4]
|
|
input = torch.ones(*dims, requires_grad=True)
|
|
self.run_test(model, input)
|
|
|
|
def test_logsoftmax_dim(self):
|
|
for i in range(-4, 3):
|
|
model = torch.nn.LogSoftmax(dim=i)
|
|
input = torch.randn(3, 4, 5, 6)
|
|
self.run_test(model, input)
|
|
|
|
def test_logsoftmax_dtype(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.log_softmax(x, dim=1, dtype=torch.float64)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(Model(), x)
|
|
|
|
def test_softplus(self):
|
|
class BetaOneModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.softplus(x)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(BetaOneModel(), x)
|
|
|
|
class BetaModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.softplus(x, beta=2)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(BetaModel(), x)
|
|
|
|
class BetaFloatModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.softplus(x, beta=1.7)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(BetaFloatModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_lstm_no_hidden(self):
|
|
class LSTMModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16)
|
|
|
|
def forward(self, x):
|
|
return self.rnn(x)
|
|
|
|
input = torch.randn((10, 16, 16))
|
|
self.run_test(LSTMModel(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_lstm_proj_no_hidden(self):
|
|
class LSTMModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16, proj_size=8)
|
|
|
|
def forward(self, x):
|
|
return self.rnn(x)
|
|
|
|
input = torch.randn((10, 16, 16))
|
|
with self.assertRaises(RuntimeError):
|
|
self.run_test(LSTMModel(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_lstm(self):
|
|
class LSTMModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.rnn = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
|
|
|
|
def forward(self, x, h0, c0):
|
|
return self.rnn(x, (h0, c0))
|
|
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
self.run_test(LSTMModel(), (input, h0, c0))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_lstm_cell(self):
|
|
class LSTMCellModel(torch.nn.Module):
|
|
def __init__(self, bias):
|
|
super().__init__()
|
|
self.lstm_cell = torch.nn.LSTMCell(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, bias=bias)
|
|
|
|
def forward(self, x, h0, c0):
|
|
return self.lstm_cell(x, (h0, c0))
|
|
|
|
input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE)
|
|
h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
c0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
for bias in [True, False]:
|
|
self.run_test(LSTMCellModel(bias), (input, h0, c0))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_lstm_default_init_state(self):
|
|
class LSTMModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.rnn = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
|
|
|
|
def forward(self, x):
|
|
return self.rnn(x)
|
|
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
self.run_test(LSTMModel(), input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_lstm_fixed_batch_size(self):
|
|
class LSTMModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LSTMModel, self).__init__()
|
|
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
|
|
self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE
|
|
|
|
def forward(self, input):
|
|
batch_size = input.size()[1]
|
|
h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
|
|
c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
|
|
return self.lstm(input, (h0, c0))
|
|
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
# verify with different input of same batch size
|
|
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
self.run_test(LSTMModel(), input, fixed_batch_size=True, test_with_inputs=[input2])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_lstm_post_fix_init_state(self):
|
|
class LSTMModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LSTMModel, self).__init__()
|
|
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
|
|
self.RNN_HIDDEN_SIZE = RNN_HIDDEN_SIZE
|
|
|
|
def forward(self, input):
|
|
batch_size = input.size()[1]
|
|
h0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
|
|
c0 = torch.ones([1, batch_size, self.RNN_HIDDEN_SIZE])
|
|
return self.lstm(input, (h0, c0))
|
|
|
|
model = LSTMModel()
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, 1, RNN_INPUT_SIZE)
|
|
# verify with different input of different batch size
|
|
input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
self.run_test(model, input, input_names=["input.1"], dynamic_axes={"input.1" : {0 : "seq", 1 : "batch"}},
|
|
test_with_inputs=[input2])
|
|
|
|
def test_lstm_constant_folding(self):
|
|
class LstmNet(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
|
|
super(LstmNet, self).__init__()
|
|
self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
|
|
|
def forward(self, input, initial_state: Tuple[torch.Tensor, torch.Tensor]):
|
|
return self.lstm(input, initial_state)
|
|
|
|
def get_LstmNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
|
|
seq_len, bidirectional):
|
|
num_directions = 2 if bidirectional else 1
|
|
model = LstmNet(input_size, hidden_size, num_layers, bidirectional)
|
|
input = torch.randn(seq_len, batch_size, input_size)
|
|
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
return model, (input, (h0, c0))
|
|
|
|
batch_size1 = 3
|
|
model1, input1 = get_LstmNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
|
|
self.run_test(model1, input1, do_constant_folding=True)
|
|
|
|
batch_size2 = 4
|
|
model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
|
|
self.run_test(model2, input2, do_constant_folding=True)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_lstm_no_bias(self):
|
|
class LstmNet(torch.nn.Module):
|
|
def __init__(self, num_layers, bidirectional):
|
|
super(LstmNet, self).__init__()
|
|
self.lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, num_layers, bias=False, bidirectional=bidirectional)
|
|
|
|
def forward(self, input, initial_state: Tuple[torch.Tensor, torch.Tensor]):
|
|
return self.lstm(input, initial_state)
|
|
|
|
def get_LstmNet_model_and_inputs(num_layers, bidirectional):
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
num_directions = 2 if bidirectional else 1
|
|
model = LstmNet(num_layers, bidirectional)
|
|
h0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
c0 = torch.randn(num_layers * num_directions, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
return model, (input, (h0, c0))
|
|
|
|
num_layers = [1, 1, 2, 3]
|
|
bidirectional = [True, False, True, False]
|
|
models_and_inputs = [get_LstmNet_model_and_inputs(n, b) for n, b in zip(num_layers, bidirectional)]
|
|
for model, input in models_and_inputs:
|
|
self.run_test(model, input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_lstm_sequence(self):
|
|
class LstmNet(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.rnn1 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True)
|
|
self.linear1 = torch.nn.Linear(8 * 2, 8)
|
|
self.rnn2 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True)
|
|
self.linear2 = torch.nn.Linear(8 * 2, 8)
|
|
|
|
def forward(self, input):
|
|
rnn_output1, _ = self.rnn1(input)
|
|
linear_output1 = self.linear1(rnn_output1)
|
|
rnn_output2, _ = self.rnn2(linear_output1)
|
|
linear_output2 = self.linear2(rnn_output2)
|
|
return linear_output2
|
|
|
|
input = torch.zeros((1, 100, 8), dtype=torch.float32)
|
|
self.run_test(LstmNet(), input, input_names=['input'], output_names=['output'],
|
|
dynamic_axes={'input' : {0 : 'batch_size', 1: 'w', 2: 'h'},
|
|
'output' : {0 : 'batch_size', 1: 'w', 2: 'h'}, })
|
|
|
|
@disableScriptTest()
|
|
def test_rnn_no_bias(self):
|
|
def make_model(layers, packed_sequence):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
model = torch.nn.RNN(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, bidirectional=False,
|
|
batch_first=batch_first, bias=False)
|
|
|
|
if packed_sequence == 1:
|
|
model = RnnModelWithPackedSequence(model, False)
|
|
if packed_sequence == 2:
|
|
model = RnnModelWithPackedSequence(model, True)
|
|
return model
|
|
|
|
def make_input(batch_size, layers, packed_sequence):
|
|
batch_first = True if packed_sequence == 2 else False
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
h0 = torch.randn(layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append(h0)
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
layers = [1, 3, 1, 3, 1, 3]
|
|
packed_sequence = [0, 0, 1, 1, 2, 2]
|
|
models = [make_model(l, p) for l, p in zip(layers, packed_sequence)]
|
|
inputs = [make_input(RNN_BATCH_SIZE, l, p) for l, p in zip(layers, packed_sequence)]
|
|
|
|
for model, input in zip(models, inputs):
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
def test_gru_no_bias(self):
|
|
class GruNet(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
|
|
super(GruNet, self).__init__()
|
|
self.mygru = torch.nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional, bias=False)
|
|
|
|
def forward(self, input, initial_state):
|
|
out = self.mygru(input, initial_state)
|
|
return out
|
|
|
|
def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
|
|
seq_len, bidirectional):
|
|
num_directions = 2 if bidirectional else 1
|
|
model = GruNet(input_size, hidden_size, num_layers, bidirectional)
|
|
input = torch.randn(seq_len, batch_size, input_size)
|
|
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
return model, (input, h0)
|
|
|
|
input_size = [7, 5]
|
|
hidden_size = [3, 4]
|
|
num_layers = [2, 3]
|
|
batch_size = [3, 4]
|
|
seq_len = [5, 7]
|
|
bidirectional = [True, False]
|
|
models_and_inputs = [get_GruNet_model_and_inputs(i, h, n, b, s, bi)
|
|
for i, h, n, b, s, bi in zip(input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional)]
|
|
for model, input in models_and_inputs:
|
|
self.run_test(model, input, do_constant_folding=True)
|
|
|
|
def test_gru_constant_folding(self):
|
|
class GruNet(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
|
|
super(GruNet, self).__init__()
|
|
self.mygru = torch.nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
|
|
|
def forward(self, input, initial_state):
|
|
out = self.mygru(input, initial_state)
|
|
return out
|
|
|
|
def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
|
|
seq_len, bidirectional):
|
|
num_directions = 2 if bidirectional else 1
|
|
model = GruNet(input_size, hidden_size, num_layers, bidirectional)
|
|
input = torch.randn(seq_len, batch_size, input_size)
|
|
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
|
|
return model, (input, h0)
|
|
|
|
batch_size1 = 3
|
|
model1, input1 = get_GruNet_model_and_inputs(7, 3, 2, batch_size1, 5, True)
|
|
self.run_test(model1, input1, do_constant_folding=True)
|
|
|
|
batch_size2 = 4
|
|
model2, input2 = get_GruNet_model_and_inputs(5, 4, 3, batch_size2, 7, False)
|
|
self.run_test(model2, input2, do_constant_folding=True)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_max_tensors(self):
|
|
class MaxModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.max(input, other)
|
|
|
|
model = MaxModel()
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 1, requires_grad=True)
|
|
self.run_test(model, (x, y))
|
|
|
|
def test_amax_amin(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.amax(x, dim=0, keepdim=True), torch.amin(x, dim=[0, 1], keepdim=False)
|
|
|
|
model = Model()
|
|
x = torch.randn(4, 4)
|
|
self.run_test(model, x)
|
|
|
|
def test_aminmax(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.aminmax(x, dim=1, keepdim=True), torch.aminmax(x, keepdim=False)
|
|
|
|
model = Model()
|
|
x = torch.randn(3, 4)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_end(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
outputs = ArangeScript()(x)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_end_notype(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0))
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
outputs = ArangeScript()(x)
|
|
self.run_test(ArangeScript(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(ArangeScript(), x, remained_onnx_input_idx=[])
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(a.size(0))
|
|
|
|
self.run_test(ArangeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(ArangeModel(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_start_end(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_start_end_notype(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2.7, a.size(0) + 2).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_arange_start_end_step(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_arange_start_end_step_notype(self):
|
|
class ArangeScript(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1) + a
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
self.run_test(ArangeScript(), x)
|
|
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.arange(2.7, a.size(0) * a.size(1) + 2, a.size(1)).view(-1, 1) + a
|
|
|
|
self.run_test(ArangeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test__dim_arange(self):
|
|
class DimArange(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch._dim_arange(input, 1)
|
|
|
|
x = torch.ones(5, 6)
|
|
self.run_test(DimArange(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
remained_onnx_input_idx = None if self.opset_version < 11 else []
|
|
self.run_test(DimArange(), x, remained_onnx_input_idx=remained_onnx_input_idx)
|
|
|
|
def _test_compare_ops(self, model, num_inputs):
|
|
x_float = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
x_int = torch.randint(10, (3, 4), dtype=torch.int32)
|
|
if num_inputs > 1:
|
|
y_float = torch.randn(1, 2, 3, 4, requires_grad=True)
|
|
y_int = torch.randint(10, (3, 4), dtype=torch.int32)
|
|
self.run_test(model, (x_float, y_float))
|
|
self.run_test(model, (x_float, y_int))
|
|
self.run_test(model, (x_int, y_float))
|
|
self.run_test(model, (x_int, y_int))
|
|
else:
|
|
self.run_test(model, x_float)
|
|
self.run_test(model, x_int)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_and_or_xor(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x ^ y, x | y, x & y, ~x
|
|
|
|
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
|
|
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
|
|
self.run_test(MyModel(), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_logical_and(self):
|
|
class AndModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.logical_and(x, y)
|
|
|
|
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
|
|
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
|
|
self.run_test(AndModel(), input=(x, y))
|
|
|
|
x = torch.randint(10, (5, 5), dtype=torch.int32)
|
|
y = torch.randint(10, (5, 5), dtype=torch.int32)
|
|
self.run_test(AndModel(), input=(x, y))
|
|
|
|
x = torch.randint(10, (5, 5), dtype=torch.double)
|
|
y = torch.randint(10, (5, 5), dtype=torch.double)
|
|
self.run_test(AndModel(), input=(x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
|
|
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
|
|
self.run_test(AndModel(), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_logical_or(self):
|
|
class OrModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.logical_or(x, y)
|
|
|
|
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
|
|
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
|
|
self.run_test(OrModel(), input=(x, y))
|
|
|
|
x = torch.randint(10, (5, 5), dtype=torch.int32)
|
|
y = torch.randint(10, (5, 5), dtype=torch.int32)
|
|
self.run_test(OrModel(), input=(x, y))
|
|
|
|
x = torch.randint(10, (5, 5), dtype=torch.double)
|
|
y = torch.randint(10, (5, 5), dtype=torch.double)
|
|
self.run_test(OrModel(), input=(x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
|
|
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
|
|
self.run_test(OrModel(), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_logical_xor(self):
|
|
class XorModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.logical_xor(x, y)
|
|
|
|
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
|
|
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
|
|
self.run_test(XorModel(), input=(x, y))
|
|
|
|
x = torch.randint(10, (5, 5), dtype=torch.int32)
|
|
y = torch.randint(10, (5, 5), dtype=torch.int32)
|
|
self.run_test(XorModel(), input=(x, y))
|
|
|
|
x = torch.randint(10, (5, 5), dtype=torch.double)
|
|
y = torch.randint(10, (5, 5), dtype=torch.double)
|
|
self.run_test(XorModel(), input=(x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
|
|
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
|
|
self.run_test(XorModel(), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11) # float equal added after opset 11
|
|
def test_eq(self):
|
|
class EqualModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input == other
|
|
self._test_compare_ops(EqualModel(), 2)
|
|
|
|
def test_gt(self):
|
|
class GreaterModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input > other
|
|
self._test_compare_ops(GreaterModel(), 2)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_ge(self):
|
|
class GreaterOrEqualModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input >= other
|
|
self._test_compare_ops(GreaterOrEqualModel(), 2)
|
|
|
|
def test_gt_scalar(self):
|
|
class GreaterModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input > 1
|
|
self._test_compare_ops(GreaterModel(), 1)
|
|
|
|
def test_gt_primitive(self):
|
|
class GreaterModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.y : int = 2
|
|
|
|
def forward(self, x: int):
|
|
return self.y > x
|
|
|
|
x = 3
|
|
self.run_test(GreaterModel(), (x, ))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_ge_scalar(self):
|
|
class GreaterOrEqualModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input >= 1
|
|
self._test_compare_ops(GreaterOrEqualModel(), 1)
|
|
|
|
def test_lt(self):
|
|
class LessModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input > other
|
|
self._test_compare_ops(LessModel(), 2)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_le(self):
|
|
class LessOrEqualModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input <= other
|
|
self._test_compare_ops(LessOrEqualModel(), 2)
|
|
|
|
def test_lt_scalar(self):
|
|
class LessModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input < 1
|
|
self._test_compare_ops(LessModel(), 1)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_le_scalar(self):
|
|
class LessOrEqualModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input <= 1
|
|
self._test_compare_ops(LessOrEqualModel(), 1)
|
|
|
|
def test_matmul(self):
|
|
class MatmulModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.matmul(input, other)
|
|
|
|
x = torch.randn(3, 4, requires_grad=True)
|
|
y = torch.randn(4, 5, requires_grad=True)
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
x = torch.randint(10, (3, 4))
|
|
y = torch.randint(10, (4, 5))
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
def test_matmul_batch(self):
|
|
class MatmulModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.matmul(input, other)
|
|
|
|
x = torch.randn(2, 3, 4, requires_grad=True)
|
|
y = torch.randn(2, 4, 5, requires_grad=True)
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
x = torch.randint(10, (2, 3, 4))
|
|
y = torch.randint(10, (2, 4, 5))
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
def _argmin_argmax_model(self, input):
|
|
class ArgminArgmaxModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.argmin(input), \
|
|
torch.argmax(input), \
|
|
torch.argmin(input, keepdim=True), \
|
|
torch.argmax(input, keepdim=True)
|
|
|
|
self.run_test(ArgminArgmaxModel(), input)
|
|
|
|
def test_argmin_argmax(self):
|
|
input = torch.randn(7, 3, 5)
|
|
self._argmin_argmax_model(input)
|
|
|
|
# Argmin and Argmax with "select_last_index" is not supprted before opset 12
|
|
# "select_last_index" was added in opset 12 to deal with corner case where the
|
|
# same value appears multiple times in the tensor
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_argmin_argmax_select_last_index(self):
|
|
input = torch.tensor([[1., 2., 3.],
|
|
[1., 1., 2.]])
|
|
self._argmin_argmax_model(input)
|
|
|
|
input = torch.ones(7, 3, 5)
|
|
self._argmin_argmax_model(input)
|
|
|
|
def test_repeat(self):
|
|
class RepeatModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
x2 = x.repeat(y.shape[0], 1)
|
|
y1 = y.view(-1, 1)
|
|
return x2 + y1
|
|
|
|
x = torch.tensor([1, 2, 3])
|
|
y = torch.tensor([4, 5, 8, 9])
|
|
self.run_test(RepeatModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_repeat_interleave(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.repeat_interleave(2)
|
|
|
|
x = torch.tensor([1, 2, 3])
|
|
self.run_test(FlattenModel(), (x,))
|
|
|
|
class DimsModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.repeat_interleave(4, dim=1)
|
|
|
|
x = torch.tensor([[1, 2], [3, 4]])
|
|
self.run_test(DimsModel(), (x,))
|
|
|
|
class DimsModel2(torch.nn.Module):
|
|
def forward(self, x):
|
|
repeats = torch.tensor([4])
|
|
return torch.repeat_interleave(x, repeats, dim=1)
|
|
|
|
x = torch.tensor([[1, 2], [3, 4]])
|
|
self.run_test(DimsModel2(), (x,))
|
|
|
|
class RepeatsDimsModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
repeats = torch.tensor([1, 2])
|
|
return torch.repeat_interleave(x, repeats, dim=0)
|
|
|
|
x = torch.tensor([[1, 2], [3, 4]])
|
|
self.run_test(RepeatsDimsModel(), (x,))
|
|
|
|
class RepeatsDimsModel2(torch.nn.Module):
|
|
def forward(self, x):
|
|
repeats = torch.tensor([1, 2])
|
|
return torch.repeat_interleave(x, repeats, dim=1)
|
|
|
|
x = torch.tensor([[1, 2], [3, 4]])
|
|
self.run_test(RepeatsDimsModel2(), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_repeat_interleave_noop(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.repeat_interleave(1, dim=1)
|
|
|
|
x = torch.randn(4, 1, 8)
|
|
self.run_test(Model(), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_dynamic_repeat_interleave(self):
|
|
class SingleDynamicModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
repeats = torch.tensor(4)
|
|
return torch.repeat_interleave(x, repeats, dim=1)
|
|
|
|
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
|
another_x = torch.tensor([[7, 8], [5, 6]])
|
|
self.run_test(SingleDynamicModel(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"], dynamic_axes={"input_1" : {1 : "w"}})
|
|
|
|
class NegDynamicModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
repeats = torch.tensor(4)
|
|
return torch.repeat_interleave(x, repeats, dim=-1)
|
|
|
|
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
|
another_x = torch.tensor([[7, 8], [5, 6]])
|
|
self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"], dynamic_axes={"input_1" : {1 : "w"}})
|
|
|
|
class SingleDynamicModelFloat(torch.nn.Module):
|
|
def forward(self, x):
|
|
repeats = torch.tensor([4])
|
|
return torch.repeat_interleave(x, repeats, dim=0)
|
|
|
|
x = torch.tensor([[1.1, 2.1], [3.1, 4.1]])
|
|
another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]])
|
|
self.run_test(SingleDynamicModelFloat(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h"}})
|
|
|
|
class DynamicRepeatsModel(torch.nn.Module):
|
|
def forward(self, x, repeats):
|
|
return torch.repeat_interleave(x, repeats, dim=1)
|
|
|
|
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
|
another_x = torch.tensor([[7, 8], [5, 6]])
|
|
repeats = torch.tensor([2])
|
|
another_repeats = torch.tensor([4])
|
|
self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(another_x, another_repeats)],
|
|
input_names=["input_1", "repeats_1"],
|
|
dynamic_axes={"input_1" : {1 : "w"}, "repeats_1" : {0 : "r"}})
|
|
|
|
class DynamicRepeatsModel2(torch.nn.Module):
|
|
def forward(self, x, repeats):
|
|
return torch.repeat_interleave(x, repeats, dim=1)
|
|
|
|
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
|
repeats = torch.tensor([2])
|
|
another_repeats = torch.tensor([4])
|
|
self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
|
|
input_names=["input_1", "repeats_1"],
|
|
dynamic_axes={"repeats_1" : {0 : "r"}})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_multiple_dynamic_repeat_interleave(self):
|
|
class DynamicRepeatsModel(torch.nn.Module):
|
|
def forward(self, x, repeats):
|
|
return torch.repeat_interleave(x, repeats, dim=1)
|
|
|
|
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
|
repeats = torch.tensor([2, 3, 4])
|
|
another_repeats = torch.tensor([4, 3, 2])
|
|
self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(x, another_repeats)],
|
|
input_names=["input_1", "repeats_1"],
|
|
dynamic_axes={"repeats_1" : {0 : "r"}})
|
|
|
|
class DynamicRepeatsModel2(torch.nn.Module):
|
|
def forward(self, x, repeats):
|
|
return torch.repeat_interleave(x, repeats, dim=0)
|
|
|
|
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
|
repeats = torch.tensor([2, 3])
|
|
another_repeats = torch.tensor([4, 3])
|
|
self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
|
|
input_names=["input_1", "repeats_1"],
|
|
dynamic_axes={"repeats_1" : {0 : "r"}})
|
|
|
|
def test_view(self):
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.view(4, 24)
|
|
|
|
x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32)
|
|
self.run_test(ViewModel(), x)
|
|
|
|
def test_view_dynamic(self):
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input.view(other.shape)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
shape = torch.randn(6, 4)
|
|
self.run_test(ViewModel(), (x, shape),
|
|
input_names=["x", "shape"], dynamic_axes={"x": [0, 1, 2], "shape": [0, 1]})
|
|
self.run_test(ViewModel(), (x, shape), remained_onnx_input_idx=[0])
|
|
|
|
def test_view_dynamic_zero_dim(self):
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
input = input.view(-1, 2)
|
|
return input.view(1, -1)
|
|
|
|
x = torch.ones(2)
|
|
another_x = torch.empty((0,))
|
|
self.run_test(ViewModel(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"], dynamic_axes={"input_1": [0, ]})
|
|
|
|
def test_view_as(self):
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return input.view_as(other)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(6, 4)
|
|
self.run_test(ViewModel(), (x, y))
|
|
|
|
def test_linear(self):
|
|
class LinearModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LinearModel, self).__init__()
|
|
self.fc = torch.nn.Linear(16, 16)
|
|
|
|
def forward(self, x):
|
|
out = self.fc(x)
|
|
out = self.fc(out)
|
|
return out
|
|
|
|
x = torch.randn(3, 16)
|
|
self.run_test(LinearModel(), (x,))
|
|
|
|
class LinearModel(torch.nn.Module):
|
|
def forward(self, input, weight, bias):
|
|
return torch.nn.functional.linear(input, weight, bias)
|
|
|
|
# input of rank 2
|
|
x = torch.randn(2, 2)
|
|
y = torch.randn(2, 2)
|
|
z = torch.randn(1)
|
|
self.run_test(LinearModel(), (x, y, z))
|
|
|
|
# input of rank 3
|
|
x = torch.randn(3, 3, 3)
|
|
y = torch.randn(3, 3)
|
|
z = torch.randn(1)
|
|
self.run_test(LinearModel(), (x, y, z))
|
|
|
|
@disableScriptTest()
|
|
def test_weight_norm(self):
|
|
# addmm for 3-d inputs converts to onnx::MatMul
|
|
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
# addmm for 2-d inputs converts to onnx::Gemm
|
|
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1)
|
|
x = torch.randn(4, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3))
|
|
x = torch.randn(1, 1, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3), dim=-2)
|
|
x = torch.randn(1, 1, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.utils.weight_norm(torch.nn.Conv1d(3, 6, 3), name="weight")
|
|
x = torch.randn(3, 3, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
@disableScriptTest()
|
|
def test_weight_norm_nodim(self):
|
|
# addmm for 3-d inputs converts to onnx::MatMul
|
|
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
# addmm for 2-d inputs converts to onnx::Gemm
|
|
model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None)
|
|
x = torch.randn(4, 5, requires_grad=True)
|
|
self.run_test(model, x)
|
|
|
|
def test_flatten(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.flatten(input)
|
|
|
|
x = torch.randint(10, (1, 2, 3, 4))
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
x = torch.randn(4)
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
def test_flatten2d(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.flatten(input, 1)
|
|
|
|
x = torch.randint(10, (1, 2, 3, 4))
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
def test_flatten2d_neg(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.flatten(x, 1, -1), torch.flatten(x, 0, -2), torch.flatten(x, 1, -2)
|
|
|
|
x = torch.randint(10, (1, 2, 3, 4))
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_flatten_dynamic_axes(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.flatten(x, start_dim=2, end_dim=3)
|
|
|
|
batch_size = 3
|
|
x = torch.randn(batch_size, 5, 4, 5)
|
|
y = torch.randn(5, 5, 4, 5)
|
|
model = MyModule()
|
|
self.run_test(model, x, test_with_inputs=[y],
|
|
input_names=["input"],
|
|
output_names=["output"],
|
|
dynamic_axes={"input" : {0 : "batch_size"},
|
|
"output" : {0 : "batch_size"}})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_getitem(self):
|
|
class GetItemModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y, z, ind):
|
|
# this will create prim::ListConstruct(x, y, z) + aten::__getitem__
|
|
arr = [x, y, z]
|
|
return arr[ind]
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
y = torch.randn(1, 4, 5)
|
|
z = torch.randn(2, 4, 5)
|
|
ind = torch.tensor(1, dtype=torch.long)
|
|
self.run_test(GetItemModel(), (x, y, z, ind))
|
|
|
|
ind = torch.tensor(-2, dtype=torch.long)
|
|
self.run_test(GetItemModel(), (x, y, z, ind))
|
|
|
|
def test_item(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, i: int):
|
|
return int(x[y[i]].item())
|
|
|
|
x = torch.arange(6, dtype=torch.float)
|
|
y = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
|
|
i = 3
|
|
self.run_test(torch.jit.script(M()), (x, y, i))
|
|
|
|
@disableScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable.
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_nonzero(self):
|
|
class NonzeroModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.nonzero(), x.nonzero(as_tuple=True)
|
|
|
|
x = torch.randn(60).index_fill_(0, torch.randint(0, 60, (20,)), 0).view(3, 4, 5)
|
|
self.run_test(NonzeroModel(), (x,))
|
|
|
|
def test_unbind(self):
|
|
class UnbindModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
_, out, _ = input.unbind()
|
|
return out
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel(), x)
|
|
|
|
class UnbindModel2(torch.nn.Module):
|
|
def forward(self, input):
|
|
_, out, _, _ = input.unbind(1)
|
|
return out
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel2(), x)
|
|
|
|
class UnbindModel3(torch.nn.Module):
|
|
def forward(self, input):
|
|
_, out, _, _ = input.unbind(-2)
|
|
return out
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel3(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_len(self):
|
|
class LenModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return len(input.unbind()) + input
|
|
|
|
x = torch.randn(4, 5)
|
|
self.run_test(LenModel(), x, input_names=["input"], dynamic_axes={"input": {0: "seq"}},
|
|
test_with_inputs=(torch.randn(5, 5),))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_len_list(self):
|
|
class LenListModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return torch.ones(len(input.shape))
|
|
|
|
x = torch.randn(4, 5)
|
|
self.run_test(LenListModel(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_unbind_dynamic(self):
|
|
class UnbindModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.unbind()[1]
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel(), x)
|
|
|
|
class UnbindModel2(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.unbind(-1)[1]
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self.run_test(UnbindModel2(), x)
|
|
|
|
@disableScriptTest() # scripting tests run for opsets > 11. See: test_split_script
|
|
def test_split(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.split([2, 1, 2]), input.split([3, 2])[0]
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel(), x)
|
|
|
|
class SplitModel2(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel2(), x)
|
|
|
|
class SplitModel3(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.split([2, 1, 2])
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel3(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_split_script(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.split([2, 1, 2]), input.split([3, 2])[0]
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel(), x)
|
|
|
|
class SplitModel2(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel2(), x)
|
|
|
|
class SplitModel3(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.split([2, 1, 2])
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel3(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_split_size_as_list(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, input, split_sizes: List[int]):
|
|
out = []
|
|
split_list: List[torch.Tensor] = input.split(split_sizes)
|
|
|
|
for ob in split_list:
|
|
out.append(ob)
|
|
return torch.cat(out, dim=0)
|
|
|
|
x = torch.randn(6, 4, 3)
|
|
split_sizes = [torch.tensor(2), torch.tensor(4)]
|
|
self.run_test(SplitModel(), (x, split_sizes))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_split_size_with_slice(self):
|
|
class SplitModule(torch.nn.Module):
|
|
def forward(self, x, y, t):
|
|
splits = (x.size(1), y.size(1))
|
|
out, out2 = torch.split(t, splits, dim=1)
|
|
return out, out2
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 4)
|
|
t = torch.randn(2, 7)
|
|
self.run_test(SplitModule(), (x, y, t), input_names=["x", "y", "t"],
|
|
dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]})
|
|
self.run_test(SplitModule(), (x, y, t), remained_onnx_input_idx=[2])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_split_dynamic(self):
|
|
class SplitModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.split(2)[1]
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel(), x)
|
|
|
|
class SplitModel2(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input.split(2, -3)[1]
|
|
|
|
x = torch.randn(5, 4, 3)
|
|
self.run_test(SplitModel2(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_split_dynamic_axes(self):
|
|
class Split(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.split(1, dim=-1)
|
|
|
|
x = torch.randn(4, 384, 2)
|
|
input_names = ["logits"]
|
|
self.run_test(Split(), x, input_names=input_names,
|
|
dynamic_axes={input_names[0]: {0: 'batch'}})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_chunk(self):
|
|
class ChunkModel(torch.nn.Module):
|
|
def __init__(self, dim=1):
|
|
super(ChunkModel, self).__init__()
|
|
self.dim = dim
|
|
|
|
def forward(self, x):
|
|
return torch.chunk(x, 3, dim=self.dim)
|
|
|
|
model = ChunkModel()
|
|
model.eval()
|
|
model_neg_dim = ChunkModel(-1)
|
|
model_neg_dim.eval()
|
|
x = torch.randn(1, 18)
|
|
|
|
for dim_size_ in range(13, 16):
|
|
y = torch.randn(1, dim_size_)
|
|
self.run_test(model, x, test_with_inputs=[y],
|
|
input_names=["x"],
|
|
dynamic_axes={"x": {0: "batch_size", 1: "dims"}})
|
|
|
|
self.run_test(model_neg_dim, x, test_with_inputs=[y],
|
|
input_names=["x"],
|
|
dynamic_axes={"x": {0: "batch_size", 1: "dims"}})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_dynamic_chunk(self):
|
|
class ChunkModel(torch.nn.Module):
|
|
def __init__(self, dim=1):
|
|
super(ChunkModel, self).__init__()
|
|
self.dim = dim
|
|
|
|
def forward(self, x):
|
|
return torch.chunk(x, x.size(0), dim=self.dim)
|
|
|
|
model = ChunkModel()
|
|
model.eval()
|
|
model_neg_dim = ChunkModel(-1)
|
|
model_neg_dim.eval()
|
|
x = torch.randn(3, 18)
|
|
|
|
for dim_size_ in range(13, 16):
|
|
y = torch.randn(3, dim_size_)
|
|
self.run_test(model, x, test_with_inputs=[y],
|
|
input_names=["x"],
|
|
dynamic_axes={"x": {0: "batch_size", 1: "dims"}})
|
|
|
|
self.run_test(model_neg_dim, x, test_with_inputs=[y],
|
|
input_names=["x"],
|
|
dynamic_axes={"x": {0: "batch_size", 1: "dims"}})
|
|
|
|
def test_concat(self):
|
|
class ConcatModel(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
return torch.cat((x, y, z))
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
y = torch.randn(1, 4, 5)
|
|
z = torch.randn(2, 4, 5)
|
|
self.run_test(ConcatModel(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_concat_dynamic(self):
|
|
class ConcatDynamicModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.cat(x.unbind())
|
|
|
|
x = torch.randn(4, 5, 6)
|
|
self.run_test(ConcatDynamicModel(), x)
|
|
|
|
def test_stack(self):
|
|
class StackModel(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
return torch.stack((x, y, z), 1)
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
y = torch.randn(3, 4, 5)
|
|
z = torch.randn(3, 4, 5)
|
|
self.run_test(StackModel(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_stack_dynamic(self):
|
|
class StackDynamicModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.stack(x.unbind(), 1)
|
|
|
|
x = torch.randn(4, 5, 6)
|
|
self.run_test(StackDynamicModel(), x)
|
|
|
|
def test_loop_dynamic(self):
|
|
class LoopModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for i in range(x.size(2)):
|
|
x = x + i
|
|
return x
|
|
|
|
model = LoopModel()
|
|
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
|
|
self.run_test(model, inputs)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_loop_nested(self):
|
|
class NestedLoopsModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for i in range(5):
|
|
a = 0
|
|
while a < 4:
|
|
a += 1
|
|
x = x + a
|
|
return x
|
|
|
|
model = NestedLoopsModel()
|
|
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
|
|
self.run_test(model, inputs)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_loop_with_list(self):
|
|
class ListLoopModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
res = []
|
|
res1 = []
|
|
arr = x.split([3, 4, 1, 1, 2, 3, 2], 0)
|
|
res2 = torch.zeros(3, 4, dtype=torch.long)
|
|
res3 = []
|
|
res4 = []
|
|
for i in range(len(arr)):
|
|
res.append(arr[i].sum(0, False))
|
|
res1.append(arr[-1 - i].sum(0, False))
|
|
res2 += 1
|
|
res3 = res3 + [arr[i].sum(0, False)]
|
|
res4 += [arr[-1 - i].sum(0, False)]
|
|
return res, res1, res2, torch.stack(res3), torch.stack(res4)
|
|
|
|
model = ListLoopModel()
|
|
inputs = torch.randn(16)
|
|
self.run_test(model, inputs)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_loop_transpose(self):
|
|
class LoopModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
res = torch.zeros_like(x[0])
|
|
for i in range(x.size(0)):
|
|
res += x[0].transpose(0, 1)
|
|
return res
|
|
|
|
model = torch.jit.script(LoopModel())
|
|
x = torch.randn(5, 3, 3)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_loop_multi_dim(self):
|
|
class LoopMultiDimModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
for x_ in torch.flip(x.narrow(0, 0, 7), [0]):
|
|
y = x_[0][y]
|
|
return y
|
|
|
|
model = LoopMultiDimModel()
|
|
x = torch.randint(0, 5, (8, 1, 17), dtype=torch.long)
|
|
y = torch.ones(1, dtype=torch.long)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_list(self):
|
|
class ListModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
tensors = x.unbind()
|
|
res = []
|
|
res.append(tensors[0])
|
|
res.append(tensors[1])
|
|
res.pop(1)
|
|
|
|
res.insert(0, tensors[1])
|
|
res.append(tensors[2])
|
|
res += [tensors[3], tensors[4]]
|
|
res = res + [tensors[5]]
|
|
return torch.ones(len(res))
|
|
|
|
model = ListModel()
|
|
inputs = torch.randn(16, 1)
|
|
self.run_test(model, inputs)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_list_append(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
res += [torch.matmul(x[i], y)]
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(16, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_list_append_nested(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
for j in range(x.size(1)):
|
|
res += [torch.matmul(x[i][j], y)]
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(4, 4, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(14) # Need onnx::Identity of sequence in opset 14
|
|
def test_list_append_nested_2(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
res = []
|
|
res_replicate = []
|
|
for i in range(x.size(0)):
|
|
if len(res) > 2:
|
|
for j in range(x.size(1)):
|
|
res.append(x[i][j])
|
|
res_replicate.append(res[-1])
|
|
res.append(res_replicate[-1])
|
|
return res, res_replicate
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(4, 4, 3, 4)
|
|
self.run_test(model, (x, ))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_list_append_nested_mixed_dtype(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
for j in range(x.size(1)):
|
|
if i == j:
|
|
res.append(x == y)
|
|
else:
|
|
res.append(x != y)
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(4, 4, 3, 4)
|
|
y = torch.randn(3, 4)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_list_pop(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
res += [torch.matmul(x[i], y)]
|
|
res.pop()
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(16, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_list_pop_nested(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
for j in range(x.size(1)):
|
|
res += [torch.matmul(x[i][j], y)]
|
|
res.pop()
|
|
res += [torch.matmul(x[i][0], y)]
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(4, 4, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_list_del(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
res += [torch.matmul(x[i], y)]
|
|
del res[2]
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(16, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_list_del_nested(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
for j in range(x.size(1)):
|
|
res += [torch.matmul(x[i][j], y)]
|
|
del res[i]
|
|
res += [torch.matmul(x[i][0], y)]
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(4, 4, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_list_set(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
res.append(x[i])
|
|
res[y] = x[y]
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(12, 4)
|
|
y = torch.tensor(2, dtype=torch.long)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_list_idx_sum(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
indices = torch.arange(x.size(0))
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
res.append(x[i])
|
|
return res[torch.sum(indices[:y])]
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(12, 4)
|
|
y = torch.tensor(2, dtype=torch.long)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tensor_factories(self):
|
|
class TensorFactory(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.zeros(x.size()) + torch.ones(x.size())
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tensor_factories_script(self):
|
|
class TensorFactory(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.zeros(x.shape, dtype=torch.float) + torch.ones(x.shape, dtype=torch.float)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tensor_like_factories_script(self):
|
|
class TensorFactory(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
zeros = torch.zeros_like(x, dtype=torch.float, layout=torch.strided, device=torch.device("cpu"))
|
|
ones = torch.ones_like(x, dtype=torch.float, layout=torch.strided, device=torch.device("cpu"))
|
|
return zeros + ones
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(TensorFactory(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(TensorFactory(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_eye(self):
|
|
class TensorFactory(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.eye(x.size()[1], 3), torch.eye(4, 4, dtype=torch.long), \
|
|
torch.eye(x.size()[1], 2, dtype=torch.long), torch.eye(x.shape[0]), \
|
|
torch.eye(x.shape[0], dtype=torch.float64)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
another_x = torch.randn(5, 6, 7)
|
|
self.run_test(TensorFactory(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"], dynamic_axes={"input_1": [0, 1, 2]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_diagonal(self):
|
|
class DiagonalModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.diagonal(x)
|
|
|
|
x = torch.randn(2, 4, 5, 2)
|
|
# Other test inputs to test dynamic behavior
|
|
another_x = torch.randn(5, 6, 7, 8)
|
|
self.run_test(DiagonalModel(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
|
|
|
class DiagonalModelNegOffset(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.diagonal(x, offset=-1)
|
|
|
|
x = torch.randn(2, 4, 5, 2)
|
|
# Other test inputs to test dynamic behavior
|
|
another_x = torch.randn(5, 6, 7, 8)
|
|
self.run_test(DiagonalModelNegOffset(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
|
|
|
class DiagonalModelPosOffset(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.diagonal(x, offset=1)
|
|
|
|
x = torch.randn(2, 4, 5, 2)
|
|
# Other test inputs to test dynamic behavior
|
|
another_x = torch.randn(5, 6, 7, 8)
|
|
self.run_test(DiagonalModelPosOffset(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
|
|
|
class DiagonalModelWithDims(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.diagonal(x, offset=-1, dim1=1, dim2=2)
|
|
|
|
x = torch.randn(2, 4, 5, 2)
|
|
# Other test inputs to test dynamic behavior
|
|
another_x = torch.randn(5, 6, 7, 8)
|
|
self.run_test(DiagonalModelWithDims(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
|
|
|
class DiagonalModelOffsetOverrun(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.diagonal(x, offset=-2), torch.diagonal(x, offset=5)
|
|
|
|
x = torch.randn(2, 4, 5, 2)
|
|
# Other test inputs to test dynamic behavior
|
|
another_x = torch.randn(5, 6, 7, 8)
|
|
self.run_test(DiagonalModelOffsetOverrun(), x, test_with_inputs=[another_x],
|
|
input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1, 2, 3]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_inplace_zero(self):
|
|
class Zero_(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.zero_(), x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(Zero_(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_new_zeros(self):
|
|
class Zero_(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_zeros(x.shape[1:2]), x.new_zeros(x.shape[2:], dtype=torch.long)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Zero_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(Zero_(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_new_ones(self):
|
|
class OnesModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_ones(x.shape[1:2]), x.new_ones(x.shape[2:], dtype=torch.long)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(OnesModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(OnesModel(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # torch.zeros/torch.ones with size tensor of dim != 0 not scriptable.
|
|
def test_zeros_ones_with_tensor_input(self):
|
|
class ZeroAndOnes(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.zeros(x, 1), torch.ones(x, 1)
|
|
|
|
x = torch.tensor([2])
|
|
self.run_test(ZeroAndOnes(), (x, ))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tolist(self):
|
|
class List(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
res: List[int] = input.tolist()
|
|
return res
|
|
|
|
self.run_test(List(), (torch.randint(100, (1,)),))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_list_pass(self):
|
|
class Slice(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.new_zeros(x.shape[2:] + y.shape[1:])
|
|
|
|
x = torch.randn(2, 3, 4, 5)
|
|
y = torch.randn(1, 2, 3, 4)
|
|
self.run_test(Slice(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]})
|
|
self.run_test(Slice(), (x, y), remained_onnx_input_idx=[])
|
|
|
|
class Size(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.new_zeros(x.shape + y.shape)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(1, 2, 3)
|
|
self.run_test(Size(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]})
|
|
self.run_test(Size(), (x, y), remained_onnx_input_idx=[])
|
|
|
|
class Array(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
arr1 = [x.shape[0], x.shape[1], 2]
|
|
arr2 = [y.shape[0], y.shape[1]]
|
|
return x.new_zeros(arr1 + arr2)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(1, 2, 3)
|
|
self.run_test(Array(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]})
|
|
self.run_test(Array(), (x, y), remained_onnx_input_idx=[])
|
|
|
|
class List(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
l1 = list(x.shape)
|
|
l2 = list(y.shape)
|
|
return x.new_zeros(l1 + l2)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(1, 2, 3)
|
|
self.run_test(List(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]})
|
|
self.run_test(List(), (x, y), remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_new_empty(self):
|
|
class Emtpy(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_empty(x.shape[0]).fill_(0), x.new_empty(x.shape[0], dtype=torch.long) * 0
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Emtpy(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(Emtpy(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_new_full(self):
|
|
class Full(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.new_full(x.shape[1:2], 5), x.new_full(x.shape[0:1], 1.3, dtype=torch.long)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Full(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(Full(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_inplace_list(self):
|
|
class Arithmetic(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
return torch.cat([x.add_(3), y.fill_(0)])
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
self.run_test(Arithmetic(), (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1], "y": [0, 1]})
|
|
self.run_test(Arithmetic(), (x, y), remained_onnx_input_idx=[0])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_inplace_fill(self):
|
|
class Fill_(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.fill_(3), x
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(Fill_(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(Fill_(), x, remained_onnx_input_idx=[])
|
|
|
|
def test_inplace_arithmetic(self):
|
|
class Arithmetic(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
x.add_(3)
|
|
y.mul_(x)
|
|
return x, y
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(2, 3, 4)
|
|
self.run_test(Arithmetic(), (x, y))
|
|
|
|
def test_inplace_arithmetic_half(self):
|
|
class InplaceAddModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.add_(y)
|
|
|
|
class InplaceMulModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.mul_(y)
|
|
|
|
x = torch.randn(2, 2, dtype=torch.half)
|
|
y = torch.randn(2, 2, dtype=torch.float)
|
|
self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2)
|
|
self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_inplace_with_loop(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
a = torch.ones(12,)
|
|
for i in range(10):
|
|
a.add_(torch.ones(12,))
|
|
return a + x
|
|
|
|
m = M()
|
|
x = torch.randn(12,)
|
|
self.run_test(torch.jit.script(M()), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_inplace_with_loop_2(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
_bias = torch.ones(12,)
|
|
a = torch.ones(12,) # used in loop, altered.
|
|
a_ref = a # not used in loop, should be altered.
|
|
b = x.clone() # used in loop, not be altered.
|
|
b_ref = b # not used in loop, should not be altered.
|
|
for i in range(10):
|
|
if i == 3:
|
|
for j in range(5):
|
|
a += _bias
|
|
_bias.add_(torch.ones(12,))
|
|
b = b + torch.ones(12,)
|
|
|
|
_bias.add_(torch.ones(12,))
|
|
a += _bias
|
|
# TODO: value for a_ref is incorrect.
|
|
# a_ref += torch.ones(12,)
|
|
b_ref += torch.ones(12,)
|
|
return _bias + x, a, b, b_ref
|
|
|
|
m = M()
|
|
x = torch.zeros(12,)
|
|
self.run_test(torch.jit.script(M()), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_inplace_attr_with_loop(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._bias = torch.arange(12,)
|
|
|
|
def forward(self, x):
|
|
self._bias = torch.arange(12,)
|
|
for i in range(10):
|
|
if i == 3:
|
|
for j in range(5):
|
|
self._bias += torch.arange(12,)
|
|
return self._bias + x
|
|
|
|
m = M()
|
|
x = torch.zeros(12,)
|
|
self.run_test(torch.jit.script(M()), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_inplace_attr_copy_with_loop(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._bias = torch.arange(12,)
|
|
|
|
def forward(self, x):
|
|
self._bias = torch.arange(12,)
|
|
for i in range(10):
|
|
if i == 3:
|
|
for j in range(5):
|
|
self._bias.copy_(torch.arange(12,))
|
|
self._bias.copy_(self._bias + torch.arange(12,))
|
|
|
|
self._bias.copy_(self._bias + torch.arange(12,))
|
|
return self._bias + x
|
|
|
|
m = M()
|
|
x = torch.zeros(12,)
|
|
self.run_test(torch.jit.script(M()), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(14) # Need onnx::Identity of sequence in opset 14
|
|
def test_inplace_sequence_with_loop(self):
|
|
class M(torch.nn.Module):
|
|
def process(self, beam_hyps: List[torch.Tensor], done: torch.Tensor, x):
|
|
batch_size = x.shape[0]
|
|
for i in range(batch_size):
|
|
if done[i]:
|
|
continue
|
|
|
|
beam_idx = 0
|
|
for _, token in enumerate(x[i]):
|
|
beam_hyps.append(token)
|
|
beam_idx += 1
|
|
|
|
if beam_idx == 6:
|
|
break
|
|
|
|
done[i] = len(beam_hyps) > 4
|
|
|
|
return beam_hyps, done
|
|
|
|
def forward(self, x):
|
|
beam_hyps: List[torch.Tensor] = []
|
|
batch_size = x.shape[0]
|
|
cur_len = 0
|
|
max_len = x.shape[1]
|
|
done = torch.zeros(batch_size, dtype=torch.bool)
|
|
while cur_len < max_len:
|
|
beam_hyps, done = self.process(beam_hyps, done, x[:, 0, :])
|
|
cur_len = cur_len + 1
|
|
|
|
return beam_hyps
|
|
|
|
m = torch.jit.script(M())
|
|
x = torch.randn(8, 4, 3)
|
|
self.run_test(torch.jit.script(M()), (x))
|
|
|
|
|
|
@disableScriptTest() # Sort with dynamic dim not supported in ONNX
|
|
def test_sort(self):
|
|
class SortModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
out = []
|
|
for i in range(-2, 2):
|
|
out.append(torch.sort(x, dim=i, descending=True))
|
|
return out
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(SortModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest() # Sort with dynamic dim not supported in ONNX
|
|
def test_sort_ascending(self):
|
|
class SortModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
out = []
|
|
for i in range(-2, 2):
|
|
out.append(torch.sort(x, dim=i, descending=False))
|
|
return out
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(SortModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_masked_fill(self):
|
|
class MaskedFillModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
|
|
return x.masked_fill(mask, 2)
|
|
|
|
x = torch.zeros(4, 2, 3, requires_grad=True)
|
|
self.run_test(MaskedFillModel(), x)
|
|
|
|
class MaskedFillModel2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.masked_fill(x > 3, -1)
|
|
|
|
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
|
|
self.run_test(MaskedFillModel2(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_masked_fill_inplace(self):
|
|
|
|
class MaskedFillModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8)
|
|
x.masked_fill_(mask, 2)
|
|
return x
|
|
|
|
x = torch.zeros(4, 2, 3, requires_grad=True)
|
|
self.run_test(MaskedFillModel(), x)
|
|
|
|
class MaskedFillModel2(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x.masked_fill_(x > 3, -1)
|
|
return x
|
|
|
|
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
|
|
self.run_test(MaskedFillModel2(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_masked_scatter(self):
|
|
class MaskedScatterModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.masked_scatter(x, x.ge(0.5), torch.ones(100, 100) * 5)
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(MaskedScatterModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_masked_select(self):
|
|
class MaskedSelectModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.masked_select(x, x.ge(0.5))
|
|
|
|
x = torch.randn(3, 4, 5, requires_grad=True)
|
|
self.run_test(MaskedSelectModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_to_masked_fill(self):
|
|
class MaskedFillModel(torch.nn.Module):
|
|
def forward(self, input_mask, some_const):
|
|
mask = input_mask.clone()
|
|
mask[mask != some_const] = 1
|
|
mask[mask == some_const] = 0
|
|
return mask
|
|
|
|
mask = torch.randn(2, 2, 2, requires_grad=True)
|
|
constant = torch.tensor(5, dtype=torch.float)
|
|
self.run_test(MaskedFillModel(), (mask, constant))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_to_masked_scatter(self):
|
|
class MaskedScatterModel(torch.nn.Module):
|
|
def forward(self, input_mask, some_const):
|
|
mask = input_mask.clone()
|
|
mask[mask != some_const] = torch.ones(8)
|
|
return mask
|
|
|
|
mask = torch.randn(2, 2, 2, requires_grad=True)
|
|
constant = torch.tensor(5, dtype=torch.float)
|
|
self.run_test(MaskedScatterModel(), (mask, constant))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_pixel_shuffle(self):
|
|
class PixelShuffle(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.pixel_shuffle(x, upscale_factor=2)
|
|
|
|
x = torch.randn(2, 16, 4, 3, requires_grad=True)
|
|
y = torch.randn(4, 32, 8, 4, requires_grad=True)
|
|
self.run_test(PixelShuffle(), x)
|
|
self.run_test(PixelShuffle(), x, input_names=["x"],
|
|
dynamic_axes={"x": [0, 1, 2, 3]},
|
|
test_with_inputs=[y])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_pixel_unshuffle(self):
|
|
class PixelUnshuffle(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.pixel_unshuffle(x, downscale_factor=2)
|
|
|
|
x = torch.randn(2, 16, 4, 6, requires_grad=True)
|
|
y = torch.randn(4, 32, 8, 4, requires_grad=True)
|
|
self.run_test(PixelUnshuffle(), x)
|
|
self.run_test(PixelUnshuffle(), x, input_names=["x"],
|
|
dynamic_axes={"x": [0, 1, 2, 3]},
|
|
test_with_inputs=[y])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_reciprocal(self):
|
|
class ReciprocalModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.reciprocal(x)
|
|
|
|
model = ReciprocalModel()
|
|
x = torch.tensor([2, 4])
|
|
self.run_test(model, x.to(torch.long))
|
|
self.run_test(model, x.to(torch.float))
|
|
self.run_test(model, x.to(torch.double))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_scalar_type(self):
|
|
class ArithmeticModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.size(0) * 2 * x, 2 - x
|
|
|
|
x = torch.ones(2, 3, dtype=torch.float32)
|
|
self.run_test(ArithmeticModel(), x)
|
|
|
|
|
|
class ComparisonModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
a = torch.tensor([12.0])
|
|
return x.lt(1.5) & y.le(2) & x.le(1), x.gt(y), x.lt(y), a.ge(x.size(0))
|
|
|
|
x = torch.ones(2, 3, dtype=torch.int32)
|
|
y = torch.ones(2, 3, dtype=torch.float32)
|
|
self.run_test(ComparisonModel(), (x, y))
|
|
|
|
class MatMulModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return (torch.mm(x, x) + x + torch.mm(x, x) + x)
|
|
|
|
x = torch.ones(3, 3)
|
|
self.run_test(MatMulModel(), x)
|
|
|
|
class AddMMModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.mm(x, x) + x
|
|
|
|
x = torch.ones(3, 3)
|
|
self.run_test(AddMMModel(), x)
|
|
|
|
class FullModel(torch.nn.Module):
|
|
# add is used for exporting full
|
|
def forward(self, x):
|
|
return torch.full((3, 4), x)
|
|
x = torch.tensor(12.)
|
|
self.run_test(FullModel(), x)
|
|
|
|
class CatModel(torch.nn.Module):
|
|
def forward(self, fp16, fp32):
|
|
return torch.cat([fp16, fp32])
|
|
fp16 = torch.Tensor([0.5])
|
|
fp16 = fp16.half()
|
|
fp32 = torch.Tensor([1.5])
|
|
self.run_test(CatModel(), (fp16, fp32))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_full_like(self):
|
|
class FullLikeModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.full_like(x, 1.3, dtype=torch.int)
|
|
|
|
x = torch.tensor(12)
|
|
self.run_test(FullLikeModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_full_like_value(self):
|
|
class FullLikeModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = y + 2
|
|
return torch.full_like(x, out)
|
|
|
|
x = torch.tensor(12)
|
|
y = torch.tensor(2)
|
|
self.run_test(FullLikeModel(), (x, y))
|
|
|
|
def test_l1_norm(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p=1, dim=-1, keepdim=False)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_l2_norm(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p=2, dim=-2, keepdim=False)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_frobenius_norm(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p="fro", dim=0, keepdim=False)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_frobenius_norm_keepdim(self):
|
|
class NormModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.norm(x, p="fro", dim=(0, 1), keepdim=True)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(NormModel(), x)
|
|
|
|
def test_unfold(self):
|
|
class UnfoldModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.unfold(dimension=2, size=2, step=2)
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
y = torch.randn(2, 1, 3, requires_grad=True)
|
|
self.run_test(UnfoldModel(), x,
|
|
dynamic_axes={"x": [0, 1]},
|
|
input_names=["x"],
|
|
test_with_inputs=[y])
|
|
|
|
def test_unfold_infer_shape(self):
|
|
class UnfoldModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(UnfoldModule, self).__init__()
|
|
self.conv = torch.nn.Conv1d(3, 1, 3, stride=2)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x.unfold(dimension=2, size=2, step=2)
|
|
|
|
x = torch.randn(32, 3, 64)
|
|
self.run_test(UnfoldModule(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_unfold_dynamic_inputs(self):
|
|
class UnfoldModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.unfold(dimension=2, size=x.shape[1], step=x.shape[1] - 1)
|
|
|
|
x = torch.randn(4, 2, 4, requires_grad=True)
|
|
self.run_test(UnfoldModel(), x)
|
|
|
|
class UnfoldModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.unfold(dimension=2, size=x.shape[1], step=1)
|
|
|
|
x = torch.randn(4, 2, 4, requires_grad=True)
|
|
self.run_test(UnfoldModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9.
|
|
def test_mv(self):
|
|
class MatmulModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.mv(input, other)
|
|
|
|
x = torch.randn(4, 5, requires_grad=True)
|
|
y = torch.randn(5, requires_grad=True)
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
x = torch.randint(10, (4, 5))
|
|
y = torch.randint(10, (5, ))
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9.
|
|
def test_dot(self):
|
|
class MatmulModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.dot(input, other)
|
|
|
|
x = torch.randn(5, requires_grad=True)
|
|
y = torch.randn(5, requires_grad=True)
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
x = torch.randint(10, (5, ))
|
|
y = torch.randint(10, (5, ))
|
|
self.run_test(MatmulModel(), (x, y))
|
|
|
|
@disableScriptTest() # SpectralNorm not TorchScript compatible.
|
|
def test_spectral_norm(self):
|
|
m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4))
|
|
|
|
x = torch.randn(6, 2)
|
|
self.run_test(m, (x, ))
|
|
|
|
def test_prelu(self):
|
|
class PReluModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(PReluModel, self).__init__()
|
|
self.prelu = torch.nn.PReLU()
|
|
|
|
def forward(self, x):
|
|
return self.prelu(x)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(2, 4, 5)
|
|
self.run_test(PReluModel(), x, input_names=["x"],
|
|
dynamic_axes={"x": [1, 2]},
|
|
test_with_inputs=[y])
|
|
|
|
def test_prelu_scalar(self):
|
|
x = torch.scalar_tensor(1.)
|
|
self.run_test(torch.nn.PReLU(), x, input_names=["x"])
|
|
|
|
def test_relu6(self):
|
|
class Relu6Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Relu6Model, self).__init__()
|
|
self.relu6 = torch.nn.ReLU6()
|
|
|
|
def forward(self, x):
|
|
return self.relu6(x)
|
|
|
|
x = torch.randn(2, 3, 4) * 100.0
|
|
y = torch.randn(2, 4, 5) * 100.0
|
|
self.run_test(Relu6Model(), x, input_names=['x'],
|
|
dynamic_axes={'x': [1, 2]},
|
|
test_with_inputs=[y])
|
|
|
|
def test_silu(self):
|
|
class SiLUModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(SiLUModel, self).__init__()
|
|
self.silu = torch.nn.SiLU()
|
|
|
|
def forward(self, x):
|
|
return self.silu(x)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(SiLUModel(), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(14)
|
|
def test_tril(self):
|
|
class trilModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.tril(x)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(trilModel(), (x))
|
|
|
|
class trilModelwithDiagonal(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.tril(x, diagonal=1)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(trilModelwithDiagonal(), (x))
|
|
|
|
class trilModelwithNegDiagonal(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.tril(x, diagonal=-1)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(trilModelwithNegDiagonal(), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(14)
|
|
def test_triu(self):
|
|
class triuModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.triu(x)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(triuModel(), (x))
|
|
|
|
class triuModelwithDiagonal(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.triu(x, diagonal=1)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(triuModelwithDiagonal(), (x))
|
|
|
|
class trilModelwithNegDiagonal(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.tril(x, diagonal=-1)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(trilModelwithNegDiagonal(), (x))
|
|
|
|
def test_mish(self):
|
|
class MishModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MishModel, self).__init__()
|
|
self.mish = torch.nn.Mish()
|
|
|
|
def forward(self, x):
|
|
return self.mish(x)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(MishModel(), (x))
|
|
|
|
def test_remainder(self):
|
|
class RemainderModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.remainder(input, other)
|
|
|
|
x = torch.randn(4, 2, 3)
|
|
y = torch.randn(1, 2, 1)
|
|
self.run_test(RemainderModel(), (x, y))
|
|
|
|
x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
|
|
y = torch.tensor([2], dtype=torch.long)
|
|
self.run_test(RemainderModel(), (x, y))
|
|
|
|
x = x.to(torch.float)
|
|
self.run_test(RemainderModel(), (x, y))
|
|
|
|
y = y.to(torch.float)
|
|
self.run_test(RemainderModel(), (x, y))
|
|
|
|
x = x.to(torch.int32)
|
|
self.run_test(RemainderModel(), (x, y))
|
|
|
|
def test_remainder_scalar(self):
|
|
class RemainderModel(torch.nn.Module):
|
|
def __init__(self, scalar=2.55):
|
|
super().__init__()
|
|
self.scalar = scalar
|
|
|
|
def forward(self, input):
|
|
return torch.remainder(input, self.scalar)
|
|
|
|
x = torch.randint(10, (2, 3))
|
|
self.run_test(RemainderModel(), x)
|
|
|
|
x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
|
|
self.run_test(RemainderModel(2), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_fmod(self):
|
|
class FModModel(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.fmod(input, other)
|
|
|
|
x = torch.randn(4, 2, 3)
|
|
y = torch.randn(1, 2, 1)
|
|
self.run_test(FModModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_fmod_scalar(self):
|
|
class FModModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.fmod(input, 2.55)
|
|
|
|
x = torch.randint(10, (2, 3))
|
|
self.run_test(FModModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_glu(self):
|
|
class GluModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.glu(x)
|
|
|
|
x = torch.randn(2, 4, 5, 6, requires_grad=True)
|
|
self.run_test(GluModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_gelu(self):
|
|
class GeluModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.gelu(x, approximate='none')
|
|
|
|
x = torch.randn(2, 4, 5, 6, requires_grad=True)
|
|
self.run_test(GeluModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_tanh_gelu(self):
|
|
class GeluModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.gelu(x, approximate='tanh')
|
|
|
|
x = torch.randn(2, 4, 5, 6, requires_grad=True)
|
|
self.run_test(GeluModel(), x)
|
|
|
|
def test_add_inplace(self):
|
|
class InplaceAddModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
x += 12
|
|
return x
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
|
self.run_test(InplaceAddModel(), x)
|
|
|
|
def test_addcmul(self):
|
|
class AddcmulModel(torch.nn.Module):
|
|
def forward(self, x, t1, t2):
|
|
return torch.addcmul(x, t1, t2), torch.addcmul(x, t1, t2, value=2.2)
|
|
|
|
x = torch.randn(1, 3)
|
|
t1 = torch.randn(3, 1)
|
|
t2 = torch.randn(1, 3)
|
|
self.run_test(AddcmulModel(), (x, t1, t2))
|
|
|
|
def test_rsqrt(self):
|
|
class RsqrtModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.rsqrt()
|
|
|
|
x = torch.randn(4, 2, 3, requires_grad=True, dtype=torch.float64)
|
|
self.run_test(RsqrtModel(), x)
|
|
|
|
def test_rsqrt_zeros(self):
|
|
class RsqrtModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.rsqrt()
|
|
x = torch.zeros(4, 2, 3, requires_grad=True, dtype=torch.float64)
|
|
self.run_test(RsqrtModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_unique(self):
|
|
class UniqueModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.unique(x, sorted=True, return_inverse=False, return_counts=True)
|
|
|
|
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
|
|
self.run_test(UniqueModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_unique_along_dim(self):
|
|
class UniqueModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.unique(x, dim=0, sorted=True, return_inverse=True, return_counts=False)
|
|
|
|
x = torch.tensor([1, 3, 2, 3], dtype=torch.long)
|
|
self.run_test(UniqueModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_cumsum(self):
|
|
class CumSum(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.cumsum(input, dim=0)
|
|
x = torch.randn(2, 3, 4)
|
|
model = CumSum()
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_cumsum_with_cast(self):
|
|
class CumSum(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.cumsum(input, dim=0, dtype=torch.float32)
|
|
|
|
model = CumSum()
|
|
x = torch.tensor([2, 3, 4], dtype=torch.int32)
|
|
self.run_test(model, x)
|
|
x = torch.tensor([False, True, True])
|
|
self.run_test(model, x)
|
|
|
|
@disableScriptTest() # error in propagate as assign input shape
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_embedding_bag(self):
|
|
model = torch.nn.EmbeddingBag(10, 5, mode="sum", scale_grad_by_freq=True)
|
|
input = torch.randint(10, (7,))
|
|
offset = torch.tensor([0, 2, 5, 6])
|
|
self.run_test(model, (input, offset))
|
|
|
|
model = torch.nn.EmbeddingBag(10, 5, mode="sum", include_last_offset=True)
|
|
input = torch.randint(10, (7,))
|
|
offset = torch.tensor([0, 2, 5, 6])
|
|
self.run_test(model, (input, offset))
|
|
|
|
model = torch.nn.EmbeddingBag(10, 5, mode="max")
|
|
input = torch.randint(10, (7, 5))
|
|
self.run_test(model, (input))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_embedding_bag_1d_per_sample_weights(self):
|
|
class EmbeddingModel(torch.nn.Module):
|
|
def forward(self, embedding_matrix, input, offset, weights):
|
|
return torch.nn.functional.embedding_bag(input, embedding_matrix, offsets=offset,
|
|
mode="sum", per_sample_weights=weights)
|
|
|
|
model = EmbeddingModel()
|
|
x = torch.randint(7, (6,))
|
|
w = torch.randn(6, )
|
|
offset = torch.tensor([0, 2, 5])
|
|
embedding_matrix = torch.rand(10, 15)
|
|
self.run_test(model, (embedding_matrix, x, offset, w))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_embedding_bag_2d_per_sample_weights(self):
|
|
class EmbeddingModel(torch.nn.Module):
|
|
def forward(self, embedding_matrix, input, weights):
|
|
return torch.nn.functional.embedding_bag(input, embedding_matrix,
|
|
mode="sum", per_sample_weights=weights)
|
|
|
|
embedding_matrix = torch.rand(10, 15)
|
|
model = EmbeddingModel()
|
|
x = torch.randint(7, (2, 3))
|
|
w = torch.randn(2, 3)
|
|
|
|
x2 = torch.randint(7, (4, 3))
|
|
w2 = torch.randn(4, 3)
|
|
self.run_test(model, (embedding_matrix, x, w),
|
|
input_names=['embed', 'x', 'w'], dynamic_axes={'x': [0], 'w': [0]},
|
|
test_with_inputs=[(embedding_matrix, x2, w2)])
|
|
|
|
@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@unittest.skip("Due to ONNX Loop shape inference issue. "
|
|
"https://msdata.visualstudio.com/Vienna/_workitems/edit/1352001")
|
|
def test_embedding_bag_dynamic_input(self):
|
|
class EmbeddingModel1D(torch.nn.Module):
|
|
def forward(self, embedding_matrix, input, weights, offsets):
|
|
return torch.nn.functional.embedding_bag(input, embedding_matrix, offsets=offsets,
|
|
mode="sum", per_sample_weights=weights)
|
|
|
|
model = EmbeddingModel1D()
|
|
x = torch.randint(7, (6,))
|
|
w = torch.randn(6, )
|
|
offsets = torch.tensor([0, 2, 5], dtype=torch.long)
|
|
embedding_matrix = torch.rand(10, 15)
|
|
x2 = torch.randint(7, (2,))
|
|
w2 = torch.randn(2, )
|
|
embedding_matrix2 = torch.rand(12, 25)
|
|
offsets2 = torch.tensor([0, ], dtype=torch.long)
|
|
self.run_test(model, (embedding_matrix, x, w, offsets),
|
|
test_with_inputs=[(embedding_matrix2, x2, w2, offsets2)],
|
|
input_names=["embedding_matrix", "x", "offsets", "w"],
|
|
dynamic_axes={"embedding_matrix": [0, 1], "x": [0], "offsets": [0], "w": [0]})
|
|
|
|
class EmbeddingModel2D(torch.nn.Module):
|
|
def forward(self, embedding_matrix, input, weights):
|
|
return torch.nn.functional.embedding_bag(input, embedding_matrix,
|
|
mode="sum", per_sample_weights=weights)
|
|
|
|
model = EmbeddingModel2D()
|
|
x = torch.randint(7, (2, 3))
|
|
w = torch.randn(2, 3)
|
|
embedding_matrix = torch.rand(10, 15)
|
|
x2 = torch.randint(7, (3, 5))
|
|
w2 = torch.randn(3, 5)
|
|
embedding_matrix2 = torch.rand(12, 25)
|
|
self.run_test(model, (embedding_matrix, x, w),
|
|
test_with_inputs=[(embedding_matrix2, x2, w2)],
|
|
input_names=["embedding_matrix", "x", "w"],
|
|
dynamic_axes={"embedding_matrix": [0, 1], "x": [0, 1], "w": [0, 1]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_meshgrid(self):
|
|
class Meshgrid(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
output1, output2, output3 = torch.meshgrid(x, y, z)
|
|
return output1, output2, output3
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = torch.zeros(4, requires_grad=True)
|
|
z = torch.randn(5, requires_grad=True)
|
|
self.run_test(Meshgrid(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(8)
|
|
def test_meshgrid_scalar(self):
|
|
class Meshgrid(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
output1, output2, output3 = torch.meshgrid(x, y, z)
|
|
return output1, output2, output3
|
|
|
|
x = torch.ones(3, requires_grad=True)
|
|
y = torch.zeros(4, requires_grad=True)
|
|
z = torch.tensor(2.0)
|
|
self.run_test(Meshgrid(), (x, y, z))
|
|
|
|
def test_baddbmm(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input, batch1, batch2):
|
|
return torch.baddbmm(input, batch1, batch2, alpha=torch.tensor(5), beta=3.5)
|
|
x = torch.randn(10, 3, 5)
|
|
batch1 = torch.randn(10, 3, 4)
|
|
batch2 = torch.randn(10, 4, 5)
|
|
model = MyModule()
|
|
self.run_test(model, (x, batch1, batch2))
|
|
|
|
def test_baddbmm_dynamic(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input, batch1, batch2, alpha, beta):
|
|
return torch.baddbmm(input, batch1, batch2, alpha=alpha, beta=beta)
|
|
x = torch.randn(10, 3, 5)
|
|
batch1 = torch.randn(10, 3, 4)
|
|
batch2 = torch.randn(10, 4, 5)
|
|
alpha = torch.tensor(5)
|
|
beta = torch.tensor(3.5)
|
|
model = MyModule()
|
|
self.run_test(model, (x, batch1, batch2, alpha, beta))
|
|
|
|
def test_numel(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.numel() * input
|
|
|
|
x = torch.randn(2, 3, 5)
|
|
x2 = torch.randn(4, 5, 6)
|
|
model = MyModule()
|
|
self.run_test(model, (x,),
|
|
input_names=['x'], dynamic_axes={'x': [0, 1, 2]},
|
|
test_with_inputs=[(x2,)])
|
|
|
|
def test_numel_empty(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input.numel() * input
|
|
|
|
x = torch.randn(0)
|
|
x2 = torch.randn(4)
|
|
model = MyModule()
|
|
self.run_test(model, (x,),
|
|
input_names=['x'], dynamic_axes={'x': [0]},
|
|
test_with_inputs=[(x2,)])
|
|
|
|
def test_dtype(self):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input, other):
|
|
return input.to(dtype=other.dtype) + other
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
self.run_test(MyModel(), (x, y))
|
|
|
|
def test_dtype_eq(self):
|
|
class MyModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input, other):
|
|
if input.dtype == other.dtype:
|
|
return input + other
|
|
return input
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
self.run_test(MyModel(), (x, y))
|
|
|
|
def test_cast_to(self):
|
|
class MyModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input, other):
|
|
return input.to(other) + other
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.tensor([1], dtype=torch.int64)
|
|
model = MyModule()
|
|
self.run_test(model, (x, y))
|
|
|
|
def test_cast_to_bool(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.cat((input.to(other), other), 0)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.zeros([2, 3, 4], dtype=torch.bool)
|
|
model = MyModule()
|
|
self.run_test(model, (x, y))
|
|
|
|
# ONNX supports bfloat16 for opsets >= 13
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_cast_type_as_with_bfloat16(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.ones((3, 4), dtype=torch.bfloat16)
|
|
x = x.type_as(y)
|
|
return x.to(dtype=torch.float16)
|
|
|
|
x = torch.ones(3, 4, dtype=torch.float16)
|
|
model = MyModule()
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_type_as(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.tensor([1.0])
|
|
return x.type_as(y)
|
|
|
|
a = torch.tensor([True, False], dtype=torch.bool)
|
|
b = torch.randn(3, 4, dtype=torch.double)
|
|
c = torch.ones((2, 2), dtype=torch.int64)
|
|
model = MyModule()
|
|
self.run_test(model, a)
|
|
self.run_test(model, b)
|
|
self.run_test(model, c)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_ones_bool(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, input):
|
|
true = torch.ones(input.shape, dtype=torch.bool)
|
|
return input.to(true) & true
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
model = MyModule()
|
|
self.run_test(model, x)
|
|
|
|
def test_log(self):
|
|
class Log(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.log(input)
|
|
x = torch.rand(2, 3, 4)
|
|
model = Log()
|
|
self.run_test(model, x)
|
|
|
|
def test_log1p(self):
|
|
class Log1p(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.log1p(input)
|
|
x = torch.rand(2, 3, 4)
|
|
model = Log1p()
|
|
self.run_test(model, x)
|
|
|
|
def test_log10(self):
|
|
class Log10(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.log10(input)
|
|
x = torch.rand(2, 3, 4)
|
|
model = Log10()
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_round(self):
|
|
class Round(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.round(x)
|
|
|
|
x = torch.tensor([0.9920, -1.0362, -1.5000, 3.5000], requires_grad=True)
|
|
self.run_test(Round(), x)
|
|
|
|
def test_constant_pad(self):
|
|
model = torch.nn.ConstantPad1d(2, 3.5)
|
|
x = torch.randn(2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5)
|
|
x = torch.randn(2, 2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
# Dynamic padding is added in opset 11
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_pad_types(self):
|
|
# Test for different pad integer types
|
|
class Pad(torch.nn.Module):
|
|
def forward(self, x, pad: List[int]):
|
|
return torch.nn.functional.pad(x, pad)
|
|
|
|
x = torch.randn(2, 2, 4, 4)
|
|
y = pad = [2, 4]
|
|
self.run_test(Pad(), (x, y))
|
|
|
|
y = pad = [torch.tensor(2, dtype=torch.int64), torch.tensor(4, dtype=torch.int64)]
|
|
self.run_test(Pad(), (x, y))
|
|
|
|
|
|
@skipIfUnsupportedMaxOpsetVersion(10)
|
|
@disableScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script
|
|
def test_unsupported_pad(self):
|
|
class Pad(torch.nn.Module):
|
|
def forward(self, x, pad: List[int]):
|
|
return torch.nn.functional.pad(x, pad)
|
|
|
|
x = torch.randn(2, 2, 4, 4)
|
|
y = [2, 4]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, ("Unsupported: ONNX export of Pad.*" +
|
|
"The sizes of the padding must be constant")):
|
|
self.run_test(Pad(), (x, y))
|
|
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_if_fold(self):
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.dim() == 2:
|
|
y = y + 4
|
|
y = y + 2
|
|
else:
|
|
y = y - 1
|
|
return y
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), x)
|
|
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.numel() > 1:
|
|
y = y + 4
|
|
else:
|
|
y = y + 2
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), x)
|
|
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.dim() != 3:
|
|
y = y + 4
|
|
y = y + 2
|
|
else:
|
|
return y
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), x)
|
|
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.dim() >= 1:
|
|
y = y + 4
|
|
else:
|
|
y = y - 1
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), x)
|
|
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.dim() <= 1:
|
|
y = y + 4
|
|
else:
|
|
y = y + 2
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), x)
|
|
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.dim() < 3 and y.dtype == torch.int:
|
|
y = y + 4
|
|
y = y + 2
|
|
else:
|
|
return y
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), x)
|
|
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.dim() == 3 and y.dtype == torch.int:
|
|
y = y + 4
|
|
y = y + 2
|
|
else:
|
|
y = y + 1
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), x)
|
|
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.numel() != 0 and y.dim() == 2:
|
|
y = y + 4
|
|
y = y + 2
|
|
else:
|
|
return y
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), x)
|
|
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
if x.numel() == y.numel():
|
|
y = x + y
|
|
else:
|
|
y = y - x
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
y = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), (x, y))
|
|
|
|
class IfFoldModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
if x.numel() != y.numel():
|
|
y = x + y
|
|
else:
|
|
y = y - x
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
y = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(IfFoldModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_uninitialized(self):
|
|
class UninitializedModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.shape[1] < 5:
|
|
if y.size(0) == 1:
|
|
y = y + 4
|
|
else:
|
|
return y
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(UninitializedModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_uninitialized_dynamic(self):
|
|
class UninitializedModel(torch.nn.Module):
|
|
def forward(self, y):
|
|
if y.shape[1] < 5:
|
|
if y.size(0) == 1:
|
|
y = y + 4
|
|
else:
|
|
return y
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
y = torch.ones((6, 7), dtype=torch.int)
|
|
self.run_test(UninitializedModel(), x, test_with_inputs=[y],
|
|
input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1]})
|
|
|
|
# onnx::Identity of sequence supported for ONNX opset >= 14
|
|
@skipIfUnsupportedMinOpsetVersion(14)
|
|
def test_uninitialized_tensorList(self):
|
|
class UninitializedTensorListModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
if x[0].shape[0] < 5:
|
|
if x.size(0) == 1:
|
|
x = x + 4
|
|
else:
|
|
return [x]
|
|
return [x]
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(torch.jit.script(UninitializedTensorListModel()), x)
|
|
|
|
# onnx::Identity of sequence supported for ONNX opset >= 14
|
|
@skipIfUnsupportedMinOpsetVersion(14)
|
|
def test_uninitialized_tensorList_dynamic(self):
|
|
class UninitializedTensorListModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
if x[0].shape[0] < 5:
|
|
if x.size(0) == 1:
|
|
x += x
|
|
else:
|
|
return list(x)
|
|
return list(x)
|
|
|
|
x = torch.ones((3, 4), dtype=torch.double)
|
|
self.run_test(torch.jit.script(UninitializedTensorListModel()), x, input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1]})
|
|
|
|
# onnx::Identity of sequence supported for ONNX opset >= 14
|
|
@skipIfUnsupportedMinOpsetVersion(14)
|
|
def test_uninitialized_intList(self):
|
|
class UninitializedListModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = list(range(x.size(0)))
|
|
if y[0] < 5:
|
|
# if x.size(0) != 3, ORT will throw type error.
|
|
if x.size(0) == 3:
|
|
y.append(10)
|
|
else:
|
|
return y
|
|
return y
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(torch.jit.script(UninitializedListModel()), x, input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1]})
|
|
|
|
# onnx::Identity of sequence supported for ONNX opset >= 14
|
|
@skipIfUnsupportedMinOpsetVersion(14)
|
|
def test_uninitialized_tensorList_shape(self):
|
|
class UninitializedModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
if x.shape[1] < 5:
|
|
if x.size(0) == 1:
|
|
x = x + 4
|
|
else:
|
|
x_list = list(x)
|
|
x_list.append(x)
|
|
return x_list
|
|
return [x, x]
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
y = torch.ones((4, 6), dtype=torch.int)
|
|
self.run_test(torch.jit.script(UninitializedModel()), x, test_with_inputs=[y],
|
|
input_names=["input_1"],
|
|
dynamic_axes={"input_1": [0, 1]})
|
|
|
|
# Sequence type as loop-carried dependencies only supported for ONNX opset >= 13
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_sequance_loopcarried(self):
|
|
class SequanceLoopModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
outputs = []
|
|
for i in range(3):
|
|
outputs += [x]
|
|
return torch.stack(outputs).transpose(0, 1)
|
|
|
|
x = torch.ones((3, 4), dtype=torch.int)
|
|
self.run_test(torch.jit.script(SequanceLoopModel()), x)
|
|
|
|
def test_reflection_pad(self):
|
|
model = torch.nn.ReflectionPad1d(2)
|
|
x = torch.randn(2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.ReflectionPad2d((3, 0, 2, 1))
|
|
x = torch.randn(2, 2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
def test_replication_pad(self):
|
|
model = torch.nn.ReplicationPad1d(2)
|
|
x = torch.randn(2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
model = torch.nn.ReplicationPad2d((3, 0, 2, 1))
|
|
x = torch.randn(2, 2, 4, 4)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_im2col(self):
|
|
class Unfold(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.nn.functional.unfold(input, kernel_size=(10, 15), dilation=2, padding=5, stride=3), \
|
|
torch.nn.functional.unfold(input, kernel_size=(2, 2), dilation=1, padding=0, stride=3), \
|
|
torch.nn.functional.unfold(input, kernel_size=(1, 1), dilation=5, padding=2, stride=3)
|
|
|
|
x = torch.rand(1, 1, 200, 100)
|
|
self.run_test(Unfold(), x)
|
|
|
|
@skipIfNoLapack
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_det(self):
|
|
class Det(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.linalg.det(x)
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
self.run_test(Det(), x)
|
|
|
|
def test_linalg_norm(self):
|
|
class LinalgSingleDimModel(torch.nn.Module):
|
|
def __init__(self, ord_val):
|
|
super(LinalgSingleDimModel, self).__init__()
|
|
self.ord = ord_val
|
|
|
|
def forward(self, x):
|
|
return torch.linalg.norm(x, ord=self.ord, dim=1)
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
self.run_test(LinalgSingleDimModel(None), x)
|
|
self.run_test(LinalgSingleDimModel(2), x)
|
|
self.run_test(LinalgSingleDimModel(float('inf')), x)
|
|
self.run_test(LinalgSingleDimModel(-float('inf')), x)
|
|
self.run_test(LinalgSingleDimModel(-4), x)
|
|
self.run_test(LinalgSingleDimModel(1.5), x)
|
|
|
|
class LinalgMultiDimModel(torch.nn.Module):
|
|
def __init__(self, ord_val):
|
|
super(LinalgMultiDimModel, self).__init__()
|
|
self.ord = ord_val
|
|
|
|
def forward(self, x):
|
|
return torch.linalg.norm(x, ord=self.ord, dim=(0, 2))
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
self.run_test(LinalgMultiDimModel('fro'), x)
|
|
self.run_test(LinalgMultiDimModel(float('inf')), x)
|
|
self.run_test(LinalgMultiDimModel(-float('inf')), x)
|
|
self.run_test(LinalgMultiDimModel(1), x)
|
|
self.run_test(LinalgMultiDimModel(-1), x)
|
|
|
|
class LinalgNoDimNoOrdModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.linalg.norm(x)
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
self.run_test(LinalgNoDimNoOrdModel(), x)
|
|
y = torch.randn(2, 3)
|
|
self.run_test(LinalgNoDimNoOrdModel(), y)
|
|
z = torch.randn(2)
|
|
self.run_test(LinalgNoDimNoOrdModel(), z)
|
|
|
|
class LinalgNoDim1DModel(torch.nn.Module):
|
|
def __init__(self, ord_val):
|
|
super(LinalgNoDim1DModel, self).__init__()
|
|
self.ord = ord_val
|
|
|
|
def forward(self, x):
|
|
return torch.linalg.norm(x, ord=self.ord)
|
|
|
|
x = torch.randn(2)
|
|
self.run_test(LinalgNoDim1DModel(None), x)
|
|
self.run_test(LinalgNoDim1DModel(2), x)
|
|
self.run_test(LinalgNoDim1DModel(float('inf')), x)
|
|
self.run_test(LinalgNoDim1DModel(-float('inf')), x)
|
|
self.run_test(LinalgNoDim1DModel(-4), x)
|
|
self.run_test(LinalgNoDim1DModel(1.5), x)
|
|
|
|
class LinalgNoDim2DModel(torch.nn.Module):
|
|
def __init__(self, ord_val):
|
|
super(LinalgNoDim2DModel, self).__init__()
|
|
self.ord = ord_val
|
|
|
|
def forward(self, x):
|
|
return torch.linalg.norm(x, ord=self.ord)
|
|
|
|
x = torch.randn(2, 3)
|
|
self.run_test(LinalgNoDim2DModel('fro'), x)
|
|
self.run_test(LinalgNoDim2DModel(float('inf')), x)
|
|
self.run_test(LinalgNoDim2DModel(-float('inf')), x)
|
|
self.run_test(LinalgNoDim2DModel(1), x)
|
|
self.run_test(LinalgNoDim2DModel(-1), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_linalg_vector_norm_zero(self):
|
|
class LinalgVectorNormModel(torch.nn.Module):
|
|
def __init__(self, ord_val):
|
|
super(LinalgVectorNormModel, self).__init__()
|
|
self.ord = ord_val
|
|
|
|
def forward(self, x):
|
|
return torch.linalg.vector_norm(x, ord=self.ord)
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
self.run_test(LinalgVectorNormModel(0), x)
|
|
|
|
def test_linalg_vector_norm(self):
|
|
class LinalgVectorNormModel(torch.nn.Module):
|
|
def __init__(self, ord_val, dim_info):
|
|
super(LinalgVectorNormModel, self).__init__()
|
|
self.ord = ord_val
|
|
self.dim, self.keepdim = dim_info
|
|
|
|
def forward(self, x):
|
|
return torch.linalg.vector_norm(x, ord=self.ord, dim=self.dim, keepdim=self.keepdim)
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
ord_options = [2, float('inf'), -float('inf'), -4, 1.5]
|
|
dim_options = [(None, False), (1, False), ((1, 2), False), ((1, 2), True)]
|
|
for ord_val in ord_options:
|
|
for dim_info in dim_options:
|
|
self.run_test(LinalgVectorNormModel(ord_val, dim_info), x)
|
|
|
|
def test_linalg_matrix_norm(self):
|
|
class LinalgMatrixNormModel(torch.nn.Module):
|
|
def __init__(self, ord_val, dim_val=(-2, -1), keepdim_val=False):
|
|
super(LinalgMatrixNormModel, self).__init__()
|
|
self.ord = ord_val
|
|
self.dim = dim_val
|
|
self.keepdim = keepdim_val
|
|
|
|
def forward(self, x):
|
|
return torch.linalg.matrix_norm(x, ord=self.ord, dim=self.dim, keepdim=self.keepdim)
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
ord_options = ['fro', float('inf'), -float('inf'), 1, -1]
|
|
for ord_val in ord_options:
|
|
self.run_test(LinalgMatrixNormModel(ord_val), x)
|
|
self.run_test(LinalgMatrixNormModel(ord_val, (0, 2)), x)
|
|
self.run_test(LinalgMatrixNormModel(ord_val, (0, 2), True), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_linalg_cross(self):
|
|
class Cross(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.linalg.cross(x, y, dim=1), \
|
|
torch.linalg.cross(x, y)
|
|
|
|
x = torch.randn(5, 3, 2, 3)
|
|
y = torch.randn(1, 3, 1, 3)
|
|
self.run_test(Cross(), input=(x, y))
|
|
|
|
# This test checks output scalar type in the ONNX graph should not be null
|
|
# https://github.com/pytorch/pytorch/issues/28607
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_trace_script(self):
|
|
@torch.jit.script
|
|
def center_slice_helper(input, h_offset):
|
|
return input[:, h_offset:]
|
|
|
|
class CenterCrop(torch.nn.Module):
|
|
def forward(self, input):
|
|
return center_slice_helper(input, torch.tensor(input.shape[1] - 1))
|
|
|
|
x = torch.randn(3, 4)
|
|
self.run_test(CenterCrop(), x)
|
|
|
|
@skipIfNoLapack
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_logdet(self):
|
|
class LogDet(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.logdet(x)
|
|
|
|
x = torch.randn(2, 3, 5, 5)
|
|
self.run_test(LogDet(), x)
|
|
|
|
def test_dim(self):
|
|
class DimModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
out = input * 2
|
|
out *= out.dim()
|
|
return out
|
|
|
|
empty_input = torch.randn(0, requires_grad=True)
|
|
multi_dim_input = torch.randn(1, 2, 3, requires_grad=True)
|
|
self.run_test(DimModel(), empty_input)
|
|
self.run_test(DimModel(), multi_dim_input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_dim_1(self):
|
|
class M(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, poses):
|
|
boxes = torch.zeros([poses.shape[0], 2, 4])
|
|
batch_boxes = []
|
|
for kp_boxes in boxes:
|
|
kp_boxes = torchvision.ops.clip_boxes_to_image(kp_boxes, (2, 3))
|
|
batch_boxes.append(kp_boxes)
|
|
return batch_boxes
|
|
|
|
dummy_inputs = torch.rand(2, 2, 3)
|
|
self.run_test(M(), (dummy_inputs, ), input_names=['x'], dynamic_axes={"x": [0]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_outer(self):
|
|
class Outer(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.outer(x, y)
|
|
|
|
x = torch.arange(1, 5)
|
|
y = torch.arange(1, 4)
|
|
self.run_test(Outer(), input=(x, y))
|
|
|
|
x = torch.arange(1, 6).to(dtype=torch.float32)
|
|
y = torch.arange(1, 4).to(dtype=torch.long)
|
|
self.run_test(Outer(), input=(x, y))
|
|
|
|
x = torch.arange(2, 5).to(dtype=torch.float32)
|
|
y = torch.arange(2, 4).to(dtype=torch.float64)
|
|
self.run_test(Outer(), input=(x, y))
|
|
|
|
x = torch.arange(3, 6).to(dtype=torch.int32)
|
|
y = torch.arange(4, 7).to(dtype=torch.long)
|
|
self.run_test(Outer(), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_einsum(self):
|
|
class EinsumModelBatchDiagonal(torch.nn.Module):
|
|
def forward(self, x):
|
|
eqn = "...ii ->...i"
|
|
return torch.einsum(eqn, x)
|
|
|
|
for x in [torch.randn(3, 5, 5), torch.randn(3, 5, 5).to(dtype=torch.bool)]:
|
|
self.run_test(EinsumModelBatchDiagonal(), input=(x,))
|
|
|
|
class EinsumModelBatchMatmul(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
eqn = "bij, bjk -> bik"
|
|
return torch.einsum(eqn, x, y)
|
|
|
|
x = torch.randn(5, 2, 3)
|
|
y = torch.randn(5, 3, 4)
|
|
self.run_test(EinsumModelBatchMatmul(), input=(x, y))
|
|
|
|
class EinsumModelInnerProd(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
eqn = "i,i"
|
|
return torch.einsum(eqn, x, y)
|
|
|
|
x = torch.randn(5)
|
|
y = torch.randn(5)
|
|
self.run_test(EinsumModelInnerProd(), input=(x, y))
|
|
|
|
class EinsumModelTranspose(torch.nn.Module):
|
|
def forward(self, x):
|
|
eqn = "ij->ji"
|
|
return torch.einsum(eqn, x)
|
|
|
|
for x in [torch.randn(3, 4), torch.randn(3, 4).to(dtype=torch.bool)]:
|
|
self.run_test(EinsumModelTranspose(), input=(x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_cosine_similarity(self):
|
|
x = torch.randn(5, 3, 2)
|
|
y = torch.randn(5, 3, 2)
|
|
self.run_test(torch.nn.CosineSimilarity(dim=2), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_pairwise_distance(self):
|
|
x = torch.randn(5, 3, 2)
|
|
y = torch.randn(5, 3, 2)
|
|
self.run_test(torch.nn.PairwiseDistance(p=2.0), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_cross(self):
|
|
class Cross(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.cross(x, y, dim=3), \
|
|
torch.cross(x, y)
|
|
|
|
x = torch.randn(5, 3, 2, 3)
|
|
y = torch.randn(5, 3, 2, 3)
|
|
self.run_test(Cross(), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_cdist(self):
|
|
class Cdist(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.cdist(x, y)
|
|
|
|
x = torch.randn(5, 3, 3)
|
|
y = torch.randn(5, 2, 3)
|
|
self.run_test(Cdist(), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_crossentropyloss(self):
|
|
for ignore_index in [-100, 1]:
|
|
x = torch.randn(3, 5)
|
|
y = torch.empty(3, dtype=torch.long).random_(5)
|
|
y[y == 1] = ignore_index
|
|
|
|
self._crossentropyloss(x, y, ignore_index)
|
|
|
|
x = torch.randn(3, 5, 2)
|
|
y = torch.empty(3, 2, dtype=torch.long).random_(5)
|
|
y[y == 1] = ignore_index
|
|
self._crossentropyloss(x, y, ignore_index)
|
|
|
|
x = torch.randn(3, 5, 2, 7)
|
|
y = torch.empty(3, 2, 7, dtype=torch.long).random_(5)
|
|
y[y == 1] = ignore_index
|
|
self._crossentropyloss(x, y, ignore_index)
|
|
|
|
def _crossentropyloss(self, x, y, ignore_index):
|
|
class CrossEntropyLossNone(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossNone, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction="none")
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossNone(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossNoneWeight(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossNoneWeight, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction="none", weight=torch.randn(5))
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction="none", weight=torch.randn(5), ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossNoneWeight(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossSum(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossSum, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction="sum")
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossSum(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossSumWeight(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossSumWeight, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction="sum", weight=torch.randn(5))
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(reduction="sum", weight=torch.randn(5), ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossSumWeight(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossMean(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossMean, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossMean(ignore_index), input=(x, y))
|
|
|
|
class CrossEntropyLossMeanWeight(torch.nn.Module):
|
|
def __init__(self, ignore_index):
|
|
super(CrossEntropyLossMeanWeight, self).__init__()
|
|
if ignore_index == -100:
|
|
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5))
|
|
else:
|
|
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5), ignore_index=ignore_index)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(CrossEntropyLossMeanWeight(ignore_index), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_kldiv_loss(self):
|
|
|
|
x = torch.randn(5)
|
|
y = torch.randn(5)
|
|
self._kldiv_loss(x, y)
|
|
|
|
x = torch.randn(2, 3, 5)
|
|
y = torch.randn(2, 3, 5)
|
|
self._kldiv_loss(x, y)
|
|
|
|
x = torch.randn(2, 3, 5, 7)
|
|
y = torch.randn(2, 3, 5, 7)
|
|
self._kldiv_loss(x, y)
|
|
|
|
def _kldiv_loss(self, x, y):
|
|
class KLDivLossNone(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossNone, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossNone(), input=(x, y))
|
|
|
|
class KLDivLossMean(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossMean, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction="mean", log_target=False)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossMean(), input=(x, y))
|
|
|
|
class KLDivLossSum(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossSum, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction="sum", log_target=True)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossSum(), input=(x, y))
|
|
|
|
class KLDivLossBatchMean(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossBatchMean, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=False)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossBatchMean(), input=(x, y))
|
|
|
|
class KLDivLossMiniBatchMean(torch.nn.Module):
|
|
def __init__(self):
|
|
super(KLDivLossMiniBatchMean, self).__init__()
|
|
self.loss = torch.nn.KLDivLoss(reduction="batchmean", size_average=False, log_target=True)
|
|
|
|
def forward(self, input, target):
|
|
return self.loss(input, target)
|
|
|
|
self.run_test(KLDivLossMiniBatchMean(), input=(x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_nllloss(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction="none")
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(2 * input), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16)
|
|
target = torch.empty(N, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_nllloss_2d_none(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction="none")
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_nllloss_2d_mean(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction="mean")
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_nllloss_2d_sum(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction="sum")
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_nllloss_2d_mean_weights(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction="mean", weight=torch.randn(C))
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_nllloss_2d_mean_ignore_index(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction="mean", ignore_index=1)
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_nllloss_dynamic_ignore_index(self):
|
|
import torch.nn.functional as F
|
|
|
|
def linear_combination(x, y, epsilon):
|
|
return epsilon * x + (1 - epsilon) * y
|
|
|
|
def reduce_loss(loss, reduction='mean'):
|
|
return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss
|
|
|
|
class LabelSmoothingCrossEntropy(torch.nn.Module):
|
|
def __init__(self, epsilon: float = 0.1, reduction='mean'):
|
|
super().__init__()
|
|
self.epsilon = epsilon
|
|
self.reduction = reduction
|
|
|
|
def forward(self, preds, target, start_position):
|
|
n = preds.size()[-1]
|
|
log_preds = F.log_softmax(preds, dim=-1)
|
|
ignore_index = start_position.size(1)
|
|
nll = F.nll_loss(log_preds, target, reduction=self.reduction, ignore_index=ignore_index)
|
|
return nll + start_position.float()
|
|
|
|
N = 5
|
|
preds = torch.randn(N, 16)
|
|
target = torch.randint(5, (N,))
|
|
start_position = torch.randint(10, (N, N))
|
|
self.run_test(LabelSmoothingCrossEntropy(), (preds, target, start_position))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_nllloss_2d_mean_ignore_index_weights(self):
|
|
class NLLModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NLLModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction="mean", weight=torch.randn(C), ignore_index=1)
|
|
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(self.conv(input)), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, 10, 10)
|
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
|
|
self.run_test(NLLModel(), (input, target))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_binary_cross_entropy_with_logits(self):
|
|
x = torch.randn(5)
|
|
y = torch.empty(5).random_(2)
|
|
self._bce_logits(x, y)
|
|
|
|
x = torch.randn(3, 4)
|
|
y = torch.empty(3, 4).random_(2)
|
|
weight = torch.tensor([3])
|
|
self._bce_logits_wegiht(x, y, weight)
|
|
|
|
x = torch.randn(3, 2, 4)
|
|
y = torch.empty(3, 2, 4).random_(2)
|
|
pos_weight = torch.empty([2, 4]).random_(2)
|
|
self._bce_logits_posweight(x, y, pos_weight)
|
|
|
|
x = torch.randn(3, 3, 4)
|
|
y = torch.empty(3, 3, 4).random_(2)
|
|
weight = torch.tensor([3])
|
|
pos_weight = torch.empty([3, 4]).random_(2)
|
|
self._bce_logits_loss_weight_posweight(x, y, weight, pos_weight)
|
|
|
|
def _bce_logits(self, x, y):
|
|
class BCEWithLogitsLossNone(torch.nn.Module):
|
|
def forward(self, input, target):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, reduction="none")
|
|
|
|
self.run_test(BCEWithLogitsLossNone(), input=(x, y))
|
|
|
|
class BCEWithLogitsLossMean(torch.nn.Module):
|
|
def forward(self, input, target):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, reduction="mean")
|
|
|
|
self.run_test(BCEWithLogitsLossMean(), input=(x, y))
|
|
|
|
class BCEWithLogitsLossSum(torch.nn.Module):
|
|
def forward(self, input, target):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, reduction="sum")
|
|
|
|
self.run_test(BCEWithLogitsLossSum(), input=(x, y))
|
|
|
|
def _bce_logits_wegiht(self, x, y, weight):
|
|
class BCEWithLogitsLossWegihtNone(torch.nn.Module):
|
|
def forward(self, input, target, weight):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction="none")
|
|
self.run_test(BCEWithLogitsLossWegihtNone(), input=(x, y, weight))
|
|
|
|
class BCEWithLogitsLossWegihtMean(torch.nn.Module):
|
|
def forward(self, input, target, weight):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction="mean")
|
|
|
|
self.run_test(BCEWithLogitsLossWegihtMean(), input=(x, y, weight))
|
|
|
|
class BCEWithLogitsLossWegihtSum(torch.nn.Module):
|
|
def forward(self, input, target, weight):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction="sum")
|
|
|
|
self.run_test(BCEWithLogitsLossWegihtSum(), input=(x, y, weight))
|
|
|
|
def _bce_logits_posweight(self, x, y, pos_weight):
|
|
class BCEWithLogitsLossPosWegihtNone(torch.nn.Module):
|
|
def forward(self, input, target, pos_weight):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, pos_weight=pos_weight, reduction="none")
|
|
self.run_test(BCEWithLogitsLossPosWegihtNone(), input=(x, y, pos_weight))
|
|
|
|
class BCEWithLogitsLossPosWegihtMean(torch.nn.Module):
|
|
def forward(self, input, target, pos_weight):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, pos_weight=pos_weight, reduction="mean")
|
|
|
|
self.run_test(BCEWithLogitsLossPosWegihtMean(), input=(x, y, pos_weight))
|
|
|
|
class BCEWithLogitsLossPosWegihtSum(torch.nn.Module):
|
|
def forward(self, input, target, pos_weight):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, pos_weight=pos_weight, reduction="sum")
|
|
|
|
self.run_test(BCEWithLogitsLossPosWegihtSum(), input=(x, y, pos_weight))
|
|
|
|
def _bce_logits_loss_weight_posweight(self, x, y, weight, pos_weight):
|
|
class BCEWithLogitsLossWeightPosweightNone(torch.nn.Module):
|
|
def forward(self, input, target, weight, pos_weight):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
|
|
pos_weight=pos_weight, reduction="none")
|
|
|
|
self.run_test(BCEWithLogitsLossWeightPosweightNone(), input=(x, y, weight, pos_weight))
|
|
|
|
class BCEWithLogitsLossWeightPosweightMean(torch.nn.Module):
|
|
def forward(self, input, target, weight, pos_weight):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
|
|
pos_weight=pos_weight, reduction="mean")
|
|
|
|
self.run_test(BCEWithLogitsLossWeightPosweightMean(), input=(x, y, weight, pos_weight))
|
|
|
|
class BCEWithLogitsLossWeightPosweightSum(torch.nn.Module):
|
|
def forward(self, input, target, weight, pos_weight):
|
|
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
|
|
pos_weight=pos_weight, reduction="sum")
|
|
|
|
self.run_test(BCEWithLogitsLossWeightPosweightSum(), input=(x, y, weight, pos_weight))
|
|
|
|
|
|
def test_torch_mm(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, mat1, mat2):
|
|
mm = torch.mm(mat1, mat2)
|
|
return mm
|
|
|
|
mat1 = torch.randn(2, 3)
|
|
mat2 = torch.randn(3, 3)
|
|
self.run_test(M(), input=(mat1, mat2))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9.
|
|
def test_where_with_bool_tensor(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, mat1, mat2):
|
|
out = torch.where(mat1 > 0, mat1, mat2)
|
|
return out
|
|
|
|
mat1 = torch.randn(2, 3)
|
|
mat2 = torch.ones(2, 3)
|
|
self.run_test(M(), input=(mat1, mat2))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9.
|
|
def test_where_with_byte_tensor(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, cond, mat1, mat2):
|
|
out = torch.where(cond, mat1, mat2)
|
|
return out
|
|
|
|
cond = torch.ones(2, 3, dtype=torch.uint8)
|
|
cond[1, 2] = 0
|
|
mat1 = torch.randn(2, 3)
|
|
mat2 = torch.ones(2, 3)
|
|
self.run_test(M(), input=(cond, mat1, mat2))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10) # ONNX IsInf op is added in opset 10.
|
|
def test_isinf(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.isinf()
|
|
|
|
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]])
|
|
self.run_test(M(), (x, ))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_isfinite(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.isfinite()
|
|
|
|
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
|
|
self.run_test(M(), (x, ))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9) # ONNX IsNaN op is added in opset 9.
|
|
def test_isnan(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.isnan()
|
|
|
|
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), float("inf")]])
|
|
self.run_test(M(), (x, ))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10) # ONNX IsNaN, IsInf op is added in opset 9, 10 respectively.
|
|
def test_nan_to_num(self):
|
|
class NoParams(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.nan_to_num()
|
|
|
|
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
|
|
xint = torch.ones((2, 4), dtype=torch.int)
|
|
xhalf = torch.ones((2, 4), dtype=torch.half)
|
|
self.run_test(NoParams(), (x, ))
|
|
self.run_test(NoParams(), (xint, ))
|
|
self.run_test(NoParams(), (xhalf, ))
|
|
|
|
class WithParams(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.nan_to_num(nan=2.3, posinf=4.5, neginf=6.7)
|
|
|
|
x = torch.tensor([[1, 2, float("inf")], [2, float("nan"), -float("inf")]])
|
|
self.run_test(WithParams(), (x, ))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_maximum_minimum(self):
|
|
class ModelWithNan(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.maximum(x, y), torch.minimum(x, y)
|
|
|
|
x = torch.tensor([-2, -2, float("nan")])
|
|
y = torch.rand(1, 3)
|
|
self.run_test(ModelWithNan(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_minimum_dtypes(self):
|
|
class MinimumModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.minimum(x, y)
|
|
|
|
x = torch.randn((5, 5), dtype=torch.float16)
|
|
y = torch.randn((5, 5), dtype=torch.float)
|
|
self.run_test(MinimumModel(), (x, y))
|
|
|
|
x = torch.randn((5, 5), dtype=torch.float16)
|
|
y = torch.randint(10, (5, 5), dtype=torch.int16)
|
|
self.run_test(MinimumModel(), (x, y))
|
|
|
|
x = torch.randint(10, (5, 5), dtype=torch.int16)
|
|
y = torch.randint(10, (5, 5), dtype=torch.int32)
|
|
self.run_test(MinimumModel(), (x, y))
|
|
|
|
x = torch.randint(10, (5, 5), dtype=torch.int)
|
|
y = torch.full_like(x, True)
|
|
self.run_test(MinimumModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_any(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.any()
|
|
|
|
x = torch.tensor([[True, False], [False, False]])
|
|
self.run_test(M(), (x, ))
|
|
|
|
class MDim(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.any(dim=1)
|
|
|
|
x = torch.rand(3, 4).bool()
|
|
self.run_test(MDim(), (x, ))
|
|
|
|
class MKeepdim(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.any(dim=1, keepdim=True)
|
|
|
|
x = torch.rand(3, 4).bool()
|
|
self.run_test(MKeepdim(), (x, ))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_all(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.all()
|
|
|
|
x = torch.tensor([[True, False], [False, False]])
|
|
self.run_test(M(), (x, ))
|
|
|
|
class MDim(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.all(dim=1)
|
|
|
|
x = torch.rand(3, 4).bool()
|
|
self.run_test(MDim(), (x, ))
|
|
|
|
class MKeepdim(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.all(dim=1, keepdim=True)
|
|
|
|
x = torch.rand(3, 4).bool()
|
|
self.run_test(MKeepdim(), (x, ))
|
|
|
|
def test_dropout(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.dropout = torch.nn.Dropout(0.3)
|
|
|
|
def forward(self, x):
|
|
dropout = self.dropout(x)
|
|
return dropout
|
|
|
|
x = torch.randn(10, 3, 53)
|
|
self.run_test(M(), (x))
|
|
|
|
def test_shape_constant_fold(self):
|
|
class ShapeModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(ShapeModule, self).__init__()
|
|
self.register_buffer("weight", torch.ones(5))
|
|
|
|
def forward(self, x):
|
|
shape = self.weight.shape[0]
|
|
return x + shape
|
|
|
|
x = torch.randn(2, 5)
|
|
self.run_test(ShapeModule(), (x,), rtol=1e-3, atol=1e-5)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_celu(self):
|
|
class Celu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Celu, self).__init__()
|
|
self.celu = torch.nn.CELU(alpha=1.0)
|
|
|
|
def forward(self, input):
|
|
return self.celu(input)
|
|
|
|
input = torch.randn(2)
|
|
self.run_test(Celu(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_celu_default(self):
|
|
class Celu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Celu, self).__init__()
|
|
self.celu = torch.nn.CELU()
|
|
|
|
def forward(self, input):
|
|
return self.celu(input)
|
|
|
|
input = torch.randn(2)
|
|
self.run_test(Celu(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_celu_alpha(self):
|
|
class Celu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Celu, self).__init__()
|
|
self.celu = torch.nn.CELU(alpha=2.)
|
|
|
|
def forward(self, input):
|
|
return self.celu(input)
|
|
|
|
input = torch.randn(2)
|
|
self.run_test(Celu(), (input,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_celu_cast(self):
|
|
class Celu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Celu, self).__init__()
|
|
self.celu = torch.nn.CELU()
|
|
|
|
def forward(self, input):
|
|
return self.celu(input)
|
|
|
|
input = torch.randn(2, 5, 7, dtype=torch.float64)
|
|
self.run_test(Celu(), (input,))
|
|
|
|
def test_lower_tuple(self):
|
|
class TupleModule(torch.nn.Module):
|
|
def forward(self, input1: Tensor, input2: Tensor, input3: Tensor) -> Tensor:
|
|
a = (input1, input2)
|
|
b = a
|
|
c = (input1, input2, input3)
|
|
for i in range(5):
|
|
d = a[0]
|
|
for j in range(2):
|
|
e, f = a
|
|
a = (d, f)
|
|
f = c[2]
|
|
if f.size(0) != input1.size(-1):
|
|
g = b[1]
|
|
b = (g, f)
|
|
else:
|
|
k = c[1:]
|
|
b = (f, k[0])
|
|
m, n = b
|
|
c = (input1, n, m)
|
|
p, q, r = c
|
|
return p + q + r
|
|
|
|
input1 = torch.randn(2)
|
|
input2 = torch.randn(2)
|
|
input3 = torch.randn(2)
|
|
self.run_test(TupleModule(), (input1, input2, input3))
|
|
|
|
def test_lower_tuple_2(self):
|
|
class TupleModule(torch.nn.Module):
|
|
def forward(self, input1: Tensor, input2: Tensor) -> Tuple[Tensor, Tensor]:
|
|
a = (input1, input2)
|
|
for x in range(5):
|
|
c, d = a
|
|
a = (c, d)
|
|
return a
|
|
|
|
input1 = torch.randn(2)
|
|
input2 = torch.randn(2)
|
|
self.run_test(TupleModule(), (input1, input2))
|
|
|
|
def test_lower_tuple_3(self):
|
|
class TupleModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
input1: Tuple[Tensor, Tensor],
|
|
input2: Tuple[Tensor, Tensor],
|
|
) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
|
|
a = input1
|
|
b = input2
|
|
for x in range(5):
|
|
c, d = a
|
|
e, f = b
|
|
if c.shape[0] == e.shape[0]:
|
|
e = e + c
|
|
else:
|
|
f = f + d
|
|
a = (e, f)
|
|
b = (c, d)
|
|
return a , b
|
|
|
|
input1 = (torch.randn(2), torch.randn(2))
|
|
input2 = (torch.randn(2), torch.randn(2))
|
|
self.run_test(TupleModule(), (input1, input2))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_where(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, cond, input, other):
|
|
return torch.where(cond, input, other)
|
|
|
|
x = torch.randint(0, 1, (2, 3, 4), dtype=torch.bool)
|
|
y = torch.randn(2, 1, 4)
|
|
z = torch.ones(2, 3, 1)
|
|
self.run_test(Model(), (x, y, z))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # scripting tests run for opsets > 11. See: test_where_condition_script
|
|
def test_where_condition(self):
|
|
class Model1(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.stack(torch.where(input > 0.5), dim=1)
|
|
|
|
x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
|
|
self.run_test(Model1(), (x))
|
|
|
|
class Model2(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.stack(torch.where(input > other), dim=1)
|
|
|
|
x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
|
|
y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
|
|
self.run_test(Model2(), (x, y))
|
|
|
|
@skipIfUnsupportedOpsetVersion([13])
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_where_condition_script(self):
|
|
class Model1(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.stack(torch.where(input > 0.5), dim=1)
|
|
|
|
x = torch.randint(0, 2, (2, 3, 4), dtype=bool)
|
|
self.run_test(Model1(), (x))
|
|
|
|
class Model2(torch.nn.Module):
|
|
def forward(self, input, other):
|
|
return torch.stack(torch.where(input > other), dim=1)
|
|
|
|
x = torch.randint(0, 1, (2, 3, 4), dtype=bool)
|
|
y = torch.randint(1, 2, (2, 3, 4), dtype=bool)
|
|
self.run_test(Model2(), (x, y))
|
|
|
|
def test_empty_branch(self):
|
|
class EmptyBranchModel(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
out = input + 1
|
|
if out.dim() > 2:
|
|
if out.dim() > 3:
|
|
out += 3
|
|
else:
|
|
pass
|
|
else:
|
|
pass
|
|
return out
|
|
|
|
x = torch.randn(1, 2, 3, requires_grad=True)
|
|
self.run_test(EmptyBranchModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_derive_index_scripting(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
j = []
|
|
for idx in range(len(x) - 1, -len(x), -2):
|
|
y = x[idx]
|
|
j += [x * y]
|
|
return j
|
|
|
|
x = torch.randn(5, 13)
|
|
self.run_test(MyModule(), x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
j = []
|
|
for idx in range(-len(x), len(x) - 1, 2):
|
|
y = x[idx]
|
|
j += [x * y]
|
|
return j
|
|
|
|
x = torch.randn(5, 13)
|
|
self.run_test(MyModule(), x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
j = []
|
|
for idx in range(len(x) - 1, -len(x), -3):
|
|
y = x[idx]
|
|
j += [x * y]
|
|
return j
|
|
|
|
self.run_test(MyModule(), x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
j = []
|
|
for idx in range(-len(x), len(x) - 1, 3):
|
|
y = x[idx]
|
|
j += [x * y]
|
|
return j
|
|
|
|
self.run_test(MyModule(), x)
|
|
|
|
@disableScriptTest() # Scripting fails for add lists for opsets < 11. Chek test_derive_index_scripting
|
|
def test_derive_index(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
j = []
|
|
for idx in range(len(x) - 1, -len(x), -2):
|
|
y = x[idx]
|
|
j += [x * y]
|
|
return j
|
|
|
|
x = torch.randn(5, 13)
|
|
self.run_test(MyModule(), x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
j = []
|
|
for idx in range(-len(x), len(x) - 1, 2):
|
|
y = x[idx]
|
|
j += [x * y]
|
|
return j
|
|
|
|
x = torch.randn(5, 13)
|
|
self.run_test(MyModule(), x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
j = []
|
|
for idx in range(len(x) - 1, -len(x), -3):
|
|
y = x[idx]
|
|
j += [x * y]
|
|
return j
|
|
|
|
self.run_test(MyModule(), x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
j = []
|
|
for idx in range(-len(x), len(x) - 1, 3):
|
|
y = x[idx]
|
|
j += [x * y]
|
|
return j
|
|
|
|
self.run_test(MyModule(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_if_transpose(self):
|
|
class IfModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x.transpose(0, 1)
|
|
if x.size(0) == 2:
|
|
return x.transpose(0, 1)
|
|
else:
|
|
return x
|
|
|
|
x = torch.randn(2, 3)
|
|
self.run_test(torch.jit.script(IfModel()), x,
|
|
output_names=["output_1"],
|
|
dynamic_axes={"output_1": [0, 1]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_if_list(self):
|
|
class IfModel(torch.nn.Module):
|
|
def forward(self, x, y, cond):
|
|
res = []
|
|
if cond:
|
|
res = res + [x]
|
|
else:
|
|
res = res + [y]
|
|
return res
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(3, 3)
|
|
cond = torch.tensor(1, dtype=torch.bool)
|
|
self.run_test(torch.jit.script(IfModel()), (x, y, cond))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_if_view(self):
|
|
class IfModel(torch.nn.Module):
|
|
def forward(self, x, y, cond):
|
|
bs, seq = y.shape[:2]
|
|
if cond:
|
|
res = x.view(bs, seq, -1)
|
|
else:
|
|
res = y
|
|
return res.transpose(1, 2)
|
|
|
|
x = torch.randn(2, 16, 2, 2)
|
|
y = torch.randn(2, 16, 8)
|
|
cond = torch.tensor(1, dtype=torch.bool)
|
|
self.run_test(torch.jit.script(IfModel()), (x, y, cond),
|
|
output_names=["output_1"],
|
|
dynamic_axes={"output_1": [1]})
|
|
|
|
def test_onnx_proto_checker(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return 2 * x
|
|
|
|
x = torch.randn(1, 2, 3, requires_grad=True)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(Model(), x, f)
|
|
model = onnx.load(f)
|
|
model.ir_version = 0
|
|
|
|
def check_proto():
|
|
torch._C._check_onnx_proto(model.SerializeToString())
|
|
|
|
self.assertRaises(RuntimeError, check_proto)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_split_tensor_scalar_scripting(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.split(x, x.size(1))
|
|
|
|
x = torch.randn(1, 2, 3, requires_grad=True)
|
|
self.run_test(SplitModel(), x)
|
|
|
|
@disableScriptTest() # Scripting fails to export dynamic split for opsets < 11
|
|
def test_split_tensor_scalar(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.split(x, x.size(1))
|
|
|
|
x = torch.randn(1, 2, 3, requires_grad=True)
|
|
self.run_test(SplitModel(), x)
|
|
|
|
def test_split_tensor_multi(self):
|
|
class SplitModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.split(x, torch.ones(3))
|
|
|
|
x = torch.randn(1, 2, 3, requires_grad=True)
|
|
|
|
def run_model():
|
|
SplitModel(x)
|
|
|
|
self.assertRaises(TypeError, run_model)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_embedding(self):
|
|
class EmbedModel(torch.nn.Module):
|
|
def forward(self, input, emb):
|
|
return torch.nn.functional.embedding(input, emb, padding_idx=1)
|
|
|
|
model = EmbedModel()
|
|
x = torch.randint(4, (4,))
|
|
x[2] = x[0] = 1
|
|
embedding_matrix = torch.rand(10, 3)
|
|
self.run_test(model, (x, embedding_matrix))
|
|
|
|
x = torch.randint(4, (4, 3, 2))
|
|
x[2] = 1
|
|
x[0][1] = 1
|
|
self.run_test(model, (x, embedding_matrix))
|
|
self.run_test(model, (x, embedding_matrix), training=torch.onnx.TrainingMode.TRAINING)
|
|
|
|
class EmbedModelWithoutPaddingIdx(torch.nn.Module):
|
|
def forward(self, input, emb):
|
|
return torch.nn.functional.embedding(input, emb)
|
|
|
|
model = EmbedModelWithoutPaddingIdx()
|
|
x = torch.randint(4, (4, 3, 2))
|
|
self.run_test(model, (x, embedding_matrix))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_embedding_module(self):
|
|
class EmbedModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.emb = torch.nn.Embedding(4, 3, padding_idx=1)
|
|
self.emb2 = torch.nn.Embedding(4, 3, padding_idx=1)
|
|
with torch.no_grad():
|
|
self.emb2.weight[1] = torch.ones(3)
|
|
|
|
def forward(self, input):
|
|
return self.emb(input), self.emb2(input)
|
|
|
|
model = EmbedModel()
|
|
x = torch.randint(4, (4,))
|
|
x[2] = x[0] = 1
|
|
self.run_test(model, (x,))
|
|
|
|
x = torch.randint(4, (4, 3, 2))
|
|
x[2] = 1
|
|
x[0][1] = 1
|
|
self.run_test(model, (x,))
|
|
|
|
class EmbedModelWithoutPaddingIdx(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.emb = torch.nn.Embedding(4, 3)
|
|
|
|
def forward(self, input):
|
|
return self.emb(input)
|
|
|
|
model = EmbedModelWithoutPaddingIdx()
|
|
x = torch.randint(4, (4, 3, 2))
|
|
self.run_test(model, (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_embedding_renorm(self):
|
|
n, d = 7, 5
|
|
embedding = torch.nn.Embedding(n, d, max_norm=0.2)
|
|
idx = torch.tensor([2, 1])
|
|
self.run_test(embedding, idx)
|
|
|
|
embedding = torch.nn.Embedding(n, d, max_norm=0.5, norm_type=1.)
|
|
idx = torch.tensor([4, 3, 4, 2])
|
|
self.run_test(embedding, idx)
|
|
|
|
def _dispatch_rnn_test(self, name, *args, **kwargs):
|
|
if name == "elman":
|
|
self._elman_rnn_test(*args, **kwargs)
|
|
if name == "lstm":
|
|
self._lstm_test(*args, **kwargs)
|
|
if name == "gru":
|
|
self._gru_test(*args, **kwargs)
|
|
|
|
def _elman_rnn_test(self, layers, nonlinearity, bidirectional,
|
|
initial_state, packed_sequence, dropout):
|
|
|
|
class ElmanWithStateModel(torch.nn.Module):
|
|
def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
|
|
super(ElmanWithStateModel, self).__init__()
|
|
|
|
self.batch_first = batch_first
|
|
self.inner_model = torch.nn.RNN(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, nonlinearity=nonlinearity,
|
|
bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)
|
|
|
|
def forward(self, input: PackedSequence, hx=None):
|
|
return self.inner_model(input, hx)
|
|
|
|
class ElmanWithoutStateModel(torch.nn.Module):
|
|
def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
|
|
super(ElmanWithoutStateModel, self).__init__()
|
|
self.batch_first = batch_first
|
|
self.inner_model = torch.nn.RNN(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, nonlinearity=nonlinearity,
|
|
bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)
|
|
|
|
def forward(self, input: PackedSequence):
|
|
return self.inner_model(input)
|
|
|
|
batch_first = packed_sequence == 2
|
|
|
|
if initial_state:
|
|
model = ElmanWithStateModel(layers=layers, bidirect=bidirectional, nonlinearity=nonlinearity,
|
|
dropout=dropout, batch_first=batch_first)
|
|
if packed_sequence:
|
|
model = RnnModelWithPackedSequenceWithState(model, batch_first)
|
|
else:
|
|
model = ElmanWithStateModel(layers=layers, bidirect=bidirectional,
|
|
nonlinearity=nonlinearity, dropout=dropout,
|
|
batch_first=batch_first)
|
|
if packed_sequence:
|
|
model = RnnModelWithPackedSequenceWithoutState(model, batch_first)
|
|
|
|
def make_input(batch_size):
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
directions = 2 if bidirectional else 1
|
|
|
|
if initial_state:
|
|
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append(h0)
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
input = make_input(RNN_BATCH_SIZE)
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
# test that the model still runs with a different batch size
|
|
other_input = make_input(RNN_BATCH_SIZE + 1)
|
|
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
|
|
|
|
def _lstm_test(self, layers, bidirectional, initial_state,
|
|
packed_sequence, dropout):
|
|
batch_first = packed_sequence == 2
|
|
|
|
if packed_sequence:
|
|
model = LstmFlatteningResultWithSeqLength(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
|
|
bidirectional, dropout, batch_first)
|
|
if initial_state:
|
|
model = RnnModelWithPackedSequenceWithState(model, batch_first)
|
|
else:
|
|
model = RnnModelWithPackedSequenceWithoutState(model, batch_first)
|
|
else:
|
|
model = LstmFlatteningResultWithoutSeqLength(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
|
|
bidirectional, dropout, batch_first)
|
|
|
|
def make_input(batch_size):
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
directions = 2 if bidirectional else 1
|
|
|
|
if initial_state:
|
|
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
c0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append((h0, c0))
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
input = make_input(RNN_BATCH_SIZE)
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
# test that the model still runs with a different batch size
|
|
other_input = make_input(RNN_BATCH_SIZE + 1)
|
|
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
|
|
|
|
def _gru_test(self, layers, bidirectional, initial_state,
|
|
packed_sequence, dropout):
|
|
|
|
class GRUWithStateModel(torch.nn.Module):
|
|
def __init__(self, layers, bidirect, dropout, batch_first):
|
|
super(GRUWithStateModel, self).__init__()
|
|
|
|
self.batch_first = batch_first
|
|
self.inner_model = torch.nn.GRU(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, num_layers=layers,
|
|
bidirectional=bidirectional, dropout=dropout,
|
|
batch_first=batch_first)
|
|
|
|
def forward(self, input: PackedSequence, hx):
|
|
return self.inner_model(input, hx)
|
|
|
|
class GRUWithoutStateModel(torch.nn.Module):
|
|
def __init__(self, layers, bidirect, dropout, batch_first):
|
|
super(GRUWithoutStateModel, self).__init__()
|
|
self.batch_first = batch_first
|
|
self.inner_model = torch.nn.GRU(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, num_layers=layers,
|
|
bidirectional=bidirectional, dropout=dropout,
|
|
batch_first=batch_first)
|
|
|
|
def forward(self, input: PackedSequence):
|
|
return self.inner_model(input)
|
|
|
|
class GRUNoSeqLengthWithoutStateModel(torch.nn.Module):
|
|
def __init__(self, layers, bidirect, dropout, batch_first):
|
|
super(GRUNoSeqLengthWithoutStateModel, self).__init__()
|
|
self.batch_first = batch_first
|
|
self.inner_model = torch.nn.GRU(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, num_layers=layers,
|
|
bidirectional=bidirectional, dropout=dropout,
|
|
batch_first=batch_first)
|
|
|
|
def forward(self, input):
|
|
return self.inner_model(input)
|
|
|
|
class GRUNoSeqLengthWithStateModel(torch.nn.Module):
|
|
def __init__(self, layers, bidirect, dropout, batch_first):
|
|
super(GRUNoSeqLengthWithStateModel, self).__init__()
|
|
self.batch_first = batch_first
|
|
self.inner_model = torch.nn.GRU(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, num_layers=layers,
|
|
bidirectional=bidirectional, dropout=dropout,
|
|
batch_first=batch_first)
|
|
|
|
def forward(self, input, hx):
|
|
return self.inner_model(input, hx)
|
|
|
|
batch_first = packed_sequence == 2
|
|
|
|
if packed_sequence:
|
|
if initial_state:
|
|
model = GRUWithStateModel(layers=layers, bidirect=bidirectional, dropout=dropout,
|
|
batch_first=batch_first)
|
|
model = RnnModelWithPackedSequenceWithState(model, batch_first)
|
|
else:
|
|
model = GRUWithoutStateModel(layers=layers, bidirect=bidirectional, dropout=dropout,
|
|
batch_first=batch_first)
|
|
model = RnnModelWithPackedSequenceWithoutState(model, batch_first)
|
|
else:
|
|
if initial_state:
|
|
model = GRUNoSeqLengthWithStateModel(layers=layers, bidirect=bidirectional,
|
|
dropout=dropout, batch_first=batch_first)
|
|
else:
|
|
model = GRUNoSeqLengthWithoutStateModel(layers=layers, bidirect=bidirectional,
|
|
dropout=dropout, batch_first=batch_first)
|
|
|
|
def make_input(batch_size):
|
|
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
|
seq_lengths = list(reversed(sorted(map(int, seq_lengths))))
|
|
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
|
inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first)
|
|
inputs = [inputs]
|
|
|
|
directions = 2 if bidirectional else 1
|
|
|
|
if initial_state:
|
|
h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)
|
|
inputs.append(h0)
|
|
if packed_sequence != 0:
|
|
inputs.append(torch.IntTensor(seq_lengths))
|
|
if len(inputs) == 1:
|
|
input = inputs[0]
|
|
else:
|
|
input = tuple(inputs)
|
|
return input
|
|
|
|
input = make_input(RNN_BATCH_SIZE)
|
|
self.run_test(model, input, batch_size=RNN_BATCH_SIZE)
|
|
|
|
# test that the model still runs with a different batch size
|
|
other_input = make_input(RNN_BATCH_SIZE + 1)
|
|
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
|
|
|
|
@disableScriptTest() # TODO: RuntimeError: Exporting the operator __is_ to ONNX is not supported
|
|
def test_transformer_encoder(self):
|
|
from torch.nn import TransformerEncoderLayer, TransformerEncoder
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self, ninp, nhead, nhid, dropout, nlayers):
|
|
super(MyModule, self).__init__()
|
|
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
|
|
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
|
|
|
|
def forward(self, input):
|
|
return self.transformer_encoder(input)
|
|
|
|
x = torch.rand(10, 32, 512)
|
|
self.run_test(MyModule(512, 8, 2048 , 0., 3), (x,), atol=1e-6)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_fake_quantize_per_tensor(self):
|
|
class FakeQuantizePerTensorModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
scale = 1. / 127
|
|
zero_point = 0
|
|
quant_min = -128
|
|
quant_max = 127
|
|
return torch.fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max)
|
|
|
|
x = torch.randn(6, 4, 3, 3)
|
|
self.run_test(FakeQuantizePerTensorModel(), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_fake_quantize_per_tensor_dynamic_scale_zeropoint(self):
|
|
class FakeQuantizePerTensorModel(torch.nn.Module):
|
|
def forward(self, input, scale, zero_point):
|
|
quant_min = -128
|
|
quant_max = 127
|
|
return torch.fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max)
|
|
|
|
x = torch.randn(6, 4, 3, 3)
|
|
scale = torch.tensor(1. / 127)
|
|
zero_point = torch.tensor(0)
|
|
self.run_test(FakeQuantizePerTensorModel(), (x, scale, zero_point))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_fake_quantize_per_channel(self):
|
|
class FakeQuantizePerChannelModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
amax = torch.ones(4)
|
|
scale = amax / 127.
|
|
zero_point = torch.zeros_like(amax, dtype=torch.int)
|
|
# Quantize twice to test differnet branches
|
|
y = torch.fake_quantize_per_channel_affine(input, scale, zero_point, 1, 0, 255)
|
|
return torch.fake_quantize_per_channel_affine(y, scale, zero_point, 1, -128, 127)
|
|
|
|
x = torch.randn(6, 4, 3, 3)
|
|
self.run_test(FakeQuantizePerChannelModel(), (x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
@disableScriptTest() # RuntimeError: Can't redefine method: forward on class: __torch__.torch.nn.modules.linear.Linear
|
|
def test_fake_quantize_activation(self):
|
|
from torch import quantization
|
|
m = torch.nn.Linear(1, 1)
|
|
m.qconfig = quantization.QConfig(
|
|
activation=quantization.default_fake_quant,
|
|
weight=quantization.default_per_channel_weight_fake_quant)
|
|
quantization.prepare_qat(m.train(), inplace=True)
|
|
m.apply(quantization.enable_observer)
|
|
m.apply(quantization.enable_fake_quant)
|
|
for module in m.modules():
|
|
if isinstance(module, quantization.FakeQuantize):
|
|
module.calculate_qparams()
|
|
|
|
m.apply(quantization.disable_observer)
|
|
m.eval()
|
|
|
|
# Fake quantize activation is a special case, as it restricts quantized range to be (0, 127),
|
|
# while standard 8bit quantization range is (-128, 127) or (0, 255).
|
|
# Set fixed weight, bias and inputs to test if ONNX handles the overflow correctly.
|
|
m.weight = torch.nn.Parameter(torch.tensor([[1.], [1.], [1.]]))
|
|
m.bias = torch.nn.Parameter(torch.tensor([0.]))
|
|
x = torch.tensor([[150.], [127.], [-5.]])
|
|
self.run_test(m, x)
|
|
|
|
def test_batchnorm_training(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.bn1 = torch.nn.BatchNorm2d(3, affine=False)
|
|
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
|
self.bn2 = torch.nn.BatchNorm2d(3, affine=True)
|
|
self.cv2 = torch.nn.Conv2d(3, 3, 10)
|
|
self.bn3 = torch.nn.BatchNorm2d(3, affine=False)
|
|
|
|
def forward(self, x):
|
|
x = self.bn1(x)
|
|
x = self.cv1(x)
|
|
x = self.bn2(x)
|
|
x = self.cv2(x)
|
|
x = self.bn3(x)
|
|
return x
|
|
|
|
x = torch.randn(10, 3, 20, 20) * 2
|
|
model_export = MyModule()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5)
|
|
model_export.train()
|
|
self.run_test(model_export, (x, ), training=torch.onnx.TrainingMode.PRESERVE, rtol=1e-3, atol=1e-5)
|
|
|
|
def test_batchnorm_training_mode_fix_layer(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
|
|
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
|
self.bn2 = torch.nn.BatchNorm2d(3, affine=False)
|
|
self.cv2 = torch.nn.Conv2d(3, 3, 10)
|
|
self.bn3 = torch.nn.BatchNorm2d(3, affine=True)
|
|
self.bn3.eval()
|
|
|
|
def forward(self, x):
|
|
x = self.bn1(x)
|
|
x = self.cv1(x)
|
|
x = self.bn2(x)
|
|
x = self.cv2(x)
|
|
x = self.bn3(x)
|
|
return x
|
|
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model_export = MyModule()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5)
|
|
model_export.train()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.PRESERVE, rtol=1e-3, atol=1e-5)
|
|
|
|
def test_batchnorm_eval_mode_train_layer(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
|
|
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
|
self.bn2 = torch.nn.BatchNorm2d(3, affine=False)
|
|
self.cv2 = torch.nn.Conv2d(3, 3, 10)
|
|
self.bn3 = torch.nn.BatchNorm2d(3, affine=True)
|
|
self.bn3.train()
|
|
|
|
def forward(self, x):
|
|
x = self.bn1(x)
|
|
x = self.cv1(x)
|
|
x = self.bn2(x)
|
|
x = self.cv2(x)
|
|
x = self.bn3(x)
|
|
return x
|
|
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model_export = MyModule()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL, rtol=1e-3, atol=1e-5)
|
|
model_export.eval()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.PRESERVE, rtol=1e-3, atol=1e-5)
|
|
|
|
def test_instancenorm_training(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
|
|
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
|
self.in2 = torch.nn.InstanceNorm2d(3, affine=False)
|
|
self.cv2 = torch.nn.Conv2d(3, 3, 10)
|
|
self.in3 = torch.nn.InstanceNorm2d(3, affine=True)
|
|
|
|
def forward(self, x):
|
|
x = self.in1(x)
|
|
x = self.cv1(x)
|
|
x = self.in2(x)
|
|
x = self.cv2(x)
|
|
x = self.in3(x)
|
|
return x
|
|
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model_export = MyModule()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5)
|
|
model_export.train()
|
|
self.run_test(model_export, (x, ), training=torch.onnx.TrainingMode.PRESERVE, rtol=1e-3, atol=1e-5)
|
|
|
|
def test_instancenorm_training_mode_fix_layer(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
|
|
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
|
self.in2 = torch.nn.InstanceNorm2d(3, affine=False)
|
|
self.cv2 = torch.nn.Conv2d(3, 3, 10)
|
|
self.in3 = torch.nn.InstanceNorm2d(3, affine=True)
|
|
self.in3.eval()
|
|
|
|
def forward(self, x):
|
|
x = self.in1(x)
|
|
x = self.cv1(x)
|
|
x = self.in2(x)
|
|
x = self.cv2(x)
|
|
x = self.in3(x)
|
|
return x
|
|
|
|
x = torch.randn(10, 3, 128, 128)
|
|
model_export = MyModule()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5)
|
|
model_export.train()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.PRESERVE, rtol=1e-3, atol=1e-5)
|
|
|
|
def test_instancenorm_eval_mode_train_layer(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.in1 = torch.nn.InstanceNorm2d(8, affine=True)
|
|
self.cv1 = torch.nn.Conv2d(8, 8, 10)
|
|
self.in2 = torch.nn.InstanceNorm2d(8, affine=False)
|
|
self.cv2 = torch.nn.Conv2d(8, 8, 10)
|
|
self.in3 = torch.nn.InstanceNorm2d(8, affine=True)
|
|
self.in3.train()
|
|
|
|
def forward(self, x):
|
|
x = self.in1(x)
|
|
x = self.cv1(x)
|
|
x = self.in2(x)
|
|
x = self.cv2(x)
|
|
x = self.in3(x)
|
|
return x
|
|
|
|
x = torch.randn(10, 8, 128, 128)
|
|
model_export = MyModule()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL, rtol=1e-3, atol=1e-5)
|
|
model_export.eval()
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.PRESERVE, rtol=1e-3, atol=1e-5)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_dropout_training(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.dropout = torch.nn.Dropout(0.4)
|
|
|
|
def forward(self, x):
|
|
dropout = self.dropout(x)
|
|
return dropout
|
|
|
|
model = MyModule()
|
|
x = torch.randn(10)
|
|
model.train()
|
|
|
|
ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.TRAINING)
|
|
ort_outs = run_ort(ort_sess, input=(x,))
|
|
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
|
|
|
|
script_model = torch.jit.script(model)
|
|
output = model(x)
|
|
ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.TRAINING)
|
|
ort_outs = run_ort(ort_sess, input=(x,))
|
|
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_dropout_training_zero(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.dropout = torch.nn.Dropout(0.5)
|
|
|
|
def forward(self, x):
|
|
dropout = self.dropout(x)
|
|
return dropout
|
|
|
|
model = MyModule()
|
|
|
|
# ensure there are no zeros in the input
|
|
x = torch.randn(10, 3, 128, 128)
|
|
y = x.numpy()
|
|
y_mask = np.where(y == 0, 1, y)
|
|
input = torch.from_numpy(y_mask)
|
|
nb_elements = torch.numel(input)
|
|
|
|
model.train()
|
|
ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.TRAINING)
|
|
ort_outs = run_ort(ort_sess, input=(x,))
|
|
|
|
y = model(input)
|
|
output = y.cpu().numpy()
|
|
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
|
|
pyt_mask = np.where(output != 0, 1, 0)
|
|
|
|
ratio_pytorch = np.sum(pyt_mask) / nb_elements
|
|
ratio_ort = np.sum(ort_mask) / nb_elements
|
|
|
|
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
|
|
|
|
script_model = torch.jit.script(model)
|
|
y = model(input)
|
|
output = y.cpu().numpy()
|
|
ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.TRAINING)
|
|
ort_outs = run_ort(ort_sess, input=(x,))
|
|
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
|
|
pyt_mask = np.where(output != 0, 1, 0)
|
|
|
|
ratio_pytorch = np.sum(pyt_mask) / nb_elements
|
|
ratio_ort = np.sum(ort_mask) / nb_elements
|
|
|
|
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
|
|
|
|
def test_conv_bn(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
|
|
self.bn = torch.nn.BatchNorm2d(16, affine=True)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
bn = self.bn(x)
|
|
return bn
|
|
|
|
model_export = MyModule()
|
|
x = torch.randn(10, 3, 128, 128)
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL)
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5)
|
|
|
|
def test_multiple_conv_bn(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
self.conv2 = torch.nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0, bias=False)
|
|
self.conv3 = torch.nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.bn = torch.nn.BatchNorm2d(64)
|
|
self.bn2 = torch.nn.BatchNorm2d(2)
|
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
x = self.conv3(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
model_export = MyModule()
|
|
x = torch.randn(2, 3, 224, 224)
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5)
|
|
self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL)
|
|
|
|
def test_script_custom_class_error(self):
|
|
class BoxCoder(object):
|
|
def __init__(self, bbox_xform_clip: float) -> None:
|
|
self.bbox_xform_clip = bbox_xform_clip
|
|
|
|
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
|
|
boxes = torch.cat(boxes, dim=0)
|
|
pred_ctr_x = torch.clamp(rel_codes[:, 0::4], max=self.bbox_xform_clip) * boxes[:, 2]
|
|
return pred_ctr_x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
__annotations__ = {
|
|
"box_coder": BoxCoder,
|
|
}
|
|
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.box_coder = BoxCoder(1.4)
|
|
|
|
def forward(self, box_regression: torch.Tensor, proposals: List[torch.Tensor]):
|
|
return self.box_coder.decode(box_regression, proposals)
|
|
|
|
model = torch.jit.script(MyModule())
|
|
box_regression = torch.randn([4, 4])
|
|
proposal = [torch.randn(2, 4), torch.randn(2, 4)]
|
|
|
|
with self.assertRaises(RuntimeError) as cm:
|
|
convert_to_onnx(model, input=(box_regression, proposal))
|
|
|
|
def test_initializer_sequence(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_classes):
|
|
super(MyModule, self).__init__()
|
|
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
|
self.relu = torch.nn.ReLU()
|
|
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
|
|
|
def forward(self, x):
|
|
out = self.fc1(x)
|
|
out = self.relu(out)
|
|
out = self.fc2(out)
|
|
return out
|
|
|
|
test_model = MyModule(3, 4, 10)
|
|
state_dict_list = [k for (k, v) in test_model.state_dict().items()]
|
|
named_params_list = [k for (k, v) in test_model.named_parameters()]
|
|
|
|
x = torch.randn(32, 3)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(test_model, (x,), f, do_constant_folding=False)
|
|
loaded_model = onnx.load_from_string(f.getvalue())
|
|
|
|
actual_list = [p.name for p in loaded_model.graph.initializer]
|
|
assert actual_list == state_dict_list, \
|
|
"Initializers' sequence is not as same as state_dict(). Expected: (" \
|
|
+ ", ".join(state_dict_list) + "). Actual:(" + ", ".join(actual_list) + ")."
|
|
assert actual_list == named_params_list, \
|
|
"Initializers' sequence is not as same as named_parameters(). Expected: (" \
|
|
+ ", ".join(named_params_list) + "). Actual:(" + ", ".join(actual_list) + ")."
|
|
|
|
def test_initializer_sequence_script_model(self):
|
|
def list_is_expected(short_list, long_list) -> bool:
|
|
if (len(short_list) > len(long_list)):
|
|
return False
|
|
|
|
for i in range(len(short_list)):
|
|
if (short_list[i] not in long_list[i]):
|
|
return False
|
|
|
|
return True
|
|
|
|
def loop(x, y):
|
|
for i in range(int(y)):
|
|
x = x + i
|
|
return x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self, input_size, hidden_size, num_classes):
|
|
super(MyModule, self).__init__()
|
|
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
|
self.relu = torch.nn.ReLU()
|
|
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
|
|
|
def forward(self, x, y):
|
|
x = loop(x, y)
|
|
out = self.fc1(x)
|
|
out = self.relu(out)
|
|
out = self.fc2(out)
|
|
return out
|
|
|
|
test_model = torch.jit.script(MyModule(3, 4, 10))
|
|
state_dict_list = [k for (k, v) in test_model.state_dict().items()]
|
|
named_params_list = [k for (k, v) in test_model.named_parameters()]
|
|
|
|
x = torch.ones(2, 3, dtype=torch.float)
|
|
y = torch.tensor(5, dtype=torch.long)
|
|
f = io.BytesIO()
|
|
|
|
torch.onnx.export(test_model, (x, y), f, do_constant_folding=False)
|
|
loaded_model = onnx.load_from_string(f.getvalue())
|
|
|
|
actual_list = [p.name for p in loaded_model.graph.initializer]
|
|
assert list_is_expected(state_dict_list, actual_list), \
|
|
"ScriptModel - Initializers' sequence is not as same as state_dict(). Expected: (" \
|
|
+ ", ".join(state_dict_list) + "). Actual:(" + ", ".join(actual_list) + ")."
|
|
assert list_is_expected(named_params_list, actual_list), \
|
|
"ScriptModel - Initializers' sequence is not as same as named_parameters(). Expected: (" \
|
|
+ ", ".join(named_params_list) + "). Actual:(" + ", ".join(actual_list) + ")."
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_nms(self):
|
|
num_boxes = 100
|
|
boxes = torch.rand(num_boxes, 4)
|
|
boxes[:, 2:] += boxes[:, :2]
|
|
scores = torch.randn(num_boxes)
|
|
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, boxes, scores):
|
|
return ops.nms(boxes, scores, 0.5)
|
|
|
|
self.run_test(Module(), (boxes, scores))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_batched_nms(self):
|
|
num_boxes = 100
|
|
boxes = torch.rand(num_boxes, 4)
|
|
boxes[:, 2:] += boxes[:, :2]
|
|
scores = torch.randn(num_boxes)
|
|
idxs = torch.randint(0, 5, size=(num_boxes,))
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, boxes, scores, idxs):
|
|
return ops.batched_nms(boxes, scores, idxs, 0.5)
|
|
|
|
self.run_test(Module(), (boxes, scores, idxs))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_clip_boxes_to_image(self):
|
|
boxes = torch.randn(5, 4) * 500
|
|
boxes[:, 2:] += boxes[:, :2]
|
|
size = torch.randn(200, 300)
|
|
|
|
size_2 = torch.randn(300, 400)
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, boxes, size):
|
|
shape = (size.shape[0], size.shape[1])
|
|
return ops.boxes.clip_boxes_to_image(boxes, shape)
|
|
|
|
self.run_test(Module(), (boxes, size),
|
|
input_names=["boxes", "size"],
|
|
dynamic_axes={"size": [0, 1]},
|
|
test_with_inputs=[(boxes, size), (boxes, size_2)])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_roi_align(self):
|
|
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
|
|
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
|
|
model = ops.RoIAlign((5, 5), 1., 2)
|
|
self.run_test(model, (x, single_roi))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_roi_align_aligned(self):
|
|
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
|
|
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
|
|
model1 = ops.RoIAlign((5, 5), 1., 2, aligned=True)
|
|
self.run_test(model1, (x, single_roi))
|
|
|
|
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
|
|
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
|
|
model2 = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
|
|
self.run_test(model2, (x, single_roi))
|
|
|
|
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
|
|
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
|
|
model3 = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
|
|
self.run_test(model3, (x, single_roi))
|
|
|
|
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
|
|
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
|
|
model4 = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
|
|
self.run_test(model4, (x, single_roi))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_roi_pool(self):
|
|
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
|
|
rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
|
|
pool_h = 5
|
|
pool_w = 5
|
|
model = ops.RoIPool((pool_h, pool_w), 2.)
|
|
self.run_test(model, (x, rois))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_resize_images(self):
|
|
class TransformModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TransformModule, self).__init__()
|
|
self.transform = _init_test_generalized_rcnn_transform()
|
|
|
|
def forward(self, images):
|
|
return self.transform.resize(images, None)[0]
|
|
|
|
input = torch.rand(3, 10, 20)
|
|
input_test = torch.rand(3, 100, 150)
|
|
self.run_test(TransformModule(), (input,),
|
|
input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]},
|
|
test_with_inputs=[(input,), (input_test,)])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_transform_images(self):
|
|
|
|
class TransformModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TransformModule, self).__init__()
|
|
self.transform = _init_test_generalized_rcnn_transform()
|
|
|
|
def forward(self, images: List[torch.Tensor]):
|
|
return self.transform(images)[0].tensors
|
|
|
|
input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
|
|
input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
|
|
self.run_test(TransformModule(), (input,), test_with_inputs=[(input,), (input_test,)])
|
|
|
|
def get_features(self, images):
|
|
s0, s1 = images.shape[-2:]
|
|
features = [
|
|
("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
|
|
("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
|
|
("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
|
|
("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
|
|
("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
|
|
]
|
|
features = OrderedDict(features)
|
|
return features
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_rpn(self):
|
|
set_rng_seed(0)
|
|
|
|
class RPNModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(RPNModule, self).__init__()
|
|
self.rpn = _init_test_rpn()
|
|
|
|
def forward(self, images, features: Dict[str, torch.Tensor]):
|
|
images_m = ImageList(images, [(i.shape[-1], i.shape[-2]) for i in images])
|
|
return self.rpn(images_m, features)
|
|
|
|
images = torch.rand(2, 3, 150, 150)
|
|
features = self.get_features(images)
|
|
images2 = torch.rand(2, 3, 80, 80)
|
|
test_features = self.get_features(images2)
|
|
|
|
model = RPNModule()
|
|
model.eval()
|
|
model(images, features)
|
|
self.run_test(model, (images, features),
|
|
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
|
|
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3],
|
|
"input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3],
|
|
"input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]},
|
|
test_with_inputs=[(images, features), (images2, test_features)],
|
|
dict_check=False)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_multi_scale_roi_align(self):
|
|
class TransformModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TransformModule, self).__init__()
|
|
self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2)
|
|
self.image_sizes = [(512, 512)]
|
|
|
|
def forward(self, input: Dict[str, Tensor], boxes: List[Tensor]) -> Tensor:
|
|
return self.model(input, boxes, self.image_sizes)
|
|
|
|
i = OrderedDict()
|
|
i["feat1"] = torch.rand(1, 5, 64, 64)
|
|
i["feat2"] = torch.rand(1, 5, 16, 16)
|
|
boxes = torch.rand(6, 4) * 256
|
|
boxes[:, 2:] += boxes[:, :2]
|
|
|
|
i1 = OrderedDict()
|
|
i1["feat1"] = torch.rand(1, 5, 64, 64)
|
|
i1["feat2"] = torch.rand(1, 5, 16, 16)
|
|
boxes1 = torch.rand(6, 4) * 256
|
|
boxes1[:, 2:] += boxes1[:, :2]
|
|
|
|
self.run_test(TransformModule(), (i, [boxes],), test_with_inputs=[(i, [boxes],), (i1, [boxes1],)])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
@disableScriptTest()
|
|
def test_roi_heads(self):
|
|
class RoiHeadsModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(RoiHeadsModule, self).__init__()
|
|
self.transform = _init_test_generalized_rcnn_transform()
|
|
self.rpn = _init_test_rpn()
|
|
self.roi_heads = _init_test_roi_heads_faster_rcnn()
|
|
|
|
def forward(self, images, features: Dict[str, torch.Tensor]):
|
|
original_image_sizes = [(img.shape[-1], img.shape[-2]) for img in images]
|
|
|
|
images_m = ImageList(images, [(i.shape[-1], i.shape[-2]) for i in images])
|
|
proposals, _ = self.rpn(images_m, features)
|
|
detections, _ = self.roi_heads(features, proposals, images_m.image_sizes)
|
|
detections = self.transform.postprocess(detections,
|
|
images_m.image_sizes,
|
|
original_image_sizes)
|
|
return detections
|
|
|
|
images = torch.rand(2, 3, 100, 100)
|
|
features = self.get_features(images)
|
|
images2 = torch.rand(2, 3, 150, 150)
|
|
test_features = self.get_features(images2)
|
|
|
|
model = RoiHeadsModule()
|
|
model.eval()
|
|
model(images, features)
|
|
|
|
self.run_test(model, (images, features),
|
|
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
|
|
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
|
|
"input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]},
|
|
test_with_inputs=[(images, features), (images2, test_features)],
|
|
dict_check=False)
|
|
|
|
def test_set_(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
x.set_(y)
|
|
return x
|
|
|
|
x = torch.ones(2, 3)
|
|
y = torch.randn(4, 6)
|
|
self.run_test(M(), (x, y), remained_onnx_input_idx=[1])
|
|
|
|
y2 = torch.randn(5, 2)
|
|
self.run_test(M(), (x, y), remained_onnx_input_idx=[1], input_names=['x', 'y'],
|
|
dynamic_axes={'x': [0, 1], 'y': [0, 1]},
|
|
test_with_inputs=[(y, y2)])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_set_attr_modules(self):
|
|
class InnerModule2(torch.nn.Module):
|
|
def __init__(self, embedding_dim):
|
|
super().__init__()
|
|
self.weights = InnerModule2.get_embedding(embedding_dim)
|
|
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
|
self.const = 2
|
|
|
|
@staticmethod
|
|
def get_embedding(embedding_dim: int):
|
|
emb = 4 / ((embedding_dim // 2) - 1)
|
|
emb = torch.exp(torch.arange((embedding_dim // 2), dtype=torch.float) * -emb)
|
|
return emb
|
|
|
|
def forward(self, input, incremental_state: Optional[torch.Tensor] = None):
|
|
bsz, seq_len = input.shape[0], input.shape[1]
|
|
self.const = 3
|
|
if self.weights is None:
|
|
self.weights = InnerModule.get_embedding(self.embedding_dim)
|
|
self.weights = self.weights.to(self._float_tensor)
|
|
self.weights = self.weights * self.const
|
|
if incremental_state is not None:
|
|
pos = seq_len
|
|
return self.weights[1 + pos, :].expand(bsz, 1, -1)
|
|
return (
|
|
self.weights.index_select(0, torch.ones((bsz * seq_len), dtype=torch.int64)).view(bsz, seq_len, -1)
|
|
)
|
|
|
|
class InnerModule(torch.nn.Module):
|
|
def __init__(self, embedding_dim):
|
|
super().__init__()
|
|
self.weights = InnerModule.get_embedding(embedding_dim)
|
|
self.module = InnerModule2(embedding_dim=8)
|
|
|
|
@staticmethod
|
|
def get_embedding(embedding_dim: int):
|
|
emb = 4 / ((embedding_dim // 2) - 1)
|
|
emb = torch.exp(torch.arange((embedding_dim // 2), dtype=torch.float) * -emb)
|
|
return emb
|
|
|
|
def forward(self, x):
|
|
return self.module(x) + self.weights
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Module, self).__init__()
|
|
self.module = InnerModule(embedding_dim=8)
|
|
|
|
def forward(self, x):
|
|
return self.module(x)
|
|
|
|
x = torch.randn(3, 256)
|
|
self.run_test(Module(), (x, ), input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(Module(), (x, ), remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_set_attr_modules_2(self):
|
|
class InnerModule(torch.nn.Module):
|
|
def __init__(self, embedding_dim):
|
|
super().__init__()
|
|
self.embedding_dim = embedding_dim
|
|
self.const = 2.5
|
|
self.weights = InnerModule.get_embedding(self.embedding_dim)
|
|
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
|
|
|
@staticmethod
|
|
def get_embedding(embedding_dim: int):
|
|
emb = 4 / ((embedding_dim // 2) - 1)
|
|
emb = torch.exp(torch.arange((embedding_dim // 2), dtype=torch.float) * -emb)
|
|
return emb
|
|
|
|
def forward(self, input, incremental_state: Optional[torch.Tensor] = None):
|
|
bsz, seq_len = input.shape[0], input.shape[1]
|
|
self.const = 1.5
|
|
self.weights = InnerModule.get_embedding(self.embedding_dim)
|
|
return (
|
|
self.weights.index_select(0, torch.ones((bsz * seq_len), dtype=torch.int64)).view(bsz, seq_len, -1)
|
|
) * self.const
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Module, self).__init__()
|
|
self.module = InnerModule(embedding_dim=8)
|
|
|
|
def forward(self, x):
|
|
return self.module(x)
|
|
|
|
x = torch.randn(3, 256)
|
|
self.run_test(Module(), (x, ), input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(Module(), (x, ), remained_onnx_input_idx=[])
|
|
|
|
def test_set_attr(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv = torch.nn.Conv1d(3, 10, 2)
|
|
self.b = False
|
|
|
|
def forward(self, box_regression, weight):
|
|
self.b = True
|
|
self.conv.weight = weight
|
|
w = torch.softmax(self.conv.weight, dim=0)
|
|
self.conv.weight = w + w
|
|
if self.b:
|
|
return box_regression + self.conv.weight
|
|
else:
|
|
return box_regression - self.conv.weight
|
|
|
|
model = torch.jit.script(MyModule())
|
|
weight = torch.ones(3, 2)
|
|
box_regression = torch.randn(3, 2)
|
|
self.run_test(model, (box_regression, weight))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_set_attr_2(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv = torch.nn.Conv1d(10, 3, 3)
|
|
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
|
|
|
|
def set_cell_anchors(self, anchors):
|
|
if self.conv.bias is not None:
|
|
b = self.conv.bias
|
|
assert b is not None
|
|
self.conv.bias = anchors + b
|
|
elif self.conv.weight is not None:
|
|
self.conv.weight = torch.randn(3, 10)
|
|
self.conv.bias = self.conv.weight[:]
|
|
|
|
def forward(self, anchors) -> Optional[torch.Tensor]:
|
|
self.set_cell_anchors(anchors)
|
|
return self.conv.bias
|
|
|
|
model = torch.jit.script(MyModule())
|
|
anchors = torch.ones(3, 10, 3)
|
|
self.run_test(model, (anchors))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_set_attr_3(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv = torch.nn.Conv1d(10, 3, 3)
|
|
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
|
|
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
|
|
|
|
def set_cell_anchors(self, anchors, boxes):
|
|
self.conv.weight = torch.ones(3, 10)
|
|
if self.conv.bias is not None:
|
|
self.conv.bias = torch.randn(3, 10, 3)
|
|
self.conv.weight = anchors + self.conv.weight
|
|
boxes[:] = torch.zeros(2, 3)
|
|
|
|
def forward(self, anchors) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
boxes = torch.ones(2, 2, 3)
|
|
self.set_cell_anchors(anchors, boxes)
|
|
if self.conv.bias is not None:
|
|
return self.conv.weight, boxes
|
|
return anchors, boxes
|
|
|
|
model = torch.jit.script(MyModule())
|
|
anchors = torch.rand(3, 10)
|
|
self.run_test(model, (anchors))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_set_attr_4(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv = torch.nn.Conv1d(10, 3, 3)
|
|
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
|
|
|
|
def set_cell_anchors(self, anchors):
|
|
self.conv.weight = torch.zeros(10, 3)
|
|
if self.conv.bias is not None:
|
|
w = self.conv.bias
|
|
assert w is not None
|
|
self.conv.bias = anchors + w
|
|
else:
|
|
self.conv.bias = torch.ones(3, 10, 3)
|
|
|
|
def forward(self, feature_maps, anchors) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
self.set_cell_anchors(anchors)
|
|
result = []
|
|
if self.conv.bias is not None:
|
|
a = self.conv.bias
|
|
assert a is not None
|
|
result += [a]
|
|
result += [feature_maps]
|
|
return result[0], result[1]
|
|
|
|
model = torch.jit.script(MyModule())
|
|
x = torch.rand(5, 11, 30)
|
|
anchors = torch.ones(3, 10, 3)
|
|
self.run_test(model, (x, anchors))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_set_attr_5(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv = torch.nn.Conv1d(10, 3, 3)
|
|
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
|
|
|
|
def set_cell_anchors(self, anchors):
|
|
self.conv.weight = torch.arange(10)
|
|
for i in range(10):
|
|
if i == 3:
|
|
for j in range(10):
|
|
w = self.conv.weight
|
|
self.conv.weight = torch.arange(10) + w
|
|
|
|
self.conv.weight = self.conv.weight + torch.arange(10)
|
|
# NOTE: `is not None` and `assert` is for passing torchscript.
|
|
if self.conv.bias is not None:
|
|
a = self.conv.bias
|
|
assert a is not None
|
|
self.conv.bias = anchors + a
|
|
|
|
def forward(self, anchors):
|
|
self.set_cell_anchors(anchors)
|
|
return self.conv.weight, self.conv.bias
|
|
|
|
model = torch.jit.script(MyModule())
|
|
anchors = torch.ones(3, 10, 3)
|
|
self.run_test(model, (anchors))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_set_attr_in_loop(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv = torch.nn.Conv1d(10, 3, 3)
|
|
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
|
|
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
|
|
|
|
def set_cell_anchors(self, anchors, boxes):
|
|
self.conv.weight = torch.randn(3, 10)
|
|
for i in range(self.conv.weight.size(0)):
|
|
for j in range(10):
|
|
self.conv.bias = torch.randn(3, 10, 3)
|
|
self.conv.weight = anchors * i
|
|
boxes[j] += torch.ones(3, 3)
|
|
|
|
def forward(self, anchors) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
boxes = torch.ones(10, 3, 3)
|
|
self.set_cell_anchors(anchors, boxes)
|
|
if self.conv.bias is not None:
|
|
return self.conv.weight, boxes
|
|
return anchors, boxes
|
|
|
|
model = torch.jit.script(MyModule())
|
|
anchors = torch.rand(10)
|
|
self.run_test(model, anchors)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_set_attr_in_loop_with_list(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.conv = torch.nn.Conv1d(10, 3, 3)
|
|
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
|
|
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
|
|
self.boxes : List[torch.Tensor] = [torch.ones(1)] # Workaround placeholder for TorchScript
|
|
|
|
def set_cell_anchors(self, anchors):
|
|
self.conv.weight = torch.randn(3, 10)
|
|
for i in range(self.conv.weight.size(0)):
|
|
for j in range(10):
|
|
self.conv.bias = torch.randn(3, 10, 3)
|
|
self.conv.weight = anchors * i
|
|
self.boxes.append(torch.ones(3, 3))
|
|
|
|
def forward(self, anchors) -> Tuple[Tensor, List[Tensor]]:
|
|
self.boxes = []
|
|
self.set_cell_anchors(anchors)
|
|
if self.conv.bias is not None:
|
|
return self.conv.weight, self.boxes
|
|
return anchors, self.boxes
|
|
|
|
model = torch.jit.script(MyModule())
|
|
anchors = torch.rand(10)
|
|
self.run_test(model, anchors)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_if(self):
|
|
@torch.jit.script
|
|
def check_init(input_data: Tensor, hidden_size: int, prev_state: Tensor) -> Tuple[Tensor, Tensor]:
|
|
batch_size = input_data.size(0)
|
|
spatial_size_0 = input_data.size(2)
|
|
spatial_size_1 = input_data.size(3)
|
|
# generate empty prev_state, if None is provided
|
|
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
|
|
state = torch.zeros(state_size, device=input_data.device)
|
|
state_copy = torch.zeros(state_size, device=input_data.device)
|
|
if prev_state.size(0) == 0:
|
|
state[:] = torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1) + state[:]
|
|
state_copy[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 2
|
|
state_copy[:] = torch.zeros(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 2
|
|
else:
|
|
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
|
|
return state, state_copy
|
|
|
|
class Example(torch.nn.Module):
|
|
def __init__(self, hidden_size):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
|
|
def forward(self, input_data, prev_state):
|
|
prev_state = check_init(input_data, self.hidden_size, prev_state)
|
|
return prev_state[0], prev_state[1]
|
|
|
|
model = Example(10)
|
|
random_data = torch.rand((1, 5, 30, 30))
|
|
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
|
|
self.run_test(model, (random_data, empty_tensor),
|
|
input_names=["random_data", "empty_tensor"],
|
|
dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]})
|
|
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_if_2(self):
|
|
@torch.jit.script
|
|
def check_init(input_data: Tensor, hidden_size: int, prev_state: Tensor) -> Tuple[Tensor, Tensor]:
|
|
batch_size = input_data.size(0)
|
|
spatial_size_0 = input_data.size(2)
|
|
spatial_size_1 = input_data.size(3)
|
|
# generate empty prev_state, if None is provided
|
|
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
|
|
state = torch.zeros(state_size, device=input_data.device)
|
|
state_copy = torch.zeros(state_size, device=input_data.device)
|
|
if prev_state.size(0) == 0:
|
|
for i in range(2):
|
|
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * i
|
|
state_copy[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * i
|
|
elif prev_state.size(0) == 1:
|
|
s = state[:]
|
|
state[:] = prev_state + s
|
|
elif prev_state.size(0) == 2:
|
|
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
|
|
return state, state_copy
|
|
|
|
class Example(torch.nn.Module):
|
|
def __init__(self, hidden_size):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
|
|
def forward(self, input_data, prev_state):
|
|
prev_state = check_init(input_data, self.hidden_size, prev_state)
|
|
return prev_state[0], prev_state[1]
|
|
|
|
model = Example(10)
|
|
random_data = torch.rand((1, 5, 30, 30))
|
|
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
|
|
random_state = torch.rand((1, 1, 10, 30, 30))
|
|
self.run_test(model, (random_data, empty_tensor),
|
|
input_names=["data", "state"],
|
|
dynamic_axes={"data": [0, 1, 2], "state": [0, 1, 2, 3, 4]},
|
|
test_with_inputs=[(random_data, random_state)])
|
|
self.run_test(model, (random_data, empty_tensor),
|
|
input_names=["data", "state"],
|
|
dynamic_axes={"state": [0, 1, 2, 3, 4]},
|
|
test_with_inputs=[(random_data, random_state)],
|
|
remained_onnx_input_idx=[1])
|
|
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_if_3(self):
|
|
@torch.jit.script
|
|
def check_init(input_data: Tensor, hidden_size: int, prev_state: Tensor) -> Tensor:
|
|
batch_size = input_data.size(0)
|
|
spatial_size_0 = input_data.size(2)
|
|
spatial_size_1 = input_data.size(3)
|
|
# generate empty prev_state, if None is provided
|
|
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
|
|
state = torch.zeros(state_size, device=input_data.device)
|
|
if prev_state.size(0) < 2:
|
|
state = state * 3
|
|
if prev_state.size(0) == 0:
|
|
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
|
|
else:
|
|
state = state + 2
|
|
|
|
return state
|
|
|
|
class Example(torch.nn.Module):
|
|
def __init__(self, hidden_size):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
|
|
def forward(self, input_data, prev_state):
|
|
prev_state = check_init(input_data, self.hidden_size, prev_state)
|
|
return prev_state
|
|
|
|
model = Example(4)
|
|
random_data = torch.rand((1, 5, 4, 4))
|
|
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
|
|
self.run_test(model, (random_data, empty_tensor),
|
|
input_names=["random_data", "empty_tensor"],
|
|
dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]})
|
|
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_if_4(self):
|
|
@torch.jit.script
|
|
def check_init(input_data: Tensor, hidden_size: int, prev_state: Tensor) -> Tensor:
|
|
batch_size = input_data.size(0)
|
|
spatial_size_0 = input_data.size(2)
|
|
spatial_size_1 = input_data.size(3)
|
|
# generate empty prev_state, if None is provided
|
|
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
|
|
state = torch.zeros(state_size, device=input_data.device)
|
|
if prev_state.size(0) == 0:
|
|
state = state + 3
|
|
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
|
|
state = state + 3
|
|
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
|
|
else:
|
|
state = state + 2
|
|
return state
|
|
|
|
class Example(torch.nn.Module):
|
|
def __init__(self, hidden_size):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
|
|
def forward(self, input_data, prev_state):
|
|
prev_state = check_init(input_data, self.hidden_size, prev_state)
|
|
return prev_state
|
|
|
|
model = Example(4)
|
|
random_data = torch.rand((1, 5, 4, 4))
|
|
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
|
|
self.run_test(model, (random_data, empty_tensor),
|
|
input_names=["random_data", "empty_tensor"],
|
|
dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]})
|
|
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
|
|
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_if_5(self):
|
|
@torch.jit.script
|
|
def check_init(input_data: Tensor, hidden_size: int, prev_state: Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
batch_size = input_data.size(0)
|
|
spatial_size_0 = input_data.size(2)
|
|
spatial_size_1 = input_data.size(3)
|
|
# generate empty prev_state, if None is provided
|
|
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
|
|
state = torch.zeros(state_size, device=input_data.device)
|
|
state_ref = state
|
|
if prev_state.size(0) == 0:
|
|
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
|
|
state = state + 3
|
|
state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
|
|
else:
|
|
state = state + 2
|
|
return state, state_ref
|
|
|
|
class Example(torch.nn.Module):
|
|
def __init__(self, hidden_size):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
|
|
def forward(self, input_data, prev_state):
|
|
prev_state, state_ref = check_init(input_data, self.hidden_size, prev_state)
|
|
return prev_state, state_ref
|
|
|
|
model = Example(4)
|
|
random_data = torch.rand((1, 5, 4, 4))
|
|
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
|
|
self.run_test(model, (random_data, empty_tensor),
|
|
input_names=["random_data", "empty_tensor"],
|
|
dynamic_axes={"random_data": [0, 1, 2, 3], "empty_tensor": [0, 1, 2, 3, 4]})
|
|
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_list_append_in_block(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
res.append(torch.matmul(x[i], y))
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(16, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_list_append_in_nested_block(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
for i in range(x.size(0)):
|
|
for j in range(x.size(1)):
|
|
res.append(torch.matmul(x[i][j], y))
|
|
return res
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(4, 4, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_list_pop_in_block(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
elem = torch.matmul(x[0], y)
|
|
for i in range(x.size(0)):
|
|
res.append(torch.matmul(x[i], y))
|
|
for i in range(x.size(0)):
|
|
elem = res.pop()
|
|
for i in range(x.size(0)):
|
|
res.append(torch.matmul(x[i], y))
|
|
elem = res.pop()
|
|
return res.append(elem)
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(16, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_list_del_in_block(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
elem = torch.matmul(x[0], y)
|
|
for i in range(x.size(0)):
|
|
res.append(torch.matmul(x[i], y))
|
|
for i in range(x.size(0)):
|
|
del res[0]
|
|
for i in range(x.size(0)):
|
|
res.append(torch.matmul(x[i], y))
|
|
del res[0]
|
|
return res.append(elem)
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(16, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_list_unpack(self):
|
|
class ListModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
res = []
|
|
elem = torch.matmul(x[0], y)
|
|
for i in range(x.size(0)):
|
|
res.append(torch.matmul(x[i], y))
|
|
a, b, c = res
|
|
return a, b
|
|
|
|
model = torch.jit.script(ListModel())
|
|
x = torch.randn(3, 3, 4)
|
|
y = torch.randn(4, 5)
|
|
self.run_test(model, (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_index_put_inplace_ops(self):
|
|
@torch.jit.script
|
|
def check_init(input_data: Tensor, hidden_size: int) -> Tensor:
|
|
batch_size = input_data.size(0)
|
|
spatial_size_0 = input_data.size(2)
|
|
spatial_size_1 = input_data.size(3)
|
|
# generate empty prev_state, if None is provided
|
|
state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
|
|
state = torch.zeros(state_size, device=input_data.device)
|
|
if input_data.size(0) == 1:
|
|
state[1] += torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 2
|
|
state[1] /= torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
|
|
for i in range(input_data.size(0)):
|
|
state[1] += torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1)
|
|
state[1] /= torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * i
|
|
return state
|
|
|
|
class Example(torch.nn.Module):
|
|
def __init__(self, hidden_size):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
|
|
def forward(self, input_data):
|
|
state = check_init(input_data, self.hidden_size)
|
|
return state
|
|
|
|
model = Example(10)
|
|
random_data = torch.rand((1, 5, 30, 30))
|
|
self.run_test(model, (random_data), input_names=["random_data"],
|
|
dynamic_axes={"random_data": [0, 1, 2, 3]})
|
|
self.run_test(model, (random_data), remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_input_mask_model(self):
|
|
class InputMaskModel(torch.nn.Module):
|
|
def __init__(self, output_size):
|
|
super(InputMaskModel, self).__init__()
|
|
self.bias = torch.nn.Parameter(torch.empty(
|
|
output_size,
|
|
dtype=torch.float))
|
|
with torch.no_grad():
|
|
self.bias.zero_()
|
|
|
|
def forward(self, model_input, y):
|
|
input_mask = (model_input <= 0) | (model_input > 25)
|
|
y[input_mask, :] = 0.0
|
|
output = y + self.bias
|
|
return output
|
|
|
|
output_size = 4
|
|
m = InputMaskModel(output_size)
|
|
x = torch.tensor([0, 4, 24, 25], dtype=torch.int64)
|
|
y = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4],
|
|
[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4]], dtype=torch.float)
|
|
self.run_test(m, (x, y))
|
|
|
|
class InputMaskModel(torch.nn.Module):
|
|
def __init__(self, output_size):
|
|
super(InputMaskModel, self).__init__()
|
|
|
|
def forward(self, model_input_1, model_input_2, y):
|
|
input_mask_1 = (model_input_1 <= 0) | (model_input_1 > 25)
|
|
input_mask_2 = (model_input_2 < 1) | (model_input_2 >= 12)
|
|
y[input_mask_1, input_mask_2] = 0.0
|
|
return y
|
|
|
|
output_size = 4
|
|
m = InputMaskModel(output_size)
|
|
x1 = torch.tensor([0, 4, 24, 25], dtype=torch.int64)
|
|
x2 = torch.tensor([0, 3, 12, 15], dtype=torch.int64)
|
|
y = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4],
|
|
[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4]], dtype=torch.float)
|
|
self.run_test(m, (x1, x2, y))
|
|
|
|
@disableScriptTest()
|
|
def test_unsafe_chunk(self):
|
|
class ChunkModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.unsafe_chunk(x, 3, dim=1)
|
|
|
|
model = ChunkModel()
|
|
model.eval()
|
|
x = torch.randn(1, 18)
|
|
self.run_test(model, x, input_names=["x"])
|
|
|
|
def test_symbolic_shape_inference(self):
|
|
# ConstantOfShape is tested in test_embedding_bag
|
|
# Tile is tested in test_repeat
|
|
# test Shape, Reshape, Transpose, Gather
|
|
class ShapeModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
shape = x.size()[:3] + (-1,) # shape [4], ("batch", 3, 4, -1)
|
|
y = y.reshape(shape) # batch, 3, 4, 10/batch
|
|
return y.transpose(1, 2)
|
|
|
|
model = ShapeModel()
|
|
model.eval()
|
|
x = torch.ones(2, 3, 4, 5)
|
|
y = torch.ones(3, 4, 5, 2)
|
|
self.run_test(model, (x, y), input_names=["x", "y"],
|
|
dynamic_axes={"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]})
|
|
self.run_test(model, (x, y), remained_onnx_input_idx=[1])
|
|
|
|
class ViewModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.view(-1)
|
|
|
|
model = ViewModel()
|
|
model.eval()
|
|
x = torch.tensor(2.)
|
|
self.run_test(model, (x,))
|
|
|
|
# test prim::ListConstruct for Reshape input 1
|
|
class ViewModel_2(torch.nn.Module):
|
|
def forward(self, x):
|
|
N, C, H, W = x.shape[0], x.shape[2], x.shape[3], x.shape[4]
|
|
x1 = x.view(N, -1, C, H, W)
|
|
x2 = x1.permute(0, 3, 4, 1, 2)
|
|
return x2.reshape(N, -1, C)
|
|
|
|
model = ViewModel_2()
|
|
model.eval()
|
|
x = torch.ones(2, 3, 4, 5, 6)
|
|
self.run_test(model, x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_symbolic_shape_inference_arange(self):
|
|
# test Range
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, signal):
|
|
frame_step = 2
|
|
outer_dimensions = signal.size()[:-2]
|
|
frames, frame_length = signal.size()[-2:]
|
|
|
|
subframe_length = signal.size()[0]
|
|
subframe_step = frame_step // subframe_length
|
|
subframes_per_frame = frame_length // subframe_length
|
|
output_size = frame_step * (frames - 1) + frame_length
|
|
output_subframes = output_size // subframe_length
|
|
|
|
frame = torch.arange(0, output_subframes)
|
|
return frame
|
|
|
|
model = ArangeModel()
|
|
model.eval()
|
|
M, C, K, N = 1, 2, 3, 4
|
|
x = torch.randint(5, (M, C, K, N))
|
|
y = torch.randint(5, (M, C + 1, K + 1, N + 1))
|
|
self.run_test(model, x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]})
|
|
self.run_test(model, x, remained_onnx_input_idx=[])
|
|
self.run_test(model, x, input_names=["x"],
|
|
dynamic_axes={"x" : [0, 1, 2, 3]}, test_with_inputs=[(x,), (y,)])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_symbolic_shape_inference_box(self):
|
|
# test NonZero
|
|
class BoxModel(torch.nn.Module):
|
|
def forward(self, boxes):
|
|
min_size = 1e-2
|
|
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
|
|
keep = (ws >= min_size) & (hs >= min_size)
|
|
keep = torch.where(keep)[0]
|
|
return keep
|
|
|
|
model = BoxModel()
|
|
model.eval()
|
|
x = torch.ones(2, 4)
|
|
y = torch.ones(3, 5)
|
|
self.run_test(model, x)
|
|
self.run_test(model, x, input_names=["x"],
|
|
dynamic_axes={"x" : [0, 1]}, test_with_inputs=[(x,), (y,)])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_symbolic_shape_inference_box_if(self):
|
|
# test If
|
|
class BoxIfModel(torch.nn.Module):
|
|
def forward(self, boxes, scores):
|
|
score_thresh = 0.0
|
|
inds = torch.where(scores > score_thresh)[0]
|
|
boxes_1 = boxes[inds]
|
|
if boxes_1.numel() > 3:
|
|
return boxes_1
|
|
else:
|
|
return boxes_1 * 2
|
|
|
|
model = BoxIfModel()
|
|
model.eval()
|
|
boxes = torch.ones(2, 4)
|
|
scores = torch.ones(1, 4)
|
|
self.run_test(model, (boxes, scores))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_symbolic_shape_inference_arange_2(self):
|
|
# test Range
|
|
class ArangeModel(torch.nn.Module):
|
|
def forward(self, start):
|
|
return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.int64)
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(ArangeModel(), (x,), input_names=['x'], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ArangeModel(), (x,), remained_onnx_input_idx=[])
|
|
|
|
class ArangeModel2(torch.nn.Module):
|
|
def forward(self, start):
|
|
return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.double)
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(ArangeModel2(), (x,), input_names=['x'], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ArangeModel2(), (x,), remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_symbolic_shape_inference_nonzero(self):
|
|
class OneLikeModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
ones = torch.ones_like(x, dtype=torch.float, layout=torch.strided, device=torch.device("cpu"))
|
|
return torch.nonzero(ones)
|
|
|
|
x = torch.randn(2)
|
|
self.run_test(OneLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0]})
|
|
self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[])
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(OneLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[])
|
|
|
|
class ZeroLikeModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
zeros = torch.zeros_like(x, dtype=torch.float, layout=torch.strided, device=torch.device("cpu"))
|
|
return torch.nonzero(zeros)
|
|
|
|
x = torch.randn(2)
|
|
self.run_test(ZeroLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0]})
|
|
self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[])
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(ZeroLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0, 1, 2]})
|
|
self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_symbolic_shape_inference_expand_1(self):
|
|
class ExpandModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.expand(4, 6, 2)
|
|
x = torch.randn(6, 1, requires_grad=True)
|
|
self.run_test(ExpandModel(), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest() # Test code not scriptable
|
|
def test_symbolic_shape_inference_expand_2(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
input_shape = x.size()
|
|
batch_size, seq_length = input_shape
|
|
seq_ids = torch.arange(seq_length)
|
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
return causal_mask.transpose(0, 1)
|
|
x = torch.randn(3, 16)
|
|
self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]})
|
|
self.run_test(M(), (x,), remained_onnx_input_idx=[])
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@disableScriptTest() # Test code not scriptable
|
|
def test_symbolic_shape_inference_slice(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, position_bias):
|
|
input_shape = x.size()
|
|
batch_size, seq_length = input_shape
|
|
position_bias = position_bias[:, :, -seq_length:, :]
|
|
return position_bias.transpose(0, 1)
|
|
x = torch.randn(3, 16)
|
|
position_bias = torch.randn(1, 3, 20, 8)
|
|
self.run_test(M(), (x, position_bias), input_names=["x", "position_bias"],
|
|
dynamic_axes={"x": [0, 1], "position_bias": [0, 1, 2, 3]})
|
|
self.run_test(M(), (x, position_bias), remained_onnx_input_idx=[1])
|
|
|
|
def test_symbolic_shape_inference_slice_2(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, position_bias):
|
|
position_bias = position_bias[:, :, -2:, :]
|
|
return position_bias.transpose(0, 1)
|
|
position_bias = torch.randn(1, 3, 20, 8)
|
|
self.run_test(M(), (position_bias,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest()
|
|
def test_symbolic_shape_inference_time(self):
|
|
input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
|
|
h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
|
model_lstm = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
|
|
self.run_test(model_lstm, (input, (h0, c0)), input_names=["x", "y"],
|
|
dynamic_axes={"x" : [0, 1]})
|
|
model_gru = torch.nn.GRU(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False)
|
|
self.run_test(model_gru, (input, h0), input_names=["x", "y"],
|
|
dynamic_axes={"x" : [0, 1]})
|
|
model_rnn = torch.nn.RNN(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False, bias=False)
|
|
self.run_test(model_rnn, (input, h0), input_names=["x", "y"],
|
|
dynamic_axes={"x" : [0, 1]})
|
|
|
|
def test_symbolic_shape_inference_dynamic_axes(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, input_ids):
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
return input_ids.transpose(0, 1)
|
|
x = torch.randn(3, 16)
|
|
self.run_test(M(), (x,), input_names=["input_ids"],
|
|
dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_hann_window_periodic(self):
|
|
class HannWindowModule_Periodic(torch.nn.Module):
|
|
def __init__(self):
|
|
super(HannWindowModule_Periodic, self).__init__()
|
|
self.window_length = 0
|
|
|
|
def forward(self, x, window_length: int):
|
|
self.window_length = window_length
|
|
return torch.add(x, torch.hann_window(self.window_length, periodic=True, dtype=torch.float))
|
|
|
|
win_length = 100
|
|
x = torch.randn(win_length)
|
|
|
|
module = HannWindowModule_Periodic()
|
|
self.run_test(module, (x, win_length))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_hann_window_not_periodic(self):
|
|
class HannWindowModule_NotPeriodic(torch.nn.Module):
|
|
def __init__(self):
|
|
super(HannWindowModule_NotPeriodic, self).__init__()
|
|
self.window_length = 0
|
|
|
|
def forward(self, x, window_length: int):
|
|
self.window_length = window_length
|
|
return torch.add(x, torch.hann_window(self.window_length, periodic=False, dtype=torch.float))
|
|
|
|
win_length = 100
|
|
x = torch.randn(win_length)
|
|
|
|
module = HannWindowModule_NotPeriodic()
|
|
self.run_test(module, (x, win_length))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest()
|
|
def test_hann_window_default_values(self):
|
|
class HannWindowModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(HannWindowModule, self).__init__()
|
|
self.window_length = 0
|
|
|
|
def forward(self, x, window_length: int):
|
|
import torch.nn.functional as F
|
|
self.window_length = window_length
|
|
return torch.add(x, F.relu(torch.hann_window(self.window_length)))
|
|
|
|
win_length = 100
|
|
x = torch.randn(win_length, dtype=torch.float)
|
|
module = HannWindowModule()
|
|
|
|
output = module(x, win_length)
|
|
self.run_test(module, (x, win_length))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest()
|
|
def test_tensordot_dim_count(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
output = torch.tensordot(x, y, 2)
|
|
return output
|
|
|
|
x = torch.randint(6, (7, 5, 3, 4))
|
|
y = torch.randint(6, (3, 4, 9, 2))
|
|
|
|
self.run_test(M(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
def test_tensordot_dim_list(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
output = torch.tensordot(x, y, ([1, -2, -1], [1, 0, 3]))
|
|
return output
|
|
|
|
x = torch.randint(6, (7, 4, 3, 5, 2))
|
|
y = torch.randint(6, (5, 4, 4, 2, 6))
|
|
|
|
self.run_test(M(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@disableScriptTest()
|
|
def test_tensordot_dynamic_dim(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
output = torch.tensordot(x, y, 2)
|
|
return output
|
|
|
|
x = torch.randint(6, (7, 5, 3, 4))
|
|
y = torch.randint(6, (3, 4, 9, 2))
|
|
|
|
new_x = torch.randint(6, (8, 6, 2, 5))
|
|
new_y = torch.randint(6, (2, 5, 3, 4))
|
|
|
|
self.run_test(M(), (x, y), test_with_inputs=[(new_x, new_y)],
|
|
input_names=["input_x", "input_y"],
|
|
dynamic_axes={"input_x": [0, 1, 2, 3], "input_y": [0, 1, 2, 3]})
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_to_device(self):
|
|
class M_ToDevice(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.to(y.device), y
|
|
|
|
class M_ToDeviceDtype(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x.to(y.device, dtype=torch.long), y
|
|
|
|
x = torch.randn(6)
|
|
y = torch.randn(6)
|
|
|
|
self.run_test(M_ToDevice(), (x, y))
|
|
self.run_test(M_ToDeviceDtype(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@disableScriptTest()
|
|
def test_fill(self):
|
|
class FillModule(torch.nn.Module):
|
|
def forward(self, x, filled_value: int):
|
|
return x.fill_(filled_value)
|
|
|
|
x = torch.randn((4, 5, 6))
|
|
filled_value = 7
|
|
self.run_test(FillModule(), (x, filled_value))
|
|
|
|
class FillScalarModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
res = x + 2
|
|
res.fill_(2.5)
|
|
return res, x
|
|
|
|
x = torch.ones(2, 3, 4, dtype=torch.long)
|
|
self.run_test(FillScalarModule(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_add_normal(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim, index, updates):
|
|
super(M, self).__init__()
|
|
self.dim = dim
|
|
self.index = index
|
|
self.updates = updates
|
|
|
|
def forward(self, x):
|
|
x.index_add_(self.dim, self.index, self.updates)
|
|
return x
|
|
|
|
x = torch.ones(5, 4, 3)
|
|
updates = torch.tensor([[1], [4], [7], [3], [2]], dtype=torch.float)
|
|
index = torch.tensor([0, 2, 3, 1, 4])
|
|
self.run_test(M(0, index, updates), (x,))
|
|
|
|
updates = torch.tensor([[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float)
|
|
index = torch.tensor([0, 2, 3, 1])
|
|
self.run_test(M(1, index, updates), (x,))
|
|
|
|
updates = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [2, 3, 4]]], dtype=torch.float)
|
|
index = torch.tensor([0, 2, 1])
|
|
self.run_test(M(2, index, updates), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_add_dim_size_differ(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim, index, updates):
|
|
super(M, self).__init__()
|
|
self.dim = dim
|
|
self.index = index
|
|
self.updates = updates
|
|
|
|
def forward(self, x):
|
|
x.index_add_(self.dim, self.index, self.updates)
|
|
return x
|
|
|
|
x = torch.ones(5, 4, 3)
|
|
updates = torch.tensor([[[1, 5, 7], [2, 4, 5], [5, 5, 6]]], dtype=torch.float)
|
|
index = torch.tensor([0, 2, 1])
|
|
self.run_test(M(1, index, updates), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_add_in_loop(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim, index, updates, loop_count):
|
|
super(M, self).__init__()
|
|
self.dim = dim
|
|
self.index = index
|
|
self.updates = updates
|
|
self.loop_count = loop_count
|
|
|
|
def forward(self, x):
|
|
for i in range(self.loop_count):
|
|
x.index_add_(self.dim, self.index, self.updates)
|
|
return x
|
|
|
|
x = torch.ones(5, 4, 3)
|
|
updates = torch.tensor([[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float)
|
|
index = torch.tensor([0, 2, 3, 1])
|
|
loop_count = torch.randint(20, (1, ))[0].item()
|
|
self.run_test(M(1, index, updates, loop_count), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_add_if(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim, updates, index_true, index_false):
|
|
super(M, self).__init__()
|
|
self.dim = dim
|
|
self.updates = updates
|
|
self.index_true = index_true
|
|
self.index_false = index_false
|
|
|
|
def forward(self, x, cond):
|
|
if cond:
|
|
x.index_add_(self.dim, self.index_true, self.updates)
|
|
else:
|
|
x.index_add_(self.dim, self.index_false, self.updates)
|
|
return x
|
|
|
|
x = torch.ones(5, 4, 3)
|
|
updates = torch.tensor([[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float)
|
|
index_true = torch.tensor([0, 2, 3, 1])
|
|
index_false = torch.tensor([1, 0, 2, 3])
|
|
cond = torch.tensor(1, dtype=torch.bool)
|
|
self.run_test(torch.jit.script(M(1, updates, index_true, index_false)), (x, cond))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_index_add_dynamic_axes(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim, index, updates):
|
|
super(M, self).__init__()
|
|
self.dim = dim
|
|
self.index = index
|
|
self.updates = updates
|
|
|
|
def forward(self, x):
|
|
x.index_add_(self.dim, self.index, self.updates)
|
|
return x
|
|
|
|
x = torch.ones(5, 4, 3)
|
|
y = torch.ones(7, 8, 3)
|
|
updates = torch.tensor([[[1, 5, 7], [2, 4, 5], [5, 5, 6], [2, 3, 4]]], dtype=torch.float)
|
|
index = torch.tensor([0, 2, 3, 1])
|
|
|
|
self.run_test(M(1, index, updates), (x,), test_with_inputs=[y],
|
|
input_names=['input_1'],
|
|
dynamic_axes={'input_1': [0, 1]})
|
|
|
|
def test_roll(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, shifts, dims):
|
|
super(M, self).__init__()
|
|
self.shifts = shifts
|
|
self.dims = dims
|
|
|
|
def forward(self, x):
|
|
return torch.roll(x, self.shifts, self.dims)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
self.run_test(M([1, 1], [1, 0]), (x,))
|
|
self.run_test(M([0, 1, 2], [1, 0, 2]), (x,))
|
|
self.run_test(M(2, 1), (x,))
|
|
self.run_test(M([-1, 3], [-2, -1]), (x,))
|
|
|
|
def test_sum(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.sum(x)
|
|
|
|
x = torch.ones(12, 3)
|
|
self.run_test(M(), (x,), input_names=['x'], dynamic_axes={'x': [0]})
|
|
|
|
def test_sum_empty_tensor(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x[0:0].sum(), x.sum()
|
|
|
|
x = torch.ones(12)
|
|
self.run_test(M(), (x,))
|
|
|
|
x = torch.ones(2, 0, 3)
|
|
self.run_test(M(), (x,))
|
|
|
|
x = torch.ones(0)
|
|
self.run_test(M(), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_broad_cast_tensors(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
m = torch.broadcast_tensors(x, y)
|
|
return m
|
|
|
|
x = torch.randint(5, (1,))
|
|
y = torch.randint(5, (5,))
|
|
|
|
self.run_test(M(), (x, y))
|
|
|
|
x = torch.randint(5, (4, 2, 1, 4))
|
|
y = torch.randint(5, (2, 3, 1))
|
|
|
|
self.run_test(M(), (x, y))
|
|
|
|
x = torch.randn(2, 1, 4)
|
|
y = torch.randn(5, 2, 3, 1)
|
|
|
|
self.run_test(M(), (x, y))
|
|
|
|
@disableScriptTest()
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_dist_normal(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.distributions.Normal(x, y).sample().size(0), x, y
|
|
|
|
self.run_test(M(), (torch.tensor([0.0]), torch.tensor([[1.0], [2.0]])))
|
|
self.run_test(M(), (torch.tensor([0.0]), torch.tensor([1.0])))
|
|
|
|
self.run_test(M(), (torch.tensor([[[0.0], [10.0]], [[2.0], [8.0]], [[2.0], [8.0]]]), torch.tensor([[1.0], [3.0]])))
|
|
|
|
@disableScriptTest()
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_dist_normal_correctness(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.distributions.Normal(x, y).sample([20000])
|
|
|
|
expected_mean = 5.0
|
|
expected_std = 10.0
|
|
|
|
model_export = M()
|
|
dummy_input = (torch.tensor([expected_mean]), torch.tensor([expected_std]))
|
|
ort_sess = convert_to_onnx(model_export, input=dummy_input, opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.EVAL)
|
|
|
|
ort_out = run_ort(ort_sess, input=dummy_input)
|
|
|
|
actual_std = np.std(ort_out)
|
|
actual_mean = np.mean(ort_out)
|
|
|
|
assert abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1, \
|
|
"the gap of mean between ort outputs and expected one is unacceptable."
|
|
assert abs(abs(actual_std) - expected_std) <= expected_std * 0.1, \
|
|
"the gap of variance between ort outputs and expected one is unacceptable."
|
|
|
|
@disableScriptTest()
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_dist_uniform(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.distributions.Uniform(x, y).sample().size(0), x , y
|
|
|
|
self.run_test(M(), (torch.tensor([0.0]), torch.tensor([10.0])))
|
|
self.run_test(M(), (torch.tensor([[0.0], [6.0]]), torch.tensor([[1.0], [7.0]])))
|
|
self.run_test(M(), (torch.tensor([1.0]), torch.tensor([[10.0], [7.0], [9.0], [20.0]])))
|
|
|
|
@disableScriptTest()
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_dist_uniform_correctness(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.distributions.Uniform(x, y).sample([10000])
|
|
|
|
expected_min = 5.0
|
|
expected_max = 10.0
|
|
expected_mean = (expected_min + expected_max) / 2
|
|
|
|
model_export = M()
|
|
dummy_input = (torch.tensor([expected_min]), torch.tensor([expected_max]))
|
|
ort_sess = convert_to_onnx(model_export, input=dummy_input, opset_version=self.opset_version,
|
|
training=torch.onnx.TrainingMode.EVAL)
|
|
|
|
ort_out = run_ort(ort_sess, input=dummy_input)
|
|
actual_min = np.min(ort_out)
|
|
actual_max = np.max(ort_out)
|
|
actual_mean = np.mean(ort_out)
|
|
|
|
assert actual_min >= expected_min, "the minimum value of ort outputs is out of scope."
|
|
assert actual_max <= expected_max, "the maximum value of ort outputs is out of scope."
|
|
assert abs(actual_mean - expected_mean) <= expected_mean * 0.05, \
|
|
"the mean value of ort outputs is out of scope."
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_sequence_to_int(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
result = torch.tensor([2 for i in range(x.size()[0])], dtype=torch.int)
|
|
return x, result
|
|
|
|
x = torch.randn(10, 5)
|
|
self.run_test(M(), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_sequence_to_float(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
result = torch.tensor([1.1 for i in range(x.size()[0])], dtype=torch.float)
|
|
return x, result
|
|
|
|
x = torch.randn(10, 5)
|
|
self.run_test(M(), (x,))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_sequence_to_bool(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
result = torch.tensor([False for i in range(x.size()[0])], dtype=torch.bool)
|
|
return x, result
|
|
|
|
x = torch.randn(10, 5)
|
|
self.run_test(M(), (x,))
|
|
|
|
def test_onnx_checker_invalid_graph(self):
|
|
class CustomAddModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.add(x, y)
|
|
|
|
def symbolic_custom_invalid_add(g, input, other, alpha=None):
|
|
return g.op("Add", input, other, invalid_attr_i=1)
|
|
|
|
register_custom_op_symbolic("::add", symbolic_custom_invalid_add, 1)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
y = torch.randn(2, 3, 4)
|
|
|
|
test_model = CustomAddModule()
|
|
f = io.BytesIO()
|
|
|
|
try:
|
|
with self.assertRaises(CheckerError) as cm:
|
|
torch.onnx.export(test_model, (x, y), f)
|
|
finally:
|
|
unregister_custom_op_symbolic("::add", 1)
|
|
|
|
self.assertTrue(f.getvalue(), "ONNX graph was not exported.")
|
|
loaded_model = onnx.load_from_string(f.getvalue())
|
|
|
|
|
|
def test_tuple_output_from_if_with_raised_exception(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
|
|
def forward(self, t: Tensor) -> Tuple[Tensor, Tensor]:
|
|
if float(t) < 0:
|
|
raise Exception("Negative input")
|
|
else:
|
|
return torch.zeros(5), torch.zeros(5)
|
|
x = torch.zeros(1)
|
|
self.run_test(torch.jit.script(M()), (x,))
|
|
|
|
def test_shape_value_map(self):
|
|
class RSoftMax(torch.nn.Module):
|
|
def __init__(self, radix, cardinality):
|
|
super().__init__()
|
|
self.radix = radix
|
|
self.cardinality = cardinality
|
|
|
|
def forward(self, x):
|
|
batch = x.size(0)
|
|
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
|
|
x = F.softmax(x, dim=1)
|
|
x = x.reshape(batch, -1)
|
|
return x
|
|
radix = 2
|
|
cardinality = 1
|
|
x = torch.randn(10, 1, 128, 1)
|
|
f = io.BytesIO()
|
|
torch.onnx.export(RSoftMax(radix, cardinality), (x, ), f, input_names=["x"], dynamic_axes={"x": [0]})
|
|
loaded_model = onnx.load_from_string(f.getvalue())
|
|
self.assertEqual(loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128)
|
|
|
|
# NOTE: For quantization tests, choose scale and zero point carefully
|
|
# such that inputs and outputs do not always overflow/underflow.
|
|
# Otherwise test results could be inaccurate.
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_quantized_linear(self):
|
|
model = torch.nn.quantized.Linear(4, 8)
|
|
# Set fixed weight to avoid flaky test.
|
|
weight = torch.quantize_per_tensor(
|
|
torch.arange(32, dtype=torch.float).view(8, 4),
|
|
0.5, 0, torch.qint8)
|
|
# Set non-zero bias.
|
|
bias = torch.arange(8, dtype=torch.float)
|
|
model.set_weight_bias(weight, bias)
|
|
# Set fixed input to avoid flaky test.
|
|
input = torch.randn(4, 4)
|
|
input = torch.arange(16, dtype=torch.float).view(4, 4) - 8
|
|
input_tensor = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
|
|
self.run_test(model, input_tensor)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_quantized_conv2d(self):
|
|
model = torch.nn.quantized.Conv2d(16, 33, 3, stride=2)
|
|
# Manually initialize model weight and bias to random numbers.
|
|
# By default all zeros.
|
|
q_weight = torch.quantize_per_tensor(torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8)
|
|
bias = torch.arange(33).to(torch.float) - 16
|
|
model.set_weight_bias(q_weight, bias)
|
|
input = torch.randn(3, 16, 32, 32)
|
|
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
|
|
self.run_test(model, q_input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_quantized_adaptive_avg_pool2d(self):
|
|
model = torch.nn.AdaptiveAvgPool2d((5, 7))
|
|
input = torch.randn(4, 3, 10, 14)
|
|
q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
|
|
self.run_test(model, q_input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_quantized_conv2d_relu(self):
|
|
model = torch.nn.intrinsic.quantized.ConvReLU2d(16, 33, 3, stride=2)
|
|
# Manually initialize model weight and bias to random numbers.
|
|
# By default all zeros.
|
|
q_weight = torch.quantize_per_tensor(torch.randn(33, 16, 3, 3), 0.5, 0, torch.qint8)
|
|
bias = torch.arange(33).to(torch.float) - 16
|
|
model.set_weight_bias(q_weight, bias)
|
|
input = torch.randn(3, 16, 32, 32)
|
|
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
|
|
self.run_test(model, q_input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_quantized_hardswish(self):
|
|
model = torch.nn.quantized.Hardswish(1., 0)
|
|
input = torch.randn(2, 6)
|
|
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
|
|
self.run_test(model, q_input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_quantized_hardsigmoid(self):
|
|
model = torch.nn.Hardsigmoid()
|
|
input = torch.randn(2, 6)
|
|
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
|
|
self.run_test(model, q_input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_quantized_flatten(self):
|
|
class FlattenModel(torch.nn.Module):
|
|
def forward(self, input):
|
|
return torch.flatten(input)
|
|
|
|
x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8)
|
|
self.run_test(FlattenModel(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
@disableScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
|
|
def test_quantized_arithmetic_qfunctional(self):
|
|
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
|
|
y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
|
|
|
|
class ArithmeticModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
o = torch.nn.quantized.QFunctional().add(x, y)
|
|
o = torch.nn.quantized.QFunctional().mul(o, x)
|
|
return o
|
|
|
|
self.run_test(ArithmeticModel(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_quantized_arithmetic(self):
|
|
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
|
|
y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
|
|
|
|
class ArithmeticModel2(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
o = torch.ops.quantized.add(x, y, 0.4, 100)
|
|
o = torch.ops.quantized.mul(o, x, 0.4, 100)
|
|
return o
|
|
|
|
self.run_test(ArithmeticModel2(), (x, y))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_quantize_per_tensor(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
return (torch.quantize_per_tensor(x, 0.2, 0, torch.qint8),
|
|
torch.quantize_per_tensor(x, 0.2, 128, torch.quint8))
|
|
|
|
x = torch.randn(4, 6)
|
|
self.run_test(Module(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_dequantize(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.dequantize(x)
|
|
|
|
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 0, torch.qint8)
|
|
self.run_test(Module(), x)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_qat_linear_per_channel(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.quant = torch.quantization.QuantStub()
|
|
self.linear = torch.nn.Linear(4, 3)
|
|
self.dequant = torch.quantization.DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.linear(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
model = M()
|
|
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
|
model = torch.quantization.prepare_qat(model)
|
|
# Set fixed weight and bias to avoid flaky test.
|
|
model.linear.weight = torch.nn.Parameter(_construct_tensor_for_quantization_test((3, 4)))
|
|
model.linear.bias = torch.nn.Parameter(torch.arange(3, dtype=torch.float))
|
|
model = torch.quantization.convert(model)
|
|
|
|
# Set fixed input to avoid flaky test.
|
|
input = _construct_tensor_for_quantization_test((4, 4), offset=-8)
|
|
self.run_test(model, input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_qat_relu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.quant = torch.quantization.QuantStub()
|
|
self.relu = torch.nn.ReLU()
|
|
self.dequant = torch.quantization.DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.relu(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
model = M()
|
|
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
|
model = torch.quantization.prepare_qat(model)
|
|
model = torch.quantization.convert(model)
|
|
input = torch.randn(8, 4)
|
|
self.run_test(model, input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_qat_conv2d(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.quant = torch.quantization.QuantStub()
|
|
self.conv = torch.nn.Conv2d(2, 4, 3, stride=2)
|
|
self.dequant = torch.quantization.DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.conv(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
model = M()
|
|
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
|
model = torch.quantization.prepare_qat(model)
|
|
# Set fixed weight and bias to avoid flaky test.
|
|
model.conv.weight = torch.nn.Parameter(_construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2))
|
|
model.conv.bias = torch.nn.Parameter(torch.tensor([0., 1.]))
|
|
model = torch.quantization.convert(model)
|
|
|
|
# Set fixed input to avoid flaky test.
|
|
input = _construct_tensor_for_quantization_test((3, 4, 8, 8), offset=-384, max_val=12)
|
|
self.run_test(model, input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_qat_conv2d_relu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.quant = torch.quantization.QuantStub()
|
|
self.conv = torch.nn.Conv2d(2, 4, 3, stride=2)
|
|
self.relu = torch.nn.ReLU()
|
|
self.dequant = torch.quantization.DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.conv(x)
|
|
x = self.relu(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
model = M()
|
|
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
|
model = torch.quantization.prepare_qat(model)
|
|
# Set fixed weight and bias to avoid flaky test.
|
|
model.conv.weight = torch.nn.Parameter(_construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2))
|
|
model.conv.bias = torch.nn.Parameter(torch.tensor([0., 1.]))
|
|
model = torch.quantization.convert(model)
|
|
|
|
# Set fixed input to avoid flaky test.
|
|
input = _construct_tensor_for_quantization_test((3, 4, 8, 8), offset=-384, max_val=12)
|
|
self.run_test(model, input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_qat_conv2d_relu_fused(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.quant = torch.quantization.QuantStub()
|
|
self.conv = torch.nn.Conv2d(2, 4, 3, stride=2)
|
|
self.relu = torch.nn.ReLU()
|
|
self.dequant = torch.quantization.DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.conv(x)
|
|
x = self.relu(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
model = M()
|
|
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
|
model = torch.quantization.fuse_modules(model.eval(), [["conv", "relu"]])
|
|
model = torch.quantization.prepare_qat(model.train())
|
|
# Set fixed weight and bias to avoid flaky test.
|
|
model.conv.weight = torch.nn.Parameter(_construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2))
|
|
model.conv.bias = torch.nn.Parameter(torch.tensor([0., 1.]))
|
|
model = torch.quantization.convert(model)
|
|
|
|
# Set fixed input to avoid flaky test.
|
|
input = _construct_tensor_for_quantization_test((3, 4, 8, 8), offset=-384, max_val=12)
|
|
self.run_test(model, input)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def test_convolution_allow_tf32(self):
|
|
class Module(torch.nn.Module):
|
|
def __init__(self, allow_tf32):
|
|
super().__init__()
|
|
|
|
self.allow_tf32 = allow_tf32
|
|
weight = torch.rand(32, 3, 3, 3)
|
|
self.weight = torch.nn.Parameter(weight)
|
|
|
|
def forward(self, x):
|
|
if self.allow_tf32:
|
|
return torch._convolution(x, self.weight, None, [2, 2], [0, 0], [1, 1], False, [0, 0],
|
|
1, False, False, True, True)
|
|
else:
|
|
return torch._convolution(x, self.weight, None, [2, 2], [0, 0], [1, 1], False, [0, 0],
|
|
1, False, False, True)
|
|
|
|
x = torch.randn(1, 3, 224, 224)
|
|
self.run_test(Module(False), x, rtol=1e-3, atol=1e-6)
|
|
self.run_test(Module(True), x, rtol=1e-3, atol=1e-6)
|
|
|
|
|
|
def make_test(name, base, layer, bidirectional, initial_state,
|
|
variable_length, dropout, script_test_min_opset_version,
|
|
**extra_kwargs):
|
|
test_name = str("_".join([
|
|
"test", name, layer[1],
|
|
bidirectional[1], initial_state[1],
|
|
variable_length[1], dropout[1]
|
|
]))
|
|
|
|
# Cannot export with older opsets because of "ConstantFill" op
|
|
# ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime
|
|
# There are still some issues prevent us from enabling script test for these scenarios:
|
|
# test_gru_*:
|
|
# Operator aten::as_tensor is not supported by exporter yet.
|
|
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055382
|
|
# Operator aten::_pack_padded_sequence is not supported by exporter yet.
|
|
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055384
|
|
@disableScriptTest()
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
def f(self):
|
|
self.is_script_test_enabled = self.opset_version >= script_test_min_opset_version
|
|
self._dispatch_rnn_test(
|
|
base,
|
|
layers=layer[0],
|
|
bidirectional=bidirectional[0],
|
|
initial_state=initial_state[0],
|
|
packed_sequence=variable_length[0],
|
|
dropout=dropout[0],
|
|
**extra_kwargs)
|
|
|
|
f.__name__ = test_name
|
|
setattr(_TestONNXRuntime, f.__name__, f)
|
|
|
|
def setup_rnn_tests():
|
|
layers_opts = [
|
|
(1, "unilayer"),
|
|
(3, "trilayer")
|
|
]
|
|
bidirectional_opts = [
|
|
(False, "forward"),
|
|
(True, "bidirectional")
|
|
]
|
|
initial_state_opts = [
|
|
(True, "with_initial_state"),
|
|
(False, "no_initial_state")
|
|
]
|
|
variable_length_opts = [
|
|
(0, "without_sequence_lengths"),
|
|
(1, "with_variable_length_sequences"),
|
|
(2, "with_batch_first_sequence_lengths")
|
|
]
|
|
dropout_opts = [
|
|
(0.2, "with_dropout"),
|
|
(0.0, "without_dropout")
|
|
]
|
|
test_count = 0
|
|
for (layer, bidirectional, initial_state, variable_length, dropout) in \
|
|
itertools.product(
|
|
layers_opts,
|
|
bidirectional_opts,
|
|
initial_state_opts,
|
|
variable_length_opts,
|
|
dropout_opts,):
|
|
|
|
for base, name, extra_kwargs in (
|
|
("elman", "elman_relu", {"nonlinearity": u"relu"}),
|
|
("elman", "elman_tanh", {"nonlinearity": u"tanh"}),
|
|
("lstm", "lstm", {}),
|
|
("gru", "gru", {})
|
|
):
|
|
# Need Add between list of tensors
|
|
script_test_min_opset_version = 11
|
|
|
|
if ( # compiling in script mode fails with errors like:
|
|
# torch.jit.frontend.UnsupportedNodeError: annotated assignments
|
|
# without assigned value aren't supported
|
|
# https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
|
|
base == 'elman' or
|
|
# compiling in script mode fails with errors like:
|
|
# RuntimeError: Arguments for call are not valid.
|
|
# https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
|
|
base == 'lstm'):
|
|
script_test_min_opset_version = float("inf")
|
|
make_test(name, base, layer, bidirectional, initial_state,
|
|
variable_length, dropout, script_test_min_opset_version,
|
|
**extra_kwargs)
|
|
test_count += 1
|
|
|
|
# sanity check that a representative example does exist
|
|
_TestONNXRuntime.test_gru_trilayer_forward_with_initial_state_without_sequence_lengths_with_dropout
|
|
|
|
# make sure no one accidentally disables all the tests without
|
|
# noticing
|
|
if test_count != 192:
|
|
raise ValueError("Expected 192 tests but found {}".format(test_count))
|
|
|
|
setup_rnn_tests()
|
|
|
|
|
|
def MakeTestCase(opset_version: int, keep_initializers_as_inputs: bool = True) -> type:
|
|
name = f"TestONNXRuntime_opset{opset_version}"
|
|
if not keep_initializers_as_inputs:
|
|
name += "_IRv4"
|
|
return type(str(name),
|
|
(unittest.TestCase,),
|
|
dict(_TestONNXRuntime.__dict__,
|
|
opset_version=opset_version,
|
|
keep_initializers_as_inputs=keep_initializers_as_inputs))
|
|
|
|
|
|
TestONNXRuntime_opset7 = MakeTestCase(7)
|
|
|
|
TestONNXRuntime_opset8 = MakeTestCase(8)
|
|
|
|
TestONNXRuntime_opset9 = MakeTestCase(9)
|
|
|
|
TestONNXRuntime_opset9_IRv4 = MakeTestCase(9, keep_initializers_as_inputs=False)
|
|
|
|
TestONNXRuntime_opset10 = MakeTestCase(10)
|
|
|
|
TestONNXRuntime_opset10_IRv4 = MakeTestCase(10, keep_initializers_as_inputs=False)
|
|
|
|
TestONNXRuntime_opset11 = MakeTestCase(11)
|
|
|
|
TestONNXRuntime_opset11_IRv4 = MakeTestCase(11, keep_initializers_as_inputs=False)
|
|
|
|
TestONNXRuntime_opset12 = MakeTestCase(12)
|
|
|
|
TestONNXRuntime_opset12_IRv4 = MakeTestCase(12, keep_initializers_as_inputs=False)
|
|
|
|
TestONNXRuntime_opset13 = MakeTestCase(13, keep_initializers_as_inputs=False)
|
|
|
|
TestONNXRuntime_opset14 = MakeTestCase(14, keep_initializers_as_inputs=False)
|
|
|
|
TestONNXRuntime_opset15 = MakeTestCase(15, keep_initializers_as_inputs=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|