mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-01 04:54:55 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD
		
			
				
	
	
		
			619 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			619 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["oncall: mobile"]
 | |
| 
 | |
| import unittest
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.utils.bundled_inputs
 | |
| from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK
 | |
| from torch.testing._internal.jit_utils import get_forward, get_forward_graph
 | |
| from torch.utils.mobile_optimizer import (LintCode,
 | |
|                                           generate_mobile_module_lints,
 | |
|                                           optimize_for_mobile,
 | |
|                                           MobileOptimizerType)
 | |
| from torch.nn import functional as F
 | |
| from torch.testing._internal.common_quantized import override_quantized_engine
 | |
| 
 | |
| try:
 | |
|     import torchvision
 | |
|     HAS_TORCHVISION = True
 | |
| except ImportError:
 | |
|     HAS_TORCHVISION = False
 | |
| 
 | |
| FileCheck = torch._C.FileCheck
 | |
| 
 | |
| class TestOptimizer(TestCase):
 | |
| 
 | |
|     @skipIfNoXNNPACK
 | |
|     def test_optimize_for_mobile(self):
 | |
|         batch_size = 2
 | |
|         input_channels_per_group = 6
 | |
|         height = 16
 | |
|         width = 16
 | |
|         output_channels_per_group = 6
 | |
|         groups = 4
 | |
|         kernel_h = kernel_w = 3
 | |
|         stride_h = stride_w = 1
 | |
|         pad_h = pad_w = 1
 | |
|         dilation = 1
 | |
|         input_channels = input_channels_per_group * groups
 | |
|         output_channels = output_channels_per_group * groups
 | |
|         strides = (stride_h, stride_w)
 | |
|         paddings = (pad_h, pad_w)
 | |
|         dilations = (dilation, dilation)
 | |
|         conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
 | |
|         conv_bias_shape = (output_channels)
 | |
| 
 | |
|         input_data = torch.rand((batch_size, input_channels, height, width))
 | |
|         conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
 | |
|         conv_bias = torch.rand(output_channels)
 | |
|         result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
 | |
|         weight_output_dim = 24
 | |
|         linear_input_shape = result.shape[1]
 | |
|         linear_weight_shape = (weight_output_dim, linear_input_shape)
 | |
| 
 | |
|         class MyTestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape))
 | |
|                 self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape))
 | |
|                 self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
 | |
|                 self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
 | |
|                 self.strides = strides
 | |
|                 self.paddings = paddings
 | |
|                 self.dilations = dilations
 | |
|                 self.groups = groups
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 o = F.conv2d(x, self.conv_weight, self.conv_bias,
 | |
|                              self.strides, self.paddings, self.dilations, self.groups)
 | |
|                 o = F.relu(o)
 | |
|                 x = o.permute([0, 2, 3, 1])
 | |
|                 o = F.linear(x, self.linear_weight, self.linear_bias)
 | |
|                 o = o + x
 | |
|                 return F.relu(o)
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def foo(self, x):
 | |
|                 o = F.conv2d(x, self.conv_weight, self.conv_bias,
 | |
|                              self.strides, self.paddings, self.dilations, self.groups)
 | |
|                 o = F.relu(o)
 | |
|                 x = o.permute([0, 2, 3, 1])
 | |
|                 o = F.linear(x, self.linear_weight, self.linear_bias)
 | |
|                 o = o + x
 | |
|                 return F.relu(o)
 | |
| 
 | |
| 
 | |
|         class BNTestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.conv = torch.nn.Conv2d(1, 20, 5, 1)
 | |
|                 self.bn = torch.nn.BatchNorm2d(num_features=20)
 | |
|                 self.bn.eps = 0.0023
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.conv(x)
 | |
|                 x = self.bn(x)
 | |
|                 return x
 | |
| 
 | |
|         data_shape = (batch_size, input_channels, height, width)
 | |
|         input_data = torch.normal(1, 20, size=data_shape)
 | |
| 
 | |
|         scripted_model = torch.jit.script(MyTestModule())
 | |
|         scripted_model.eval()
 | |
|         initial_result = scripted_model(input_data)
 | |
|         initial_foo_result = scripted_model.foo(input_data)
 | |
