mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D30279364: [codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: revert-hammer
Differential Revision:
D30279364 (b004307252
)
Original commit changeset: c1ed77dfe43a
fbshipit-source-id: eab50857675c51e0088391af06ec0ecb14e2347e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ed0b8a3e83
commit
1022443168
@ -1,35 +1,30 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.backends.xnnpack
|
||||
import torch.nn as nn
|
||||
import torch.backends.xnnpack
|
||||
import torch.utils.bundled_inputs
|
||||
from torch._C import MobileOptimizerType
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
|
||||
from torch.utils.mobile_optimizer import (
|
||||
LintCode,
|
||||
generate_mobile_module_lints,
|
||||
optimize_for_mobile,
|
||||
)
|
||||
from torch.utils.mobile_optimizer import (LintCode,
|
||||
generate_mobile_module_lints,
|
||||
optimize_for_mobile)
|
||||
from torch.nn import functional as F
|
||||
from torch._C import MobileOptimizerType
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
|
||||
HAS_TORCHVISION = True
|
||||
except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
|
||||
FileCheck = torch._C.FileCheck
|
||||
|
||||
|
||||
class TestOptimizer(TestCase):
|
||||
@unittest.skipUnless(
|
||||
torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
|
||||
)
|
||||
|
||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests."
|
||||
" Please build with USE_XNNPACK=1.")
|
||||
def test_optimize_for_mobile(self):
|
||||
batch_size = 2
|
||||
input_channels_per_group = 6
|
||||
@ -47,22 +42,13 @@ class TestOptimizer(TestCase):
|
||||
strides = (stride_h, stride_w)
|
||||
paddings = (pad_h, pad_w)
|
||||
dilations = (dilation, dilation)
|
||||
conv_weight_shape = (
|
||||
output_channels,
|
||||
input_channels_per_group,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
)
|
||||
conv_bias_shape = output_channels
|
||||
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
|
||||
conv_bias_shape = (output_channels)
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
conv_weight = torch.rand(
|
||||
(output_channels, input_channels_per_group, kernel_h, kernel_w)
|
||||
)
|
||||
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
|
||||
conv_bias = torch.rand((output_channels))
|
||||
result = F.conv2d(
|
||||
input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
|
||||
)
|
||||
result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
|
||||
weight_output_dim = 24
|
||||
linear_input_shape = result.shape[1]
|
||||
linear_weight_shape = (weight_output_dim, linear_input_shape)
|
||||
@ -80,15 +66,8 @@ class TestOptimizer(TestCase):
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
o = F.conv2d(
|
||||
x,
|
||||
self.conv_weight,
|
||||
self.conv_bias,
|
||||
self.strides,
|
||||
self.paddings,
|
||||
self.dilations,
|
||||
self.groups,
|
||||
)
|
||||
o = F.conv2d(x, self.conv_weight, self.conv_bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
o = F.relu(o)
|
||||
x = o.permute([0, 2, 3, 1])
|
||||
o = F.linear(x, self.linear_weight, self.linear_bias)
|
||||
@ -97,21 +76,15 @@ class TestOptimizer(TestCase):
|
||||
|
||||
@torch.jit.export
|
||||
def foo(self, x):
|
||||
o = F.conv2d(
|
||||
x,
|
||||
self.conv_weight,
|
||||
self.conv_bias,
|
||||
self.strides,
|
||||
self.paddings,
|
||||
self.dilations,
|
||||
self.groups,
|
||||
)
|
||||
o = F.conv2d(x, self.conv_weight, self.conv_bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
o = F.relu(o)
|
||||
x = o.permute([0, 2, 3, 1])
|
||||
o = F.linear(x, self.linear_weight, self.linear_bias)
|
||||
o = o + x
|
||||
return F.relu(o)
|
||||
|
||||
|
||||
class BNTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(BNTestModule, self).__init__()
|
||||
@ -132,109 +105,66 @@ class TestOptimizer(TestCase):
|
||||
initial_result = scripted_model(input_data)
|
||||
initial_foo_result = scripted_model.foo(input_data)
|
||||
|
||||
optimized_scripted_model = optimize_for_mobile(
|
||||
scripted_model, preserved_methods=["foo"]
|
||||
)
|
||||
optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo'])
|
||||
optimized_result = optimized_scripted_model(input_data)
|
||||
optimized_foo_result = optimized_scripted_model.foo(input_data)
|
||||
|
||||
FileCheck().check_not("Tensor = aten::conv2d").check_not(
|
||||
"Tensor = prim::CallFunction"
|
||||
).check_not("prepacked::conv2d_clamp_prepack").check_count(
|
||||
"prepacked::conv2d_clamp_run", 1, exactly=True
|
||||
).check_not(
|
||||
"prepacked::linear_clamp_prepack"
|
||||
).check_count(
|
||||
"prepacked::linear_clamp_run", 1, exactly=True
|
||||
).check_not(
|
||||
"aten::add("
|
||||
).check_not(
|
||||
"aten::relu("
|
||||
).check_count(
|
||||
"aten::_add_relu(", 1, exactly=True
|
||||
).run(
|
||||
optimized_scripted_model.graph
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
initial_result, optimized_result, rtol=1e-2, atol=1e-3
|
||||
)
|
||||
FileCheck().check_not("Tensor = aten::conv2d") \
|
||||
.check_not("Tensor = prim::CallFunction") \
|
||||
.check_not("prepacked::conv2d_clamp_prepack") \
|
||||
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
|
||||
.check_not("prepacked::linear_clamp_prepack") \
|
||||
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
|
||||
.check_not("aten::add(") \
|
||||
.check_not("aten::relu(") \
|
||||
.check_count("aten::_add_relu(", 1, exactly=True) \
|
||||
.run(optimized_scripted_model.graph)
|
||||
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
FileCheck().check_not("Tensor = aten::conv2d").check_not(
|
||||
"Tensor = prim::CallFunction"
|
||||
).check_not("prepacked::conv2d_clamp_prepack").check_count(
|
||||
"prepacked::conv2d_clamp_run", 1, exactly=True
|
||||
).check_not(
|
||||
"prepacked::linear_clamp_prepack"
|
||||
).check_count(
|
||||
"prepacked::linear_clamp_run", 1, exactly=True
|
||||
).check_not(
|
||||
"aten::add("
|
||||
).check_not(
|
||||
"aten::relu("
|
||||
).check_count(
|
||||
"aten::_add_relu(", 1, exactly=True
|
||||
).run(
|
||||
optimized_scripted_model.foo.graph
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3
|
||||
)
|
||||
FileCheck().check_not("Tensor = aten::conv2d") \
|
||||
.check_not("Tensor = prim::CallFunction") \
|
||||
.check_not("prepacked::conv2d_clamp_prepack") \
|
||||
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
|
||||
.check_not("prepacked::linear_clamp_prepack") \
|
||||
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
|
||||
.check_not("aten::add(") \
|
||||
.check_not("aten::relu(") \
|
||||
.check_count("aten::_add_relu(", 1, exactly=True) \
|
||||
.run(optimized_scripted_model.foo.graph)
|
||||
torch.testing.assert_allclose(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
optimization_blocklist_no_prepack = {
|
||||
MobileOptimizerType.INSERT_FOLD_PREPACK_OPS
|
||||
}
|
||||
optimized_scripted_model_no_prepack = optimize_for_mobile(
|
||||
scripted_model, optimization_blocklist_no_prepack
|
||||
)
|
||||
|
||||
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
|
||||
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
|
||||
optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)
|
||||
|
||||
FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True).check_not(
|
||||
"prepacked::linear_clamp_run"
|
||||
).check_not("prepacked::conv2d_clamp_run").run(
|
||||
optimized_scripted_model_no_prepack.graph
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3
|
||||
)
|
||||
FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
|
||||
.check_not("prepacked::linear_clamp_run") \
|
||||
.check_not("prepacked::conv2d_clamp_run") \
|
||||
.run(optimized_scripted_model_no_prepack.graph)
|
||||
torch.testing.assert_allclose(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
||||
bn_test_module = BNTestModule()
|
||||
bn_scripted_module = torch.jit.script(bn_test_module)
|
||||
bn_scripted_module.eval()
|
||||
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14)
|
||||
FileCheck().check_count(
|
||||
'prim::CallMethod[name="forward"]', 2, exactly=True
|
||||
).run(str(get_forward(bn_scripted_module._c).graph))
|
||||
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
||||
.run(str(get_forward(bn_scripted_module._c).graph))
|
||||
|
||||
optimization_blocklist_no_prepack = {
|
||||
MobileOptimizerType.INSERT_FOLD_PREPACK_OPS
|
||||
}
|
||||
bn_fold_scripted_module = optimize_for_mobile(
|
||||
bn_scripted_module, optimization_blocklist_no_prepack
|
||||
)
|
||||
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
|
||||
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
|
||||
bn_input = torch.rand(1, 1, 6, 6)
|
||||
torch.testing.assert_allclose(
|
||||
bn_scripted_module(bn_input),
|
||||
bn_fold_scripted_module(bn_input),
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
)
|
||||
torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
|
||||
|
||||
optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
|
||||
no_bn_fold_scripted_module = optimize_for_mobile(
|
||||
bn_scripted_module, optimization_blocklist_no_fold_bn
|
||||
)
|
||||
FileCheck().check_count("aten::batch_norm", 1, exactly=True).run(
|
||||
str(get_forward_graph(no_bn_fold_scripted_module._c))
|
||||
)
|
||||
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
|
||||
FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
|
||||
.run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
|
||||
bn_input = torch.rand(1, 1, 6, 6)
|
||||
torch.testing.assert_allclose(
|
||||
bn_scripted_module(bn_input),
|
||||
no_bn_fold_scripted_module(bn_input),
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
)
|
||||
torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
|
||||
|
||||
class MyMobileOptimizedTagTest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -290,21 +220,18 @@ class TestOptimizer(TestCase):
|
||||
x = self.l2(x)
|
||||
x = x + torch.ones(1, 100)
|
||||
return F.relu(x)
|
||||
|
||||
input_data = torch.ones(1, 10)
|
||||
m = torch.jit.script(OptimizeNoForwardTest())
|
||||
m.eval()
|
||||
initial_result = m.foo(input_data)
|
||||
|
||||
optimized_scripted_model = optimize_for_mobile(m, preserved_methods=["foo"])
|
||||
optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo'])
|
||||
optimized_result = optimized_scripted_model.foo(input_data)
|
||||
|
||||
FileCheck().check_not("dropout.__").check_count(
|
||||
"aten::_add_relu(", 1, exactly=True
|
||||
).run(optimized_scripted_model.foo.graph)
|
||||
torch.testing.assert_allclose(
|
||||
initial_result, optimized_result, rtol=1e-2, atol=1e-3
|
||||
)
|
||||
FileCheck().check_not("dropout.__") \
|
||||
.check_count("aten::_add_relu(", 1, exactly=True) \
|
||||
.run(optimized_scripted_model.foo.graph)
|
||||
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
class BNTestNoForwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -323,37 +250,28 @@ class TestOptimizer(TestCase):
|
||||
bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
|
||||
bn_no_forward_scripted_module.eval()
|
||||
|
||||
self.assertEqual(
|
||||
len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 14
|
||||
)
|
||||
FileCheck().check_count(
|
||||
'prim::CallMethod[name="forward"]', 2, exactly=True
|
||||
).run(bn_no_forward_scripted_module.foo.graph)
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 14)
|
||||
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
||||
.run(bn_no_forward_scripted_module.foo.graph)
|
||||
|
||||
bn_fold_no_forward_scripted_module = optimize_for_mobile(
|
||||
bn_no_forward_scripted_module, preserved_methods=["foo"]
|
||||
)
|
||||
self.assertEqual(
|
||||
len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1
|
||||
)
|
||||
bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo'])
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1)
|
||||
bn_input = torch.rand(1, 1, 6, 6)
|
||||
torch.testing.assert_allclose(
|
||||
bn_no_forward_scripted_module.foo(bn_input),
|
||||
bn_fold_no_forward_scripted_module.foo(bn_input),
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
)
|
||||
atol=1e-3)
|
||||
|
||||
@unittest.skipUnless(
|
||||
torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
|
||||
)
|
||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests."
|
||||
" Please build with USE_XNNPACK=1.")
|
||||
def test_quantized_conv_no_asan_failures(self):
|
||||
# There were ASAN failures when fold_conv_bn was run on
|
||||
# already quantized conv modules. Verifying that this does
|
||||
# not happen again.
|
||||
|
||||
if "qnnpack" not in torch.backends.quantized.supported_engines:
|
||||
if 'qnnpack' not in torch.backends.quantized.supported_engines:
|
||||
return
|
||||
|
||||
class Child(nn.Module):
|
||||
@ -380,9 +298,9 @@ class TestOptimizer(TestCase):
|
||||
x = self.dequant(x)
|
||||
return x
|
||||
|
||||
with override_quantized_engine("qnnpack"):
|
||||
with override_quantized_engine('qnnpack'):
|
||||
model = Parent()
|
||||
model.qconfig = torch.quantization.get_default_qconfig("qnnpack")
|
||||
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
|
||||
torch.quantization.prepare(model, inplace=True)
|
||||
model(torch.randn(4, 1, 4, 4))
|
||||
torch.quantization.convert(model, inplace=True)
|
||||
@ -419,44 +337,25 @@ class TestOptimizer(TestCase):
|
||||
return inputs
|
||||
|
||||
def get_lint_count_by_type(lint_type, module_lint_List):
|
||||
return len(
|
||||
[
|
||||
lint_dict
|
||||
for lint_dict in module_lint_List
|
||||
if lint_dict["name"] == lint_type.name
|
||||
]
|
||||
)
|
||||
return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name])
|
||||
|
||||
test_module = torch.jit.script(MyTestModule())
|
||||
test_module_lint_list = generate_mobile_module_lints(test_module)
|
||||
self.assertEqual(len(test_module_lint_list), 4)
|
||||
self.assertEqual(
|
||||
get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1
|
||||
)
|
||||
self.assertEqual(
|
||||
get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1
|
||||
)
|
||||
self.assertEqual(
|
||||
get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2
|
||||
)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2)
|
||||
|
||||
bn_module = torch.jit.script(MyBNModule())
|
||||
bn_module_lint_list = generate_mobile_module_lints(bn_module)
|
||||
self.assertEqual(len(bn_module_lint_list), 4)
|
||||
self.assertEqual(
|
||||
get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1
|
||||
)
|
||||
self.assertEqual(
|
||||
get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1
|
||||
)
|
||||
self.assertEqual(
|
||||
get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2
|
||||
)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1)
|
||||
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2)
|
||||
|
||||
bi_module = torch.jit.script(MyBundledInputModule())
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
bi_module, [(torch.tensor([1]),)], []
|
||||
)
|
||||
bi_module, [(torch.tensor([1]),)], [])
|
||||
bi_module_lint_list = generate_mobile_module_lints(bi_module)
|
||||
self.assertEqual(len(bi_module_lint_list), 0)
|
||||
|
||||
@ -484,21 +383,20 @@ class TestOptimizer(TestCase):
|
||||
|
||||
# Expected to be False since no bundled inputs methods were added
|
||||
self.assertFalse(
|
||||
hasattr(module_optim_bi_not_preserved, "get_all_bundled_inputs")
|
||||
or hasattr(module_optim_bi_not_preserved, "get_num_bundled_inputs")
|
||||
hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or
|
||||
hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs')
|
||||
)
|
||||
|
||||
# Add bundled inputs methods to the module
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
bi_module, [(torch.tensor([1]),)], []
|
||||
)
|
||||
bi_module, [(torch.tensor([1]),)], [])
|
||||
# Now they should be preserved
|
||||
module_optim_bi_preserved = optimize_for_mobile(bi_module)
|
||||
|
||||
# All of the bundled inputs methods were preserved
|
||||
self.assertTrue(
|
||||
hasattr(module_optim_bi_preserved, "get_all_bundled_inputs")
|
||||
and hasattr(module_optim_bi_preserved, "get_num_bundled_inputs")
|
||||
hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and
|
||||
hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs')
|
||||
)
|
||||
|
||||
bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0]
|
||||
@ -508,22 +406,19 @@ class TestOptimizer(TestCase):
|
||||
# we will not try to preserve them unless specified by the user.
|
||||
incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule())
|
||||
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module)
|
||||
self.assertFalse(hasattr(incomplete_bi_module_optim, "get_all_bundled_inputs"))
|
||||
self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
|
||||
|
||||
# Specifically preserve get_all_bundled_inputs even if it's the only one
|
||||
# bundled inputs method available.
|
||||
incomplete_bi_module_optim = optimize_for_mobile(
|
||||
incomplete_bi_module, preserved_methods=["get_all_bundled_inputs"]
|
||||
)
|
||||
self.assertTrue(hasattr(incomplete_bi_module_optim, "get_all_bundled_inputs"))
|
||||
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
|
||||
self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
|
||||
|
||||
@unittest.skipUnless(
|
||||
torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
|
||||
)
|
||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests."
|
||||
" Please build with USE_XNNPACK=1.")
|
||||
def test_hoist_conv_packed_params(self):
|
||||
|
||||
if "qnnpack" not in torch.backends.quantized.supported_engines:
|
||||
if 'qnnpack' not in torch.backends.quantized.supported_engines:
|
||||
return
|
||||
|
||||
class Standalone(nn.Module):
|
||||
@ -544,7 +439,7 @@ class TestOptimizer(TestCase):
|
||||
return x
|
||||
|
||||
def fuse_model(self):
|
||||
torch.quantization.fuse_modules(self, [["conv2", "relu"]], inplace=True)
|
||||
torch.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True)
|
||||
pass
|
||||
|
||||
class Child(nn.Module):
|
||||
@ -575,10 +470,9 @@ class TestOptimizer(TestCase):
|
||||
def fuse_model(self):
|
||||
pass
|
||||
|
||||
with override_quantized_engine("qnnpack"):
|
||||
|
||||
with override_quantized_engine('qnnpack'):
|
||||
def _quant_script_and_optimize(model):
|
||||
model.qconfig = torch.quantization.get_default_qconfig("qnnpack")
|
||||
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
|
||||
model.fuse_model()
|
||||
torch.quantization.prepare(model, inplace=True)
|
||||
model(torch.randn(4, 1, 4, 4))
|
||||
@ -590,11 +484,9 @@ class TestOptimizer(TestCase):
|
||||
# basic case
|
||||
|
||||
m, m_optim = _quant_script_and_optimize(Standalone())
|
||||
FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]').check_count(
|
||||
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant",
|
||||
2,
|
||||
exactly=True,
|
||||
).run(m_optim.graph)
|
||||
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
|
||||
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
|
||||
.run(m_optim.graph)
|
||||
self.assertFalse(hasattr(m_optim, "conv1"))
|
||||
self.assertFalse(hasattr(m_optim, "conv2"))
|
||||
|
||||
@ -606,11 +498,9 @@ class TestOptimizer(TestCase):
|
||||
# generic case
|
||||
|
||||
m, m_optim = _quant_script_and_optimize(Parent())
|
||||
FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]').check_count(
|
||||
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant",
|
||||
2,
|
||||
exactly=True,
|
||||
).run(m_optim.graph)
|
||||
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
|
||||
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
|
||||
.run(m_optim.graph)
|
||||
self.assertFalse(hasattr(m_optim, "conv1"))
|
||||
self.assertFalse(hasattr(m_optim, "child"))
|
||||
|
||||
@ -635,7 +525,7 @@ class TestOptimizer(TestCase):
|
||||
class MyInnerTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyInnerTestModule, self).__init__()
|
||||
self.pqr = torch.Tensor([10.0, 20.0, 30.0])
|
||||
self.pqr = torch.Tensor([10., 20., 30.])
|
||||
|
||||
def forward(self, inputs):
|
||||
return inputs
|
||||
@ -648,7 +538,7 @@ class TestOptimizer(TestCase):
|
||||
def __init__(self):
|
||||
super(MyTestModule, self).__init__()
|
||||
self.abc = 23
|
||||
self.pqr = torch.Tensor([1.0, 2.0, 3.0])
|
||||
self.pqr = torch.Tensor([1., 2., 3.])
|
||||
self.inner = MyInnerTestModule()
|
||||
|
||||
def forward(self, inputs):
|
||||
@ -702,21 +592,16 @@ class TestOptimizer(TestCase):
|
||||
|
||||
# Check that the cloned class has a classname that starts with __torch__.
|
||||
self.assertTrue(
|
||||
cloned.qualified_name.startswith("__torch__."),
|
||||
(
|
||||
"Expected the cloned module's name to start with the string "
|
||||
"'__torch__.', but got: {0}"
|
||||
).format(cloned.qualified_name),
|
||||
cloned.qualified_name.startswith('__torch__.'),
|
||||
("Expected the cloned module's name to start with the string "
|
||||
"'__torch__.', but got: {0}").format(cloned.qualified_name),
|
||||
)
|
||||
|
||||
|
||||
# Case-2: Successfully clone the module, ignoring the attribute pqr, and the method that references it.
|
||||
cloned = torch._C._hack_do_not_use_clone_module_with_class(
|
||||
m._c,
|
||||
[
|
||||
"dummy_method_not_cloned",
|
||||
"dummy_method_not_cloned2",
|
||||
"dummy_method_ref_attr_pqr",
|
||||
],
|
||||
["dummy_method_not_cloned", "dummy_method_not_cloned2", "dummy_method_ref_attr_pqr"],
|
||||
["pqr"],
|
||||
)
|
||||
|
||||
@ -727,12 +612,11 @@ class TestOptimizer(TestCase):
|
||||
self.assertEqual(hasattr(cloned, "dummy_method_ref_attr_pqr"), False)
|
||||
self.assertEqual(hasattr(cloned, "pqr"), False)
|
||||
|
||||
|
||||
# Case-3: The statement below will throw since dummy_method_cloned2 is preserved,
|
||||
# and references dummy_method_not_cloned, which is not cloned.
|
||||
with self.assertRaises(RuntimeError):
|
||||
cloned = torch._C._hack_do_not_use_clone_module_with_class(
|
||||
m._c, ["dummy_method_not_cloned"], []
|
||||
)
|
||||
cloned = torch._C._hack_do_not_use_clone_module_with_class(m._c, ["dummy_method_not_cloned"], [])
|
||||
|
||||
# Case-4: The statement below will throw since dummy_method_ref_attr_pqr
|
||||
# is preserved, and references "pqr", which is not cloned.
|
||||
@ -744,5 +628,5 @@ class TestOptimizer(TestCase):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user