# Owner(s): ["module: onnx"] import os import unittest import onnx_test_common import parameterized import PIL import torchvision import torch from torch import nn from torch.testing._internal import common_utils def _get_test_image_tensor(): data_dir = os.path.join(os.path.dirname(__file__), "assets") img_path = os.path.join(data_dir, "grace_hopper_517x606.jpg") input_image = PIL.Image.open(img_path) # Based on example from https://pytorch.org/hub/pytorch_vision_resnet/ preprocess = torchvision.transforms.Compose( [ torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) return preprocess(input_image).unsqueeze(0) # Due to precision error from quantization, check only that the top prediction matches. class _TopPredictor(nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model def forward(self, x): x = self.base_model(x) _, topk_id = torch.topk(x[0], 1) return topk_id # TODO: All torchvision quantized model test can be written as single parameterized test case, # after per-parameter test decoration is supported via #79979, or after they are all enabled, # whichever is first. @parameterized.parameterized_class( ("is_script",), [(True,), (False,)], class_name_func=onnx_test_common.parameterize_class_name, ) class TestQuantizedModelsONNXRuntime(onnx_test_common._TestONNXRuntime): def run_test(self, model, inputs, *args, **kwargs): model = _TopPredictor(model) return super().run_test(model, inputs, *args, **kwargs) def test_mobilenet_v3(self): model = torchvision.models.quantization.mobilenet_v3_large( pretrained=True, quantize=True ) self.run_test(model, _get_test_image_tensor()) @unittest.skip("quantized::cat not supported") def test_inception_v3(self): model = torchvision.models.quantization.inception_v3( pretrained=True, quantize=True ) self.run_test(model, _get_test_image_tensor()) @unittest.skip("quantized::cat not supported") def test_googlenet(self): model = torchvision.models.quantization.googlenet( pretrained=True, quantize=True ) self.run_test(model, _get_test_image_tensor()) @unittest.skip("quantized::cat not supported") def test_shufflenet_v2_x0_5(self): model = torchvision.models.quantization.shufflenet_v2_x0_5( pretrained=True, quantize=True ) self.run_test(model, _get_test_image_tensor()) def test_resnet18(self): model = torchvision.models.quantization.resnet18(pretrained=True, quantize=True) self.run_test(model, _get_test_image_tensor()) def test_resnet50(self): model = torchvision.models.quantization.resnet50(pretrained=True, quantize=True) self.run_test(model, _get_test_image_tensor()) def test_resnext101_32x8d(self): model = torchvision.models.quantization.resnext101_32x8d( pretrained=True, quantize=True ) self.run_test(model, _get_test_image_tensor()) if __name__ == "__main__": common_utils.run_tests()