| 
 | |
|         optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo'])
 | |
|         optimized_result = optimized_scripted_model(input_data)
 | |
|         optimized_foo_result = optimized_scripted_model.foo(input_data)
 | |
| 
 | |
|         FileCheck().check_not("Tensor = aten::conv2d") \
 | |
|                    .check_not("Tensor = prim::CallFunction") \
 | |
|                    .check_not("prepacked::conv2d_clamp_prepack") \
 | |
|                    .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
 | |
|                    .check_not("prepacked::linear_clamp_prepack") \
 | |
|                    .check_count("prepacked::linear_clamp_run", 1, exactly=True) \
 | |
|                    .check_not("aten::add(") \
 | |
|                    .check_not("aten::relu(") \
 | |
|                    .check_count("aten::_add_relu(", 1, exactly=True) \
 | |
|                    .run(optimized_scripted_model.graph)
 | |
|         torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
 | |
| 
 | |
|         FileCheck().check_not("Tensor = aten::conv2d") \
 | |
|                    .check_not("Tensor = prim::CallFunction") \
 | |
|                    .check_not("prepacked::conv2d_clamp_prepack") \
 | |
|                    .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
 | |
|                    .check_not("prepacked::linear_clamp_prepack") \
 | |
|                    .check_count("prepacked::linear_clamp_run", 1, exactly=True) \
 | |
|                    .check_not("aten::add(") \
 | |
|                    .check_not("aten::relu(") \
 | |
|                    .check_count("aten::_add_relu(", 1, exactly=True) \
 | |
|                    .run(optimized_scripted_model.foo.graph)
 | |
|         torch.testing.assert_close(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)
 | |
| 
 | |
| 
 | |
|         optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
 | |
|         optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
 | |
|         optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)
 | |
| 
 | |
|         FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
 | |
|                    .check_not("prepacked::linear_clamp_run") \
 | |
|                    .check_not("prepacked::conv2d_clamp_run") \
 | |
|                    .run(optimized_scripted_model_no_prepack.graph)
 | |
|         torch.testing.assert_close(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)
 | |
| 
 | |
| 
 | |
|         bn_test_module = BNTestModule()
 | |
|         bn_scripted_module = torch.jit.script(bn_test_module)
 | |
|         bn_scripted_module.eval()
 | |
| 
 | |
|         self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11)
 | |
|         FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \
 | |
|                    .run(str(get_forward(bn_scripted_module._c).graph))
 | |
| 
 | |
|         optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
 | |
|         bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
 | |
|         self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
 | |
|         bn_input = torch.rand(1, 1, 6, 6)
 | |
