mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63914 Test Plan: Imported from OSS Reviewed By: dreiss Differential Revision: D30531889 fbshipit-source-id: a65e389da2722efbc62e3fe1edf503732326350d
		
			
				
	
	
		
			675 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			675 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#!/usr/bin/env python3
 | 
						|
import os
 | 
						|
import ctypes
 | 
						|
import torch
 | 
						|
from typing import Tuple
 | 
						|
from torch.backends._nnapi.prepare import convert_model_to_nnapi
 | 
						|
from torch.testing._internal.common_utils import TestCase, run_tests
 | 
						|
 | 
						|
 | 
						|
def qpt(t, scale, zero_point, dtype=torch.quint8):
 | 
						|
    t = torch.tensor(t)
 | 
						|
    return torch.quantize_per_tensor(t, scale, zero_point, dtype)
 | 
						|
 | 
						|
 | 
						|
def nhwc(t):
 | 
						|
    t = t.clone().contiguous(memory_format=torch.channels_last)
 | 
						|
    t.nnapi_nhwc = True
 | 
						|
    return t
 | 
						|
 | 
						|
 | 
						|
class TestNNAPI(TestCase):
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        # Avoid saturation in fbgemm
 | 
						|
        torch.backends.quantized.engine = 'qnnpack'
 | 
						|
 | 
						|
        libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH")
 | 
						|
        if libneuralnetworks_path:
 | 
						|
            ctypes.cdll.LoadLibrary(libneuralnetworks_path)
 | 
						|
            print("Will attempt to run NNAPI models.")
 | 
						|
            self.can_run_nnapi = True
 | 
						|
        else:
 | 
						|
            self.can_run_nnapi = False
 | 
						|
 | 
						|
    # Created for easy override by subclasses (eg TestNnapiBackend)
 | 
						|
    def call_lowering_to_nnapi(self, traced_module, args):
 | 
						|
        return convert_model_to_nnapi(traced_module, args)
 | 
						|
 | 
						|
    # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend)
 | 
						|
    def set_can_run_nnapi(self, can_run):
 | 
						|
        self.can_run_nnapi = can_run
 | 
						|
 | 
						|
    def check(
 | 
						|
        self,
 | 
						|
        module,
 | 
						|
        arg_or_args,
 | 
						|
        *,
 | 
						|
        trace_args=None,
 | 
						|
        convert_args=None,
 | 
						|
        atol_rtol=None,
 | 
						|
        limit=None,
 | 
						|
        expected_memory_format=None
 | 
						|
    ):
 | 
						|
        with torch.no_grad():
 | 
						|
            if isinstance(arg_or_args, torch.Tensor):
 | 
						|
                args = [arg_or_args]
 | 
						|
            else:
 | 
						|
                args = arg_or_args
 | 
						|
            module.eval()
 | 
						|
            traced = torch.jit.trace(module, trace_args or args)
 | 
						|
            nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args)
 | 
						|
            if not self.can_run_nnapi:
 | 
						|
                # Only test that the model was converted successfully.
 | 
						|
                return
 | 
						|
            eager_output = module(*args)
 | 
						|
            nnapi_output = nnapi_module(*args)
 | 
						|
            kwargs = {}
 | 
						|
            if atol_rtol is not None:
 | 
						|
                kwargs["atol"] = atol_rtol[0]
 | 
						|
                kwargs["rtol"] = atol_rtol[1]
 | 
						|
            self.assertEqual(eager_output, nnapi_output, **kwargs)
 | 
						|
            if limit is not None:
 | 
						|
                mismatches = \
 | 
						|
                    eager_output.int_repr().to(torch.int32) - \
 | 
						|
                    nnapi_output.int_repr().to(torch.int32)
 | 
						|
                if mismatches.count_nonzero() > limit:
 | 
						|
                    # Too many mismatches.  Re-run the check with no tolerance
 | 
						|
                    # to get a nice message.
 | 
						|
                    self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
 | 
						|
            if expected_memory_format:
 | 
						|
                self.assertTrue(nnapi_output.is_contiguous(memory_format=expected_memory_format))
 | 
						|
 | 
						|
    def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
 | 
						|
        torch.manual_seed(29)
 | 
						|
        inp_quant = qpt(inp_float, 0.03, 128)
 | 
						|
        return [
 | 
						|
            ("float", inp_float),
 | 
						|
            ("float-nhwc", nhwc(inp_float)),
 | 
						|
            ("quant", inp_quant),
 | 
						|
            ("quant-nhwc", nhwc(inp_quant)),
 | 
						|
        ]
 | 
						|
 | 
						|
    def test_prelu(self):
 | 
						|
        arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
 | 
						|
        single_a = torch.nn.PReLU()
 | 
						|
        self.check(single_a, arg)
 | 
						|
        multi_a = torch.nn.PReLU(4)
 | 
						|
        with torch.no_grad():
 | 
						|
            multi_a.weight.copy_(torch.tensor([.1, .2, .3, .4]))
 | 
						|
        self.check(multi_a, nhwc(arg))
 | 
						|
 | 
						|
        # Test flexible size
 | 
						|
        self.check(
 | 
						|
            multi_a,
 | 
						|
            arg,
 | 
						|
            trace_args=[torch.zeros(1, 4, 3, 3)],
 | 
						|
            convert_args=[nhwc(torch.zeros(1, 4, 0, 0))],
 | 
						|
        )
 | 
						|
 | 
						|
    def test_quantize(self):
 | 
						|
        self.check(
 | 
						|
            torch.nn.quantized.Quantize(0.25, 2, torch.quint8),
 | 
						|
            nhwc(torch.tensor([[[[1.0]], [[2.0]]]])))
 | 
						|
 | 
						|
    def test_dequantize(self):
 | 
						|
        self.check(
 | 
						|
            torch.nn.quantized.DeQuantize(),
 | 
						|
            nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2)))
 | 
						|
 | 
						|
    def test_unsqueeze(self):
 | 
						|
        class UnsqueezeModule(torch.nn.Module):
 | 
						|
            def __init__(self, dim):
 | 
						|
                super().__init__()
 | 
						|
                self.dim = dim
 | 
						|
 | 
						|
            def forward(self, arg):
 | 
						|
                return arg.unsqueeze(self.dim)
 | 
						|
 | 
						|
        self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2))
 | 
						|
        self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2))
 | 
						|
        self.check(UnsqueezeModule(0), torch.randn(4, 2, 2))
 | 
						|
        self.check(UnsqueezeModule(1), torch.randn(4, 2, 2))
 | 
						|
        self.check(UnsqueezeModule(2), torch.randn(4, 2, 2))
 | 
						|
 | 
						|
    def test_reshape(self):
 | 
						|
        class ReshapeModule(torch.nn.Module):
 | 
						|
            def __init__(self, shape):
 | 
						|
                super().__init__()
 | 
						|
                self.shape = shape
 | 
						|
 | 
						|
            def forward(self, arg):
 | 
						|
                return arg.reshape(self.shape)
 | 
						|
 | 
						|
        self.check(
 | 
						|
            ReshapeModule((2, 4)),
 | 
						|
            torch.randn(4, 2, 1, 1))
 | 
						|
 | 
						|
        self.check(
 | 
						|
            ReshapeModule((8, -1)),
 | 
						|
            nhwc(torch.randn(4, 2, 1, 1)))
 | 
						|
 | 
						|
        with self.assertRaisesRegex(Exception, "target size"):
 | 
						|
            self.check(
 | 
						|
                ReshapeModule((2, 4)),
 | 
						|
                nhwc(torch.randn(4, 2, 1, 1)))
 | 
						|
 | 
						|
    def test_flatten(self):
 | 
						|
        for mod in [
 | 
						|
            torch.nn.Flatten(),
 | 
						|
            torch.nn.Flatten(start_dim=2, end_dim=3),
 | 
						|
            torch.nn.Flatten(start_dim=2, end_dim=4),
 | 
						|
            torch.nn.Flatten(start_dim=0, end_dim=-2),
 | 
						|
            torch.nn.Flatten(start_dim=0, end_dim=4)
 | 
						|
 | 
						|
        ]:
 | 
						|
            self.check(mod, torch.randn(4, 2, 1, 3, 7))
 | 
						|
 | 
						|
        # flex inputs
 | 
						|
        self.check(
 | 
						|
            torch.nn.Flatten(),
 | 
						|
            torch.randn(4, 2, 1, 3, 7),
 | 
						|
            convert_args=[torch.zeros(0, 2, 1, 3, 7)]
 | 
						|
        )
 | 
						|
 | 
						|
        # channels last
 | 
						|
        self.check(
 | 
						|
            torch.nn.Flatten(),
 | 
						|
            nhwc(torch.randn(2, 1, 4, 7))
 | 
						|
        )
 | 
						|
        self.check(
 | 
						|
            torch.nn.Flatten(),
 | 
						|
            nhwc(torch.randn(2, 3, 1, 1))
 | 
						|
        )
 | 
						|
 | 
						|
        # Exceptions
 | 
						|
        with self.assertRaisesRegex(Exception, "not supported on NHWC"):
 | 
						|
            self.check(
 | 
						|
                torch.nn.Flatten(),
 | 
						|
                nhwc(torch.randn(1, 3, 4, 4))
 | 
						|
            )
 | 
						|
        with self.assertRaisesRegex(Exception, "Flattening flexible dims is not supported yet"):
 | 
						|
            self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
 | 
						|
        with self.assertRaisesRegex(Exception, "Only 1 dim"):
 | 
						|
            self.check(
 | 
						|
                torch.nn.Flatten(start_dim=1, end_dim=-2),
 | 
						|
                torch.randn(0, 2, 1, 3, 0))
 | 
						|
 | 
						|
    def test_slice(self):
 | 
						|
        class SliceModule(torch.nn.Module):
 | 
						|
            def __init__(self, start, stop, step):
 | 
						|
                super().__init__()
 | 
						|
                self.start = start
 | 
						|
                self.stop = stop
 | 
						|
                self.step = step
 | 
						|
 | 
						|
            def forward(self, t):
 | 
						|
                return t[1:, self.start:self.stop:self.step, :]
 | 
						|
 | 
						|
        class SliceModule2(torch.nn.Module):
 | 
						|
            def forward(self, t):
 | 
						|
                return t[3:]
 | 
						|
 | 
						|
        self.check(
 | 
						|
            SliceModule(1, 5, 2),
 | 
						|
            torch.randn(4, 6, 2)
 | 
						|
        )
 | 
						|
        self.check(
 | 
						|
            SliceModule2(),
 | 
						|
            torch.randn(5)
 | 
						|
        )
 | 
						|
 | 
						|
        # flex inputs
 | 
						|
        self.check(
 | 
						|
            SliceModule(1, 5, 2),
 | 
						|
            torch.randn(4, 6, 2),
 | 
						|
            convert_args=[torch.zeros(4, 6, 0)]
 | 
						|
        )
 | 
						|
        with self.assertRaisesRegex(Exception, "slice with flexible shape"):
 | 
						|
            self.check(
 | 
						|
                SliceModule(1, 5, 2),
 | 
						|
                torch.randn(4, 6, 2),
 | 
						|
                convert_args=[torch.zeros(0, 0, 0)]
 | 
						|
            )
 | 
						|
 | 
						|
    def test_cat(self):
 | 
						|
        class CatModule(torch.nn.Module):
 | 
						|
            def __init__(self, dim):
 | 
						|
                super().__init__()
 | 
						|
                self.dim = dim
 | 
						|
 | 
						|
            def forward(self, t1, t2):
 | 
						|
                return torch.cat([t1, t2], self.dim)
 | 
						|
 | 
						|
        self.check(
 | 
						|
            CatModule(0),
 | 
						|
            [
 | 
						|
                torch.randn(1, 2, 3, 3),
 | 
						|
                torch.randn(2, 2, 3, 3),
 | 
						|
            ])
 | 
						|
 | 
						|
        self.check(
 | 
						|
            CatModule(1),
 | 
						|
            [
 | 
						|
                torch.randn(1, 2, 3, 3),
 | 
						|
                torch.randn(1, 4, 3, 3),
 | 
						|
            ])
 | 
						|
 | 
						|
        self.check(
 | 
						|
            CatModule(1),
 | 
						|
            [
 | 
						|
                nhwc(torch.randn(1, 2, 3, 3)),
 | 
						|
                nhwc(torch.randn(1, 4, 3, 3)),
 | 
						|
            ])
 | 
						|
 | 
						|
        self.check(
 | 
						|
            CatModule(1),
 | 
						|
            [
 | 
						|
                torch.randn(1, 2, 3, 3),
 | 
						|
                torch.randn(1, 4, 3, 3),
 | 
						|
            ],
 | 
						|
            convert_args=[
 | 
						|
                torch.zeros(0, 0, 0, 0),
 | 
						|
                torch.zeros(0, 0, 0, 0)
 | 
						|
            ])
 | 
						|
 | 
						|
    def test_pointwise_unary(self):
 | 
						|
        for op in ["relu", "sigmoid"]:
 | 
						|
            with self.subTest(op):
 | 
						|
                class UnaryModule(torch.nn.Module):
 | 
						|
                    def forward(self, arg):
 | 
						|
                        if op == "relu":
 | 
						|
                            return torch.nn.functional.relu(arg)
 | 
						|
                        if op == "sigmoid":
 | 
						|
                            return torch.sigmoid(arg)
 | 
						|
                        raise Exception("Bad op")
 | 
						|
                self.check(UnaryModule(), torch.tensor([-1.0, 1.0]))
 | 
						|
 | 
						|
    def test_pointwise_binary(self):
 | 
						|
        for op in ["add", "sub", "mul", "div"]:
 | 
						|
            with self.subTest(op):
 | 
						|
                class BinaryModule(torch.nn.Module):
 | 
						|
                    def forward(self, lhs, rhs):
 | 
						|
                        if op == "add":
 | 
						|
                            return lhs + rhs
 | 
						|
                        if op == "sub":
 | 
						|
                            return lhs - rhs
 | 
						|
                        if op == "mul":
 | 
						|
                            return lhs * rhs
 | 
						|
                        if op == "div":
 | 
						|
                            return lhs / rhs
 | 
						|
                        raise Exception("Bad op")
 | 
						|
 | 
						|
                self.check(
 | 
						|
                    BinaryModule(),
 | 
						|
                    [
 | 
						|
                        torch.tensor([1.0, 2.0]),
 | 
						|
                        torch.tensor([3.0, 4.0]),
 | 
						|
                    ])
 | 
						|
 | 
						|
                self.check(
 | 
						|
                    BinaryModule(),
 | 
						|
                    [
 | 
						|
                        torch.tensor([[1.0, 2.0]]),
 | 
						|
                        torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
 | 
						|
                    ])
 | 
						|
 | 
						|
                with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"):
 | 
						|
                    self.check(
 | 
						|
                        BinaryModule(),
 | 
						|
                        [
 | 
						|
                            torch.tensor([1.0, 2.0]),
 | 
						|
                            torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
 | 
						|
                        ])
 | 
						|
 | 
						|
    def test_pointwise_binary_const(self):
 | 
						|
        const = torch.randn(1, 4, 6, 6)
 | 
						|
 | 
						|
        class ArgPlusConst(torch.nn.Module):
 | 
						|
            def forward(self, arg):
 | 
						|
                return arg + const
 | 
						|
 | 
						|
        class ConstPlusArg(torch.nn.Module):
 | 
						|
            def forward(self, arg):
 | 
						|
                return const + arg
 | 
						|
 | 
						|
        arg_contig = torch.randn(2, 4, 6, 6)
 | 
						|
        arg_nhwc = nhwc(torch.randn(2, 4, 6, 6))
 | 
						|
 | 
						|
        for mod_class in [ArgPlusConst, ConstPlusArg]:
 | 
						|
            for use_nhwc in [False, True]:
 | 
						|
                with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
 | 
						|
                    arg = arg_nhwc if use_nhwc else arg_contig
 | 
						|
                    memory_format = torch.channels_last if use_nhwc else torch.contiguous_format
 | 
						|
                    self.check(mod_class(), arg,
 | 
						|
                               expected_memory_format=memory_format)
 | 
						|
 | 
						|
    def test_hardtanh(self):
 | 
						|
        inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
 | 
						|
        self.check(torch.nn.Hardtanh(), inp)
 | 
						|
        self.check(torch.nn.Hardtanh(0.0, 6.0), inp)
 | 
						|
        with self.assertRaisesRegex(Exception, "hardtanh with args"):
 | 
						|
            self.check(torch.nn.Hardtanh(0.0, 5.0), inp)
 | 
						|
 | 
						|
    def test_softmax(self):
 | 
						|
        inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]])
 | 
						|
        self.check(torch.nn.Softmax(), inp)
 | 
						|
        self.check(torch.nn.Softmax(dim=0), inp)
 | 
						|
        # Test flexible size
 | 
						|
        self.check(
 | 
						|
            torch.nn.Softmax(),
 | 
						|
            inp,
 | 
						|
            convert_args=[torch.zeros(0, 0)],
 | 
						|
        )
 | 
						|
 | 
						|
    def test_to(self):
 | 
						|
        class ToCPU(torch.nn.Module):
 | 
						|
            def __init__(self):
 | 
						|
                super().__init__()
 | 
						|
                self.prelu = torch.nn.PReLU()
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                y = x.to("cpu")
 | 
						|
                # add prelu since input operand can't be output
 | 
						|
                return self.prelu(y)
 | 
						|
 | 
						|
        arg = torch.randn(1, 2, 3, 3)
 | 
						|
        self.check(ToCPU(), arg)
 | 
						|
        # Test flexible size
 | 
						|
        self.check(
 | 
						|
            ToCPU(),
 | 
						|
            arg,
 | 
						|
            convert_args=[torch.zeros(1, 2, 0, 0)],
 | 
						|
        )
 | 
						|
 | 
						|
    def test_detach(self):
 | 
						|
        class DetachModule(torch.nn.Module):
 | 
						|
            def __init__(self):
 | 
						|
                super().__init__()
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                y = x.detach()
 | 
						|
                return torch.nn.functional.relu(y)
 | 
						|
 | 
						|
        self.check(DetachModule(), torch.randn(1, 2, 3, 3))
 | 
						|
        self.check(
 | 
						|
            DetachModule(), torch.randn(1, 2, 3, 3),
 | 
						|
            convert_args=[torch.zeros(1, 2, 0, 0)])
 | 
						|
 | 
						|
    def test_log_softmax(self):
 | 
						|
        inp = torch.randn(3, 10)
 | 
						|
        self.check(torch.nn.LogSoftmax(), inp)
 | 
						|
        self.check(torch.nn.LogSoftmax(0), inp)
 | 
						|
 | 
						|
    def test_mean(self):
 | 
						|
        class MeanModule(torch.nn.Module):
 | 
						|
            def __init__(self, dim, keep=False):
 | 
						|
                super().__init__()
 | 
						|
                self.dim = dim
 | 
						|
                self.keep = keep
 | 
						|
 | 
						|
            def forward(self, t):
 | 
						|
                return torch.mean(t, dim=self.dim, keepdim=self.keep)
 | 
						|
 | 
						|
        self.check(MeanModule(0), torch.randn(2, 3))
 | 
						|
        self.check(MeanModule(1), torch.randn(2, 3))
 | 
						|
        self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6))
 | 
						|
        self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6)))
 | 
						|
        self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6)))
 | 
						|
        self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6)))
 | 
						|
 | 
						|
    def test_max_pool2d(self):
 | 
						|
        for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
 | 
						|
            with self.subTest(name):
 | 
						|
                self.check(torch.nn.MaxPool2d(2), inp)
 | 
						|
                self.check(torch.nn.MaxPool2d((3, 4)), inp)
 | 
						|
                self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp)
 | 
						|
 | 
						|
    def test_avg_pool2d(self):
 | 
						|
        for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
 | 
						|
            with self.subTest(name):
 | 
						|
                atol_rtol = None
 | 
						|
                limit = None
 | 
						|
                convert_dims = (2, 3, 0, 0)
 | 
						|
                convert_arg = torch.zeros(*convert_dims)
 | 
						|
 | 
						|
                for model in (
 | 
						|
                        torch.nn.AvgPool2d(2),
 | 
						|
                        torch.nn.AvgPool2d((3, 4)),
 | 
						|
                        torch.nn.AvgPool2d((3, 4), (1, 2))):
 | 
						|
                    if "quant" in name:
 | 
						|
                        atol_rtol = (1, 0)
 | 
						|
                        limit = model(inp).numel()
 | 
						|
                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
 | 
						|
                    if "nhwc" in name:
 | 
						|
                        convert_arg = nhwc(convert_arg)
 | 
						|
 | 
						|
                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
 | 
						|
                    self.check(
 | 
						|
                        model,
 | 
						|
                        inp,
 | 
						|
                        convert_args=[convert_arg],
 | 
						|
                        atol_rtol=atol_rtol,
 | 
						|
                        limit=limit
 | 
						|
                    )
 | 
						|
 | 
						|
    def test_adaptive_avg_pool2d(self):
 | 
						|
        for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
 | 
						|
            with self.subTest(name):
 | 
						|
                self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp)
 | 
						|
                with self.assertRaisesRegex(Exception, "with output size"):
 | 
						|
                    self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp)
 | 
						|
 | 
						|
    def test_upsample_nearest2d(self):
 | 
						|
        convert_args = dict(self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128))
 | 
						|
        for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
 | 
						|
            with self.subTest(name):
 | 
						|
                self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp)
 | 
						|
                self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp)
 | 
						|
                self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp)
 | 
						|
                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp)
 | 
						|
                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp)
 | 
						|
                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp)
 | 
						|
 | 
						|
                self.check(
 | 
						|
                    torch.nn.UpsamplingNearest2d(size=(24, 32)), inp,
 | 
						|
                    convert_args=[convert_args[name]]
 | 
						|
                )
 | 
						|
                self.check(
 | 
						|
                    torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp,
 | 
						|
                    convert_args=[convert_args[name]]
 | 
						|
                )
 | 
						|
 | 
						|
    def test_linear(self):
 | 
						|
        torch.manual_seed(29)
 | 
						|
        self.check(torch.nn.Linear(16, 32), torch.randn(2, 16))
 | 
						|
        self.check(
 | 
						|
            torch.nn.Linear(16, 32), torch.randn(2, 16),
 | 
						|
            convert_args=[torch.zeros(0, 16)])
 | 
						|
 | 
						|
    def test_conv2d(self):
 | 
						|
        cases = [
 | 
						|
            # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim,      name
 | 
						|
            ( 4,     8,      (3, 3), 1,      0,       1,      1,    (2, 4, 16, 16), "3x3"),        # noqa: E201,E241
 | 
						|
            ( 4,     8,      (3, 3), 1,      0,       1,      0,    (2, 4, 16, 16), "3x3nobias"),  # noqa: E201,E241
 | 
						|
            ( 4,     16,     (3, 3), 1,      1,       1,      1,    (2, 4, 16, 16), "3x3p1"),      # noqa: E201,E241
 | 
						|
            ( 8,     8,      (3, 3), 2,      0,       1,      1,    (2, 8, 16, 16), "3x3s2"),      # noqa: E201,E241
 | 
						|
            ( 4,     8,      (5, 5), 1,      0,       1,      1,    (2, 4, 16, 16), "5x5"),        # noqa: E201,E241
 | 
						|
            ( 4,     4,      (3, 3), 1,      0,       4,      1,    (2, 4, 16, 16), "3x3dw"),      # noqa: E201,E241
 | 
						|
            ( 8,     4,      (1, 1), 1,      0,       1,      1,    (2, 8, 16, 16), "1x1"),        # noqa: E201,E241
 | 
						|
        ]
 | 
						|
 | 
						|
        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
 | 
						|
            for case in cases:
 | 
						|
                in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name = case
 | 
						|
                with self.subTest("{}-{}".format(kind, name)):
 | 
						|
                    inp = torch.randn(input_dim)
 | 
						|
                    model = torch.nn.Conv2d(in_ch, out_ch, kernel, stride, padding, groups=groups, bias=bool(bias))
 | 
						|
                    output_size = model(inp).numel()
 | 
						|
                    atol_rtol = None
 | 
						|
                    limit = None
 | 
						|
                    convert_dims = (0, in_ch, 0, 0)
 | 
						|
                    convert_arg = torch.zeros(*convert_dims)
 | 
						|
 | 
						|
                    if "quant" in kind:
 | 
						|
                        model = torch.nn.Sequential(model)
 | 
						|
                        model.eval()
 | 
						|
                        model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
 | 
						|
                        model = torch.quantization.prepare(model)
 | 
						|
                        model(inp)
 | 
						|
                        model = torch.quantization.convert(model)
 | 
						|
                        inp = qpt(inp, 1.0 / 16, 128)
 | 
						|
                        # I've seen numerical differences between QNNPACK and NNAPI,
 | 
						|
                        # but never more than 1 quantum, and never more than ~1% of
 | 
						|
                        # the output in this test.
 | 
						|
                        atol_rtol = (1, 0)
 | 
						|
                        limit = output_size * 0.03
 | 
						|
                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
 | 
						|
 | 
						|
                    if "nhwc" in kind:
 | 
						|
                        inp = nhwc(inp)
 | 
						|
                        convert_arg = nhwc(convert_arg)
 | 
						|
 | 
						|
                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
 | 
						|
                    self.check(
 | 
						|
                        model,
 | 
						|
                        inp,
 | 
						|
                        convert_args=[convert_arg],
 | 
						|
                        atol_rtol=atol_rtol,
 | 
						|
                        limit=limit
 | 
						|
                    )
 | 
						|
 | 
						|
    def test_conv2d_transpose(self):
 | 
						|
        torch.manual_seed(29)
 | 
						|
        in_ch, out_ch, kernel = (5, 7, (2, 2))
 | 
						|
        input_dim = (4, 5, 3, 3)
 | 
						|
        convert_dims = input_dim[:2] + (0, 0)
 | 
						|
 | 
						|
        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
 | 
						|
            with self.subTest(kind):
 | 
						|
                inp = torch.randn(input_dim)
 | 
						|
                model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel)
 | 
						|
                output_size = model(inp).numel()
 | 
						|
                atol_rtol = (0.0002, 0)
 | 
						|
                limit = None
 | 
						|
                convert_arg = torch.zeros(*convert_dims)
 | 
						|
 | 
						|
                if "quant" in kind:
 | 
						|
                    model = torch.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel)
 | 
						|
                    model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
 | 
						|
                    inp = qpt(inp, 1.0 / 16, 128)
 | 
						|
                    # I've seen numerical differences between QNNPACK and NNAPI,
 | 
						|
                    # but never more than 1 quantum, and never more than ~10% of
 | 
						|
                    # the output in this test.
 | 
						|
                    atol_rtol = (1, 0)
 | 
						|
                    limit = output_size * 0.1
 | 
						|
                    convert_arg = qpt(convert_arg, 1.0 / 16, 128)
 | 
						|
 | 
						|
                if "nhwc" in kind:
 | 
						|
                    inp = nhwc(inp)
 | 
						|
                    convert_arg = nhwc(convert_arg)
 | 
						|
 | 
						|
                self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
 | 
						|
                self.check(
 | 
						|
                    model,
 | 
						|
                    inp,
 | 
						|
                    convert_args=[convert_arg],
 | 
						|
                    atol_rtol=atol_rtol,
 | 
						|
                    limit=limit
 | 
						|
                )
 | 
						|
 | 
						|
 | 
						|
    def test_qadd(self):
 | 
						|
        func = torch.nn.quantized.QFunctional()
 | 
						|
        func.scale = 0.5
 | 
						|
        func.zero_point = 120
 | 
						|
 | 
						|
        class AddMod(torch.nn.Module):
 | 
						|
            def forward(self, lhs, rhs):
 | 
						|
                return func.add(lhs, rhs)
 | 
						|
 | 
						|
        class AddReluMod(torch.nn.Module):
 | 
						|
            def forward(self, lhs, rhs):
 | 
						|
                return func.add_relu(lhs, rhs)
 | 
						|
 | 
						|
        class MulMod(torch.nn.Module):
 | 
						|
            def forward(self, lhs, rhs):
 | 
						|
                return func.mul(lhs, rhs)
 | 
						|
 | 
						|
        for (name, mod) in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]:
 | 
						|
            with self.subTest(name):
 | 
						|
                self.check(
 | 
						|
                    mod(),
 | 
						|
                    [
 | 
						|
                        qpt([1.0, 2.0], 0.25, 128),
 | 
						|
                        qpt([3.0, 4.0], 0.25, 128),
 | 
						|
                    ])
 | 
						|
                self.check(
 | 
						|
                    mod(),
 | 
						|
                    [
 | 
						|
                        qpt([[1.0, 2.0]], 0.25, 128),
 | 
						|
                        qpt([[3.0, 4.0]], 0.25, 128),
 | 
						|
                    ],
 | 
						|
                    convert_args=[
 | 
						|
                        qpt([[1.0, 2.0]], 0.25, 128),
 | 
						|
                        qpt(torch.zeros((1, 2)), 0.25, 128),
 | 
						|
                    ]
 | 
						|
                )
 | 
						|
                self.check(
 | 
						|
                    mod(),
 | 
						|
                    [
 | 
						|
                        qpt([[1.0, 2.0]], 0.25, 128),
 | 
						|
                        qpt([[3.0, 4.0]], 0.25, 128),
 | 
						|
                    ],
 | 
						|
                    convert_args=[
 | 
						|
                        qpt(torch.zeros((1, 2)), 0.25, 128),
 | 
						|
                        qpt([[3.0, 4.0]], 0.25, 128),
 | 
						|
                    ]
 | 
						|
                )
 | 
						|
                self.check(
 | 
						|
                    mod(),
 | 
						|
                    [
 | 
						|
                        qpt([[1.0, 2.0]], 0.25, 128),
 | 
						|
                        qpt([[3.0, 4.0]], 0.25, 128),
 | 
						|
                    ],
 | 
						|
                    convert_args=[
 | 
						|
                        qpt(torch.zeros((1, 2)), 0.25, 128),
 | 
						|
                        qpt(torch.zeros((1, 2)), 0.25, 128),
 | 
						|
                    ]
 | 
						|
                )
 | 
						|
                # NOTE: NNAPI qadd supports broadcast, but PT does not.
 | 
						|
 | 
						|
    def test_qlinear(self):
 | 
						|
        torch.manual_seed(29)
 | 
						|
        weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8)
 | 
						|
        bias = torch.randn(16)
 | 
						|
        mod = torch.nn.quantized.Linear(32, 16)
 | 
						|
        mod.set_weight_bias(weight, bias)
 | 
						|
        inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
 | 
						|
        self.check(mod, inp)
 | 
						|
 | 
						|
    def test_seblock_mul(self):
 | 
						|
        class MulModel(torch.nn.Module):
 | 
						|
            def forward(self, lhs, rhs):
 | 
						|
                return lhs * rhs
 | 
						|
 | 
						|
        self.check(
 | 
						|
            MulModel(),
 | 
						|
            [
 | 
						|
                nhwc(torch.randn(2, 3, 4, 4)),
 | 
						|
                torch.randn(1, 3, 1, 1),
 | 
						|
            ])
 | 
						|
 | 
						|
    def test_multi_output(self):
 | 
						|
        class MultiModel(torch.nn.Module):
 | 
						|
            def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
						|
                the_sum = lhs + rhs
 | 
						|
                the_diff = lhs - rhs
 | 
						|
                return the_sum, the_diff
 | 
						|
 | 
						|
        self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])])
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    run_tests()
 |