mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129759 Approved by: https://github.com/justinchuby, https://github.com/ezyang
288 lines
11 KiB
Python
288 lines
11 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
import unittest
|
|
|
|
import pytorch_test_common
|
|
from model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init
|
|
from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2
|
|
from model_defs.mnist import MNIST
|
|
from model_defs.op_test import ConcatNet, DummyNet, FakeQuantNet, PermuteNet, PReluNet
|
|
from model_defs.squeezenet import SqueezeNet
|
|
from model_defs.srresnet import SRResNet
|
|
from model_defs.super_resolution import SuperResolutionNet
|
|
from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest
|
|
from torchvision.models import shufflenet_v2_x1_0
|
|
from torchvision.models.alexnet import alexnet
|
|
from torchvision.models.densenet import densenet121
|
|
from torchvision.models.googlenet import googlenet
|
|
from torchvision.models.inception import inception_v3
|
|
from torchvision.models.mnasnet import mnasnet1_0
|
|
from torchvision.models.mobilenet import mobilenet_v2
|
|
from torchvision.models.resnet import resnet50
|
|
from torchvision.models.segmentation import deeplabv3_resnet101, fcn_resnet101
|
|
from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
|
|
from torchvision.models.video import mc3_18, r2plus1d_18, r3d_18
|
|
from verify import verify
|
|
|
|
import torch
|
|
from torch.ao import quantization
|
|
from torch.autograd import Variable
|
|
from torch.onnx import OperatorExportTypes
|
|
from torch.testing._internal import common_utils
|
|
from torch.testing._internal.common_utils import skipIfNoLapack
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
def toC(x):
|
|
return x.cuda()
|
|
|
|
else:
|
|
|
|
def toC(x):
|
|
return x
|
|
|
|
|
|
BATCH_SIZE = 2
|
|
|
|
|
|
class TestModels(pytorch_test_common.ExportTestCase):
|
|
opset_version = 9 # Caffe2 doesn't support the default.
|
|
keep_initializers_as_inputs = False
|
|
|
|
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, **kwargs):
|
|
import caffe2.python.onnx.backend as backend
|
|
|
|
with torch.onnx.select_model_mode_for_export(
|
|
model, torch.onnx.TrainingMode.EVAL
|
|
):
|
|
graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
|
|
torch._C._jit_pass_lint(graph)
|
|
verify(
|
|
model,
|
|
inputs,
|
|
backend,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
opset_version=self.opset_version,
|
|
)
|
|
|
|
def test_ops(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(DummyNet()), toC(x))
|
|
|
|
def test_prelu(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(PReluNet(), x)
|
|
|
|
@skipScriptTest()
|
|
def test_concat(self):
|
|
input_a = Variable(torch.randn(BATCH_SIZE, 3))
|
|
input_b = Variable(torch.randn(BATCH_SIZE, 3))
|
|
inputs = ((toC(input_a), toC(input_b)),)
|
|
self.exportTest(toC(ConcatNet()), inputs)
|
|
|
|
def test_permute(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
|
|
self.exportTest(PermuteNet(), x)
|
|
|
|
@skipScriptTest()
|
|
def test_embedding_sequential_1(self):
|
|
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
|
|
self.exportTest(EmbeddingNetwork1(), x)
|
|
|
|
@skipScriptTest()
|
|
def test_embedding_sequential_2(self):
|
|
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
|
|
self.exportTest(EmbeddingNetwork2(), x)
|
|
|
|
@unittest.skip("This model takes too much memory")
|
|
def test_srresnet(self):
|
|
x = Variable(torch.randn(1, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(
|
|
toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x)
|
|
)
|
|
|
|
@skipIfNoLapack
|
|
def test_super_resolution(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6)
|
|
|
|
def test_alexnet(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(alexnet()), toC(x))
|
|
|
|
def test_mnist(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
|
|
self.exportTest(toC(MNIST()), toC(x))
|
|
|
|
@unittest.skip("This model takes too much memory")
|
|
def test_vgg16(self):
|
|
# VGG 16-layer model (configuration "D")
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(vgg16()), toC(x))
|
|
|
|
@unittest.skip("This model takes too much memory")
|
|
def test_vgg16_bn(self):
|
|
# VGG 16-layer model (configuration "D") with batch normalization
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(vgg16_bn()), toC(x))
|
|
|
|
@unittest.skip("This model takes too much memory")
|
|
def test_vgg19(self):
|
|
# VGG 19-layer model (configuration "E")
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(vgg19()), toC(x))
|
|
|
|
@unittest.skip("This model takes too much memory")
|
|
def test_vgg19_bn(self):
|
|
# VGG 19-layer model (configuration "E") with batch normalization
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(vgg19_bn()), toC(x))
|
|
|
|
def test_resnet(self):
|
|
# ResNet50 model
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
|
|
|
|
# This test is numerically unstable. Sporadic single element mismatch occurs occasionally.
|
|
def test_inception(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
|
|
self.exportTest(toC(inception_v3()), toC(x), acceptable_error_percentage=0.01)
|
|
|
|
def test_squeezenet(self):
|
|
# SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and
|
|
# <0.5MB model size
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
sqnet_v1_0 = SqueezeNet(version=1.1)
|
|
self.exportTest(toC(sqnet_v1_0), toC(x))
|
|
|
|
# SqueezeNet 1.1 has 2.4x less computation and slightly fewer params
|
|
# than SqueezeNet 1.0, without sacrificing accuracy.
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
sqnet_v1_1 = SqueezeNet(version=1.1)
|
|
self.exportTest(toC(sqnet_v1_1), toC(x))
|
|
|
|
def test_densenet(self):
|
|
# Densenet-121 model
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)
|
|
|
|
@skipScriptTest()
|
|
def test_dcgan_netD(self):
|
|
netD = _netD(1)
|
|
netD.apply(weights_init)
|
|
input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
|
|
self.exportTest(toC(netD), toC(input))
|
|
|
|
@skipScriptTest()
|
|
def test_dcgan_netG(self):
|
|
netG = _netG(1)
|
|
netG.apply(weights_init)
|
|
input = Variable(torch.empty(bsz, nz, 1, 1).normal_(0, 1))
|
|
self.exportTest(toC(netG), toC(input))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_fake_quant(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(FakeQuantNet()), toC(x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(10)
|
|
def test_qat_resnet_pertensor(self):
|
|
# Quantize ResNet50 model
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
qat_resnet50 = resnet50()
|
|
|
|
# Use per tensor for weight. Per channel support will come with opset 13
|
|
qat_resnet50.qconfig = quantization.QConfig(
|
|
activation=quantization.default_fake_quant,
|
|
weight=quantization.default_fake_quant,
|
|
)
|
|
quantization.prepare_qat(qat_resnet50, inplace=True)
|
|
qat_resnet50.apply(torch.ao.quantization.enable_observer)
|
|
qat_resnet50.apply(torch.ao.quantization.enable_fake_quant)
|
|
|
|
_ = qat_resnet50(x)
|
|
for module in qat_resnet50.modules():
|
|
if isinstance(module, quantization.FakeQuantize):
|
|
module.calculate_qparams()
|
|
qat_resnet50.apply(torch.ao.quantization.disable_observer)
|
|
|
|
self.exportTest(toC(qat_resnet50), toC(x))
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(13)
|
|
def test_qat_resnet_per_channel(self):
|
|
# Quantize ResNet50 model
|
|
x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
|
|
qat_resnet50 = resnet50()
|
|
|
|
qat_resnet50.qconfig = quantization.QConfig(
|
|
activation=quantization.default_fake_quant,
|
|
weight=quantization.default_per_channel_weight_fake_quant,
|
|
)
|
|
quantization.prepare_qat(qat_resnet50, inplace=True)
|
|
qat_resnet50.apply(torch.ao.quantization.enable_observer)
|
|
qat_resnet50.apply(torch.ao.quantization.enable_fake_quant)
|
|
|
|
_ = qat_resnet50(x)
|
|
for module in qat_resnet50.modules():
|
|
if isinstance(module, quantization.FakeQuantize):
|
|
module.calculate_qparams()
|
|
qat_resnet50.apply(torch.ao.quantization.disable_observer)
|
|
|
|
self.exportTest(toC(qat_resnet50), toC(x))
|
|
|
|
@skipScriptTest(skip_before_opset_version=15, reason="None type in outputs")
|
|
def test_googlenet(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
|
|
|
|
def test_mnasnet(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5)
|
|
|
|
def test_mobilenet(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)
|
|
|
|
@skipScriptTest() # prim_data
|
|
def test_shufflenet(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_fcn(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(
|
|
toC(fcn_resnet101(weights=None, weights_backbone=None)),
|
|
toC(x),
|
|
rtol=1e-3,
|
|
atol=1e-5,
|
|
)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(11)
|
|
def test_deeplab(self):
|
|
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
|
self.exportTest(
|
|
toC(deeplabv3_resnet101(weights=None, weights_backbone=None)),
|
|
toC(x),
|
|
rtol=1e-3,
|
|
atol=1e-5,
|
|
)
|
|
|
|
def test_r3d_18_video(self):
|
|
x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
|
|
self.exportTest(toC(r3d_18()), toC(x), rtol=1e-3, atol=1e-5)
|
|
|
|
def test_mc3_18_video(self):
|
|
x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
|
|
self.exportTest(toC(mc3_18()), toC(x), rtol=1e-3, atol=1e-5)
|
|
|
|
def test_r2plus1d_18_video(self):
|
|
x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
|
|
self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|