|         torch.testing.assert_close(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
 | |
| 
 | |
|         optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
 | |
|         no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
 | |
|         FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
 | |
|                    .run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
 | |
|         bn_input = torch.rand(1, 1, 6, 6)
 | |
|         torch.testing.assert_close(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
 | |
| 
 | |
|         class MyMobileOptimizedTagTest(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
 | |
|                 self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 o = F.linear(x, self.linear_weight, self.linear_bias)
 | |
|                 return F.relu(o)
 | |
| 
 | |
|         mobile_optimized_tag_module = MyMobileOptimizedTagTest()
 | |
|         m = torch.jit.script(mobile_optimized_tag_module)
 | |
|         m.eval()
 | |
|         opt_m = optimize_for_mobile(m)
 | |
|         tag = getattr(opt_m, "mobile_optimized", None)
 | |
|         self.assertTrue(tag)
 | |
| 
 | |
|         class MyPreserveMethodsTest(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
 | |
|                 self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 o = F.linear(x, self.linear_weight, self.linear_bias)
 | |
|                 return F.relu(o)
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def preserveThis(self):
 | |
|                 pass
 | |
| 
 | |
|         preserve_method_module = MyPreserveMethodsTest()
 | |
|         m = torch.jit.script(preserve_method_module)
 | |
|         m.eval()
 | |
|         opt_m = optimize_for_mobile(m)
 | |
|         no_preserveThis = getattr(opt_m, "preserveThis", None)
 | |
|         self.assertEqual(no_preserveThis, None)
 | |
|         opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
 | |
|         preserveThis = getattr(opt_m, "preserveThis", None)
 | |
|         self.assertNotEqual(preserveThis, None)
 | |
| 
 | |
|         class OptimizeNoForwardTest(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.l = nn.Linear(10, 100)
 | |
|                 self.l2 = nn.Linear(100, 1)
 | |
|                 self.d = nn.Dropout(p=0.2)
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def foo(self, x):
 | |
|                 x = self.d(F.relu(self.l(x)))
 | |
|                 x = self.l2(x)
 | |
|                 x = x + torch.ones(1, 100)
 | |
|                 return F.relu(x)
 | |
|         input_data = torch.ones(1, 10)
 | |
|         m = torch.jit.script(OptimizeNoForwardTest())
 | |
|         m.eval()
 | |
|         initial_result = m.foo(input_data)
 | |
| 
 | |
|         optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo'])
 | |
|         optimized_result = optimized_scripted_model.foo(input_data)
 | |
| 
 | |
|         FileCheck().check_not("dropout.__") \
 | |
|             .check_count("aten::_add_relu(", 1, exactly=True) \
 | |
|             .run(optimized_scripted_model.foo.graph)
 | |
|         torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
 | |
| 
 | |
|         class BNTestNoForwardModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.conv = torch.nn.Conv2d(1, 20, 5, 1)
 | |
|                 self.bn = torch.nn.BatchNorm2d(num_features=20)
 | |
|                 self.bn.eps = 0.0023
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def foo(self, x):
 | |
|                 x = self.conv(x)
 | |
|                 x = self.bn(x)
 | |
|                 return x
 | |
| 
 | |
|         bn_test_no_forward_module = BNTestNoForwardModule()
 | |
|         bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
 | |
|         bn_no_forward_scripted_module.eval()
 | |
| 
 | |
|         self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11)
 | |
|         FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \
 | |
|                    .run(bn_no_forward_scripted_module.foo.graph)
 | |
| 
 | |
|         bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo'])
 | |
|         self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1)
 | |
|         bn_input = torch.rand(1, 1, 6, 6)
 | |
|         torch.testing.assert_close(
 | |
|             bn_no_forward_scripted_module.foo(bn_input),
 | |
|             bn_fold_no_forward_scripted_module.foo(bn_input),
 | |
|             rtol=1e-2,
 | |
|             atol=1e-3)
 | |
| 
 | |
|     @skipIfNoXNNPACK
 | |
|     def test_quantized_conv_no_asan_failures(self):
 | |
|         # There were ASAN failures when fold_conv_bn was run on
 | |
|         # already quantized conv modules. Verifying that this does
 | |
|         # not happen again.
 | |
| 
 | |
|         if 'qnnpack' not in torch.backends.quantized.supported_engines:
 | |
|             return
 | |
| 
 | |
|         class Child(nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.conv2 = nn.Conv2d(1, 1, 1)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.conv2(x)
 | |
|                 return x
 | |
| 
 | |
|         class Parent(nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.quant = torch.ao.quantization.QuantStub()
 | |
|                 self.conv1 = nn.Conv2d(1, 1, 1)
 | |
|                 self.child = Child()
 | |
|                 self.dequant = torch.ao.quantization.DeQuantStub()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.quant(x)
 | |
|                 x = self.conv1(x)
 | |
|                 x = self.child(x)
 | |
|                 x = self.dequant(x)
 | |
|                 return x
 | |
| 
 | |
|         with override_quantized_engine('qnnpack'):
 | |
|             model = Parent()
 | |
|             model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
 | |
|             torch.ao.quantization.prepare(model, inplace=True)
 | |
|             model(torch.randn(4, 1, 4, 4))
 | |
|             torch.ao.quantization.convert(model, inplace=True)
 | |
|             model = torch.jit.script(model)
 | |
|             # this line should not have ASAN failures
 | |
|             optimize_for_mobile(model)
 | |
| 
 | |
|     def test_generate_mobile_module_lints(self):
 | |
|         class MyTestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.fc = torch.nn.Linear(4, 4)
 | |
|                 self.dropout = torch.nn.Dropout(p=0.5)
 | |
| 
 | |
|             def forward(self, inputs):
 | |
|                 out = self.fc(inputs)
 | |
|                 out = self.dropout(out)
 | |
|                 return out
 | |
| 
 | |
|         class MyBNModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.bn = torch.nn.BatchNorm2d(4, affine=True)
 | |
| 
 | |
|             def forward(self, inputs):
 | |
|                 bn = self.bn(inputs)
 | |
|                 return bn
 | |
| 
 | |
|         class MyBundledInputModule(torch.nn.Module):
 | |
|             def forward(self, inputs):
 | |
|                 return inputs
 | |
| 
 | |
|         def get_lint_count_by_type(lint_type, module_lint_List):
 | |
|             return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name])
 | |
| 
 | |
|         test_module = torch.jit.script(MyTestModule())
 | |
|         test_module_lint_list = generate_mobile_module_lints(test_module)
 | |
|         self.assertEqual(len(test_module_lint_list), 4)
 | |
|         self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1)
 | |
|         self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1)
 | |
|         self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2)
 | |
| 
 | |
|         bn_module = torch.jit.script(MyBNModule())
 | |
|         bn_module_lint_list = generate_mobile_module_lints(bn_module)
 | |
|         self.assertEqual(len(bn_module_lint_list), 4)
 | |
|         self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1)
 | |
|         self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1)
 | |
|         self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2)
 | |
| 
 | |
|         bi_module = torch.jit.script(MyBundledInputModule())
 | |
|         torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
 | |
|             bi_module, [(torch.tensor([1]),)], [])
 | |
|         bi_module_lint_list = generate_mobile_module_lints(bi_module)
 | |
|         self.assertEqual(len(bi_module_lint_list), 0)
 | |
| 
 | |
|     @skipIfNoXNNPACK
 | |
|     def test_preserve_bundled_inputs_methods(self):
 | |
|         class MyBundledInputModule(torch.nn.Module):
 | |
|             def forward(self, inputs):
 | |
|                 return inputs
 | |
| 
 | |
|         class MyIncompleteBundledInputModule(torch.nn.Module):
 | |
|             def forward(self, inputs):
 | |
|                 return inputs
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def get_all_bundled_inputs(self):
 | |
|                 pass
 | |
| 
 | |
|         bi_module = torch.jit.script(MyBundledInputModule())
 | |
|         module_optim_bi_not_preserved = optimize_for_mobile(bi_module)
 | |
| 
 | |
|         # Expected to be False since no bundled inputs methods were added
 | |
|         self.assertFalse(
 | |
|             hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or
 | |
|             hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs')
 | |
|         )
 | |
| 
 | |
|         # Add bundled inputs methods to the module
 | |
|         torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
 | |
|             bi_module, [(torch.tensor([1]),)], [])
 | |
|         # Now they should be preserved
 | |
|         module_optim_bi_preserved = optimize_for_mobile(bi_module)
 | |
| 
 | |
|         # All of the bundled inputs methods were preserved
 | |
|         self.assertTrue(
 | |
|             hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and
 | |
|             hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs')
 | |
|         )
 | |
| 
 | |
|         bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0]
 | |
|         module_optim_bi_preserved(*bundled_input)
 | |
| 
 | |
|         # If not all 3 bundled inputs methods are present in the module,
 | |
|         # we will not try to preserve them unless specified by the user.
 | |
|         incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule())
 | |
|         incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module)
 | |
|         self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
 | |
| 
 | |
|         # Specifically preserve get_all_bundled_inputs even if it's the only one
 | |
|         # bundled inputs method available.
 | |
|         incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
 | |
|         self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
 | |
| 
 | |
|     @skipIfNoXNNPACK
 | |
|     def test_hoist_conv_packed_params(self):
 | |
| 
 | |
|         if 'qnnpack' not in torch.backends.quantized.supported_engines:
 | |
|             return
 | |
| 
 | |
|         class Standalone(nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.quant = torch.ao.quantization.QuantStub()
 | |
|                 self.conv1 = nn.Conv2d(1, 1, 1)
 | |
|                 self.conv2 = nn.Conv2d(1, 1, 1)
 | |
|                 self.relu = nn.ReLU()
 | |
|                 self.dequant = torch.ao.quantization.DeQuantStub()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.quant(x)
 | |
|                 x = self.conv1(x)
 | |
|                 x = self.conv2(x)
 | |
|                 x = self.relu(x)
 | |
|                 x = self.dequant(x)
 | |
|                 return x
 | |
| 
 | |
|             def fuse_model(self):
 | |
|                 torch.ao.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True)
 | |
| 
 | |
|         class Child(nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.conv1 = nn.Conv2d(1, 1, 1)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.conv1(x)
 | |
|                 return x
 | |
| 
 | |
|         class Parent(nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.quant = torch.ao.quantization.QuantStub()
 | |
|                 self.conv1 = nn.Conv2d(1, 1, 1)
 | |
|                 self.child = Child()
 | |
|                 # TODO: test nn.Sequential after #42039 is fixed
 | |
|                 self.dequant = torch.ao.quantization.DeQuantStub()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.quant(x)
 | |
|                 x = self.conv1(x)
 | |
|                 x = self.child(x)
 | |
|                 x = self.dequant(x)
 | |
|                 return x
 | |
| 
 | |
|             def fuse_model(self):
 | |
|                 pass
 | |
| 
 | |
|         with override_quantized_engine('qnnpack'):
 | |
|             def _quant_script_and_optimize(model):
 | |
|                 model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
 | |
|                 model.fuse_model()
 | |
|                 torch.ao.quantization.prepare(model, inplace=True)
 | |
|                 model(torch.randn(4, 1, 4, 4))
 | |
|                 torch.ao.quantization.convert(model, inplace=True)
 | |
|                 model = torch.jit.script(model)
 | |
|                 model_optim = optimize_for_mobile(model)
 | |
|                 return model, model_optim
 | |
| 
 | |
|             # basic case
 | |
| 
 | |
|             m, m_optim = _quant_script_and_optimize(Standalone())
 | |
|             FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \
 | |
|                        .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
 | |
|                        .run(m_optim.graph)
 | |
|             self.assertFalse(hasattr(m_optim, "conv1"))
 | |
|             self.assertFalse(hasattr(m_optim, "conv2"))
 | |
| 
 | |
|             data = torch.randn(4, 1, 4, 4)
 | |
|             m_res = m(data)
 | |
|             m_optim_res = m_optim(data)
 | |
|             torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
 | |
| 
 | |
|             # generic case
 | |
| 
 | |
|             m, m_optim = _quant_script_and_optimize(Parent())
 | |
|             FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \
 | |
|                        .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
 | |
|                        .run(m_optim.graph)
 | |
|             self.assertFalse(hasattr(m_optim, "conv1"))
 | |
|             self.assertFalse(hasattr(m_optim, "child"))
 | |
| 
 | |
|             data = torch.randn(4, 1, 4, 4)
 | |
|             m_res = m(data)
 | |
|             m_optim_res = m_optim(data)
 | |
|             torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
 | |
| 
 | |
|     @skipIfNoXNNPACK
 | |
|     @unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
 | |
|     def test_mobilenet_optimize_for_mobile(self):
 | |
|         m = torchvision.models.mobilenet_v3_small()
 | |
|         m = torch.jit.script(m)
 | |
|         m = optimize_for_mobile(m)
 | |
| 
 | |
|         # run forward 3 times until segfault, see https://github.com/pytorch/pytorch/issues/52463
 | |
|         x = torch.zeros(1, 3, 56, 56)
 | |
|         self.assertEqual(m(x).numel(), 1000)
 | |
|         self.assertEqual(m(x).numel(), 1000)
 | |
|         self.assertEqual(m(x).numel(), 1000)
 | |
| 
 | |
|     def test_clone_module_with_class(self):
 | |
|         class MyInnerTestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.pqr = torch.Tensor([10., 20., 30.])
 | |
| 
 | |
|             def forward(self, inputs):
 | |
|                 return inputs
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def dummy_method_not_cloned(self):
 | |
|                 return 20
 | |
| 
 | |
|         class MyTestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.abc = 23
 | |
|                 self.pqr = torch.Tensor([1., 2., 3.])
 | |
|                 self.inner = MyInnerTestModule()
 | |
| 
 | |
|             def forward(self, inputs):
 | |
|                 x = self.dummy_method_cloned()
 | |
|                 # The call to self.inner.dummy_method_not_cloned should not raise an error
 | |
|                 y = self.inner.dummy_method_not_cloned()
 | |
|                 # The call to self.inner.pqr should not raise an error
 | |
|                 z = self.inner.pqr
 | |
|                 return (inputs, x, y, z)
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def dummy_method_not_cloned2(self):
 | |
|                 # The call to self.inner.dummy_method_not_cloned should not raise an error
 | |
|                 y = self.inner.dummy_method_not_cloned()
 | |
|                 # The call to self.inner.pqr should not raise an error
 | |
|                 z = self.inner.pqr
 | |
|                 return self.pqr, self.dummy_method_not_cloned(), y, z
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def dummy_method_not_cloned(self):
 | |
|                 return None
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def dummy_method_cloned(self):
 | |
|                 return None
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def dummy_method_ref_attr_pqr(self):
 | |
|                 return self.pqr, self.inner.pqr
 | |
| 
 | |
|         m = torch.jit.script(MyTestModule())
 | |
| 
 | |
|         # Check that the methods exist on the original model.
 | |
|         self.assertEqual(hasattr(m, "dummy_method_not_cloned"), True)
 | |
|         self.assertEqual(hasattr(m, "dummy_method_cloned"), True)
 | |
|         self.assertEqual(hasattr(m, "dummy_method_not_cloned2"), True)
 | |
|         self.assertEqual(hasattr(m, "pqr"), True)
 | |
| 
 | |
|         # Case-1: Successfully clone, ignoring 2 methods, keeping all attributes.
 | |
|         cloned = torch._C._hack_do_not_use_clone_module_with_class(
 | |
|             m._c,
 | |
|             ["dummy_method_not_cloned", "dummy_method_not_cloned2"],  # ignored_methods
 | |
|             [],  # ignored_attributes
 | |
|         )
 | |
| 
 | |
|         # Check that the ignored methods don't exist on the cloned model.
 | |
|         self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
 | |
|         self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
 | |
|         self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
 | |
|         self.assertEqual(hasattr(cloned, "pqr"), True)
 | |
| 
 | |
|         # Check that the cloned class has a classname that starts with __torch__.
 | |
|         self.assertTrue(
 | |
|             cloned.qualified_name.startswith('__torch__.'),
 | |
|             ("Expected the cloned module's name to start with the string "
 | |
|              f"'__torch__.', but got: {cloned.qualified_name}"),
 | |
|         )
 | |
| 
 | |
| 
 | |
|         # Case-2: Successfully clone the module, ignoring the attribute pqr, and the method that references it.
 | |
|         cloned = torch._C._hack_do_not_use_clone_module_with_class(
 | |
|             m._c,
 | |
|             ["dummy_method_not_cloned", "dummy_method_not_cloned2", "dummy_method_ref_attr_pqr"],
 | |
|             ["pqr"],
 | |
|         )
 | |
| 
 | |
|         # Check that the ignored methods don't exist on the cloned model.
 | |
|         self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
 | |
|         self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
 | |
|         self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
 | |
|         self.assertEqual(hasattr(cloned, "dummy_method_ref_attr_pqr"), False)
 | |
|         self.assertEqual(hasattr(cloned, "pqr"), False)
 | |
| 
 | |
| 
 | |
|         # Case-3: The statement below will throw since dummy_method_cloned2 is preserved,
 | |
|         # and references dummy_method_not_cloned, which is not cloned.
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             cloned = torch._C._hack_do_not_use_clone_module_with_class(m._c, ["dummy_method_not_cloned"], [])
 | |
| 
 | |
|         # Case-4: The statement below will throw since dummy_method_ref_attr_pqr
 | |
|         # is preserved, and references "pqr", which is not cloned.
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             cloned = torch._C._hack_do_not_use_clone_module_with_class(
 | |
|                 m._c,
 | |
|                 ["dummy_method_not_cloned", "dummy_method_not_cloned2"],
 | |
|                 ["pqr"],
 | |
|             )
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 |