Enable UFMT on all of test/quantization/ao_migration &bc (#123994)

Partially addresses #123062
Ran lintrunner on:
- test/quantization/ao_migration
- test/quantization/bc

Detail:
```
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```

@ezyang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123994
Approved by: https://github.com/ezyang
This commit is contained in:
WeiChunyu-star
2024-04-13 06:36:07 +00:00
committed by PyTorch MergeBot
parent 285c93d64d
commit 6ac8fe46dd
6 changed files with 707 additions and 516 deletions

View File

@ -1334,13 +1334,6 @@ exclude_patterns = [
'test/profiler/test_profiler.py',
'test/profiler/test_profiler_tree.py',
'test/quantization/__init__.py',
'test/quantization/ao_migration/__init__.py',
'test/quantization/ao_migration/common.py',
'test/quantization/ao_migration/test_ao_migration.py',
'test/quantization/ao_migration/test_quantization.py',
'test/quantization/ao_migration/test_quantization_fx.py',
'test/quantization/bc/__init__.py',
'test/quantization/bc/test_backward_compatibility.py',
'test/quantization/core/__init__.py',
'test/quantization/core/experimental/apot_fx_graph_mode_ptq.py',
'test/quantization/core/experimental/apot_fx_graph_mode_qat.py',

View File

@ -1,42 +1,52 @@
from torch.testing._internal.common_utils import TestCase
import importlib
from typing import List, Optional
from torch.testing._internal.common_utils import TestCase
class AOMigrationTestCase(TestCase):
def _test_function_import(self, package_name: str, function_list: List[str],
base: Optional[str] = None, new_package_name: Optional[str] = None):
def _test_function_import(
self,
package_name: str,
function_list: List[str],
base: Optional[str] = None,
new_package_name: Optional[str] = None,
):
r"""Tests individual function list import by comparing the functions
and their hashes."""
if base is None:
base = 'quantization'
old_base = 'torch.' + base
new_base = 'torch.ao.' + base
base = "quantization"
old_base = "torch." + base
new_base = "torch.ao." + base
if new_package_name is None:
new_package_name = package_name
old_location = importlib.import_module(f'{old_base}.{package_name}')
new_location = importlib.import_module(f'{new_base}.{new_package_name}')
old_location = importlib.import_module(f"{old_base}.{package_name}")
new_location = importlib.import_module(f"{new_base}.{new_package_name}")
for fn_name in function_list:
old_function = getattr(old_location, fn_name)
new_function = getattr(new_location, fn_name)
assert old_function == new_function, f"Functions don't match: {fn_name}"
assert hash(old_function) == hash(new_function), \
f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
assert hash(old_function) == hash(new_function), (
f"Hashes don't match: {old_function}({hash(old_function)}) vs. "
f"{new_function}({hash(new_function)})"
)
def _test_dict_import(self, package_name: str, dict_list: List[str],
base: Optional[str] = None):
def _test_dict_import(
self, package_name: str, dict_list: List[str], base: Optional[str] = None
):
r"""Tests individual function list import by comparing the functions
and their hashes."""
if base is None:
base = 'quantization'
old_base = 'torch.' + base
new_base = 'torch.ao.' + base
old_location = importlib.import_module(f'{old_base}.{package_name}')
new_location = importlib.import_module(f'{new_base}.{package_name}')
base = "quantization"
old_base = "torch." + base
new_base = "torch.ao." + base
old_location = importlib.import_module(f"{old_base}.{package_name}")
new_location = importlib.import_module(f"{new_base}.{package_name}")
for dict_name in dict_list:
old_dict = getattr(old_location, dict_name)
new_dict = getattr(new_location, dict_name)
assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
for key in new_dict.keys():
assert old_dict[key] == new_dict[key], f"Dicts don't match: {dict_name} for key {key}"
assert (
old_dict[key] == new_dict[key]
), f"Dicts don't match: {dict_name} for key {key}"

View File

@ -7,257 +7,261 @@ class TestAOMigrationNNQuantized(AOMigrationTestCase):
def test_functional_import(self):
r"""Tests the migration of the torch.nn.quantized.functional"""
function_list = [
'avg_pool2d',
'avg_pool3d',
'adaptive_avg_pool2d',
'adaptive_avg_pool3d',
'conv1d',
'conv2d',
'conv3d',
'interpolate',
'linear',
'max_pool1d',
'max_pool2d',
'celu',
'leaky_relu',
'hardtanh',
'hardswish',
'threshold',
'elu',
'hardsigmoid',
'clamp',
'upsample',
'upsample_bilinear',
'upsample_nearest',
"avg_pool2d",
"avg_pool3d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"conv1d",
"conv2d",
"conv3d",
"interpolate",
"linear",
"max_pool1d",
"max_pool2d",
"celu",
"leaky_relu",
"hardtanh",
"hardswish",
"threshold",
"elu",
"hardsigmoid",
"clamp",
"upsample",
"upsample_bilinear",
"upsample_nearest",
]
self._test_function_import('functional', function_list, base='nn.quantized')
self._test_function_import("functional", function_list, base="nn.quantized")
def test_modules_import(self):
module_list = [
# Modules
'BatchNorm2d',
'BatchNorm3d',
'Conv1d',
'Conv2d',
'Conv3d',
'ConvTranspose1d',
'ConvTranspose2d',
'ConvTranspose3d',
'DeQuantize',
'ELU',
'Embedding',
'EmbeddingBag',
'GroupNorm',
'Hardswish',
'InstanceNorm1d',
'InstanceNorm2d',
'InstanceNorm3d',
'LayerNorm',
'LeakyReLU',
'Linear',
'MaxPool2d',
'Quantize',
'ReLU6',
'Sigmoid',
'Softmax',
'Dropout',
"BatchNorm2d",
"BatchNorm3d",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"DeQuantize",
"ELU",
"Embedding",
"EmbeddingBag",
"GroupNorm",
"Hardswish",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"LayerNorm",
"LeakyReLU",
"Linear",
"MaxPool2d",
"Quantize",
"ReLU6",
"Sigmoid",
"Softmax",
"Dropout",
# Wrapper modules
'FloatFunctional',
'FXFloatFunctional',
'QFunctional',
"FloatFunctional",
"FXFloatFunctional",
"QFunctional",
]
self._test_function_import('modules', module_list, base='nn.quantized')
self._test_function_import("modules", module_list, base="nn.quantized")
def test_modules_activation(self):
function_list = [
'ReLU6',
'Hardswish',
'ELU',
'LeakyReLU',
'Sigmoid',
'Softmax',
"ReLU6",
"Hardswish",
"ELU",
"LeakyReLU",
"Sigmoid",
"Softmax",
]
self._test_function_import('activation', function_list,
base='nn.quantized.modules')
self._test_function_import(
"activation", function_list, base="nn.quantized.modules"
)
def test_modules_batchnorm(self):
function_list = [
'BatchNorm2d',
'BatchNorm3d',
"BatchNorm2d",
"BatchNorm3d",
]
self._test_function_import('batchnorm', function_list,
base='nn.quantized.modules')
self._test_function_import(
"batchnorm", function_list, base="nn.quantized.modules"
)
def test_modules_conv(self):
function_list = [
'_reverse_repeat_padding',
'Conv1d',
'Conv2d',
'Conv3d',
'ConvTranspose1d',
'ConvTranspose2d',
'ConvTranspose3d',
"_reverse_repeat_padding",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]
self._test_function_import('conv', function_list,
base='nn.quantized.modules')
self._test_function_import("conv", function_list, base="nn.quantized.modules")
def test_modules_dropout(self):
function_list = [
'Dropout',
"Dropout",
]
self._test_function_import('dropout', function_list,
base='nn.quantized.modules')
self._test_function_import(
"dropout", function_list, base="nn.quantized.modules"
)
def test_modules_embedding_ops(self):
function_list = [
'EmbeddingPackedParams',
'Embedding',
'EmbeddingBag',
"EmbeddingPackedParams",
"Embedding",
"EmbeddingBag",
]
self._test_function_import('embedding_ops', function_list,
base='nn.quantized.modules')
self._test_function_import(
"embedding_ops", function_list, base="nn.quantized.modules"
)
def test_modules_functional_modules(self):
function_list = [
'FloatFunctional',
'FXFloatFunctional',
'QFunctional',
"FloatFunctional",
"FXFloatFunctional",
"QFunctional",
]
self._test_function_import('functional_modules', function_list,
base='nn.quantized.modules')
self._test_function_import(
"functional_modules", function_list, base="nn.quantized.modules"
)
def test_modules_linear(self):
function_list = [
'Linear',
'LinearPackedParams',
"Linear",
"LinearPackedParams",
]
self._test_function_import('linear', function_list,
base='nn.quantized.modules')
self._test_function_import("linear", function_list, base="nn.quantized.modules")
def test_modules_normalization(self):
function_list = [
'LayerNorm',
'GroupNorm',
'InstanceNorm1d',
'InstanceNorm2d',
'InstanceNorm3d',
"LayerNorm",
"GroupNorm",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
]
self._test_function_import('normalization', function_list,
base='nn.quantized.modules')
self._test_function_import(
"normalization", function_list, base="nn.quantized.modules"
)
def test_modules_utils(self):
function_list = [
'_ntuple_from_first',
'_pair_from_first',
'_quantize_weight',
'_hide_packed_params_repr',
'WeightedQuantizedModule',
"_ntuple_from_first",
"_pair_from_first",
"_quantize_weight",
"_hide_packed_params_repr",
"WeightedQuantizedModule",
]
self._test_function_import('utils', function_list,
base='nn.quantized.modules')
self._test_function_import("utils", function_list, base="nn.quantized.modules")
def test_import_nn_quantized_dynamic_import(self):
module_list = [
# Modules
'Linear',
'LSTM',
'GRU',
'LSTMCell',
'RNNCell',
'GRUCell',
'Conv1d',
'Conv2d',
'Conv3d',
'ConvTranspose1d',
'ConvTranspose2d',
'ConvTranspose3d',
"Linear",
"LSTM",
"GRU",
"LSTMCell",
"RNNCell",
"GRUCell",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]
self._test_function_import('dynamic', module_list, base='nn.quantized')
self._test_function_import("dynamic", module_list, base="nn.quantized")
def test_import_nn_quantizable_activation(self):
module_list = [
# Modules
'MultiheadAttention',
"MultiheadAttention",
]
self._test_function_import('activation', module_list, base='nn.quantizable.modules')
self._test_function_import(
"activation", module_list, base="nn.quantizable.modules"
)
def test_import_nn_quantizable_rnn(self):
module_list = [
# Modules
'LSTM',
'LSTMCell',
"LSTM",
"LSTMCell",
]
self._test_function_import('rnn', module_list, base='nn.quantizable.modules')
self._test_function_import("rnn", module_list, base="nn.quantizable.modules")
def test_import_nn_qat_conv(self):
module_list = [
'Conv1d',
'Conv2d',
'Conv3d',
"Conv1d",
"Conv2d",
"Conv3d",
]
self._test_function_import('conv', module_list, base='nn.qat.modules')
self._test_function_import("conv", module_list, base="nn.qat.modules")
def test_import_nn_qat_embedding_ops(self):
module_list = [
'Embedding',
'EmbeddingBag',
"Embedding",
"EmbeddingBag",
]
self._test_function_import('embedding_ops', module_list, base='nn.qat.modules')
self._test_function_import("embedding_ops", module_list, base="nn.qat.modules")
def test_import_nn_qat_linear(self):
module_list = [
'Linear',
"Linear",
]
self._test_function_import('linear', module_list, base='nn.qat.modules')
self._test_function_import("linear", module_list, base="nn.qat.modules")
def test_import_nn_qat_dynamic_linear(self):
module_list = [
'Linear',
"Linear",
]
self._test_function_import('linear', module_list, base='nn.qat.dynamic.modules')
self._test_function_import("linear", module_list, base="nn.qat.dynamic.modules")
class TestAOMigrationNNIntrinsic(AOMigrationTestCase):
def test_modules_import_nn_intrinsic(self):
module_list = [
# Modules
'_FusedModule',
'ConvBn1d',
'ConvBn2d',
'ConvBn3d',
'ConvBnReLU1d',
'ConvBnReLU2d',
'ConvBnReLU3d',
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
'LinearReLU',
'BNReLU2d',
'BNReLU3d',
'LinearBn1d',
"_FusedModule",
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
]
self._test_function_import('intrinsic', module_list, base='nn')
self._test_function_import("intrinsic", module_list, base="nn")
def test_modules_nn_intrinsic_fused(self):
function_list = [
'_FusedModule',
'ConvBn1d',
'ConvBn2d',
'ConvBn3d',
'ConvBnReLU1d',
'ConvBnReLU2d',
'ConvBnReLU3d',
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
'LinearReLU',
'BNReLU2d',
'BNReLU3d',
'LinearBn1d',
"_FusedModule",
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
]
self._test_function_import('fused', function_list,
base='nn.intrinsic.modules')
self._test_function_import("fused", function_list, base="nn.intrinsic.modules")
def test_modules_import_nn_intrinsic_qat(self):
module_list = [
@ -275,76 +279,83 @@ class TestAOMigrationNNIntrinsic(AOMigrationTestCase):
"update_bn_stats",
"freeze_bn_stats",
]
self._test_function_import('qat', module_list, base='nn.intrinsic')
self._test_function_import("qat", module_list, base="nn.intrinsic")
def test_modules_intrinsic_qat_conv_fused(self):
function_list = [
'ConvBn1d',
'ConvBnReLU1d',
'ConvReLU1d',
'ConvBn2d',
'ConvBnReLU2d',
'ConvReLU2d',
'ConvBn3d',
'ConvBnReLU3d',
'ConvReLU3d',
'update_bn_stats',
'freeze_bn_stats'
"ConvBn1d",
"ConvBnReLU1d",
"ConvReLU1d",
"ConvBn2d",
"ConvBnReLU2d",
"ConvReLU2d",
"ConvBn3d",
"ConvBnReLU3d",
"ConvReLU3d",
"update_bn_stats",
"freeze_bn_stats",
]
self._test_function_import('conv_fused', function_list,
base='nn.intrinsic.qat.modules')
self._test_function_import(
"conv_fused", function_list, base="nn.intrinsic.qat.modules"
)
def test_modules_intrinsic_qat_linear_fused(self):
function_list = [
'LinearBn1d',
"LinearBn1d",
]
self._test_function_import('linear_fused', function_list,
base='nn.intrinsic.qat.modules')
self._test_function_import(
"linear_fused", function_list, base="nn.intrinsic.qat.modules"
)
def test_modules_intrinsic_qat_linear_relu(self):
function_list = [
'LinearReLU',
"LinearReLU",
]
self._test_function_import('linear_relu', function_list,
base='nn.intrinsic.qat.modules')
self._test_function_import(
"linear_relu", function_list, base="nn.intrinsic.qat.modules"
)
def test_modules_import_nn_intrinsic_quantized(self):
module_list = [
'BNReLU2d',
'BNReLU3d',
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
'LinearReLU',
"BNReLU2d",
"BNReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
]
self._test_function_import('quantized', module_list, base='nn.intrinsic')
self._test_function_import("quantized", module_list, base="nn.intrinsic")
def test_modules_intrinsic_quantized_bn_relu(self):
function_list = [
'BNReLU2d',
'BNReLU3d',
"BNReLU2d",
"BNReLU3d",
]
self._test_function_import('bn_relu', function_list,
base='nn.intrinsic.quantized.modules')
self._test_function_import(
"bn_relu", function_list, base="nn.intrinsic.quantized.modules"
)
def test_modules_intrinsic_quantized_conv_relu(self):
function_list = [
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
]
self._test_function_import('conv_relu', function_list,
base='nn.intrinsic.quantized.modules')
self._test_function_import(
"conv_relu", function_list, base="nn.intrinsic.quantized.modules"
)
def test_modules_intrinsic_quantized_linear_relu(self):
function_list = [
'LinearReLU',
"LinearReLU",
]
self._test_function_import('linear_relu', function_list,
base='nn.intrinsic.quantized.modules')
self._test_function_import(
"linear_relu", function_list, base="nn.intrinsic.quantized.modules"
)
def test_modules_no_import_nn_intrinsic_quantized_dynamic(self):
# TODO(future PR): generalize this
import torch
_ = torch.ao.nn.intrinsic.quantized.dynamic
_ = torch.nn.intrinsic.quantized.dynamic

View File

@ -7,102 +7,103 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
r"""Modules and functions related to the
`torch/quantization` migration to `torch/ao/quantization`.
"""
def test_function_import_quantize(self):
function_list = [
'_convert',
'_observer_forward_hook',
'_propagate_qconfig_helper',
'_remove_activation_post_process',
'_remove_qconfig',
'_add_observer_',
'add_quant_dequant',
'convert',
'_get_observer_dict',
'_get_unique_devices_',
'_is_activation_post_process',
'prepare',
'prepare_qat',
'propagate_qconfig_',
'quantize',
'quantize_dynamic',
'quantize_qat',
'_register_activation_post_process_hook',
'swap_module',
"_convert",
"_observer_forward_hook",
"_propagate_qconfig_helper",
"_remove_activation_post_process",
"_remove_qconfig",
"_add_observer_",
"add_quant_dequant",
"convert",
"_get_observer_dict",
"_get_unique_devices_",
"_is_activation_post_process",
"prepare",
"prepare_qat",
"propagate_qconfig_",
"quantize",
"quantize_dynamic",
"quantize_qat",
"_register_activation_post_process_hook",
"swap_module",
]
self._test_function_import('quantize', function_list)
self._test_function_import("quantize", function_list)
def test_function_import_stubs(self):
function_list = [
'QuantStub',
'DeQuantStub',
'QuantWrapper',
"QuantStub",
"DeQuantStub",
"QuantWrapper",
]
self._test_function_import('stubs', function_list)
self._test_function_import("stubs", function_list)
def test_function_import_quantize_jit(self):
function_list = [
'_check_is_script_module',
'_check_forward_method',
'script_qconfig',
'script_qconfig_dict',
'fuse_conv_bn_jit',
'_prepare_jit',
'prepare_jit',
'prepare_dynamic_jit',
'_convert_jit',
'convert_jit',
'convert_dynamic_jit',
'_quantize_jit',
'quantize_jit',
'quantize_dynamic_jit',
"_check_is_script_module",
"_check_forward_method",
"script_qconfig",
"script_qconfig_dict",
"fuse_conv_bn_jit",
"_prepare_jit",
"prepare_jit",
"prepare_dynamic_jit",
"_convert_jit",
"convert_jit",
"convert_dynamic_jit",
"_quantize_jit",
"quantize_jit",
"quantize_dynamic_jit",
]
self._test_function_import('quantize_jit', function_list)
self._test_function_import("quantize_jit", function_list)
def test_function_import_fake_quantize(self):
function_list = [
'_is_per_channel',
'_is_per_tensor',
'_is_symmetric_quant',
'FakeQuantizeBase',
'FakeQuantize',
'FixedQParamsFakeQuantize',
'FusedMovingAvgObsFakeQuantize',
'default_fake_quant',
'default_weight_fake_quant',
'default_fixed_qparams_range_neg1to1_fake_quant',
'default_fixed_qparams_range_0to1_fake_quant',
'default_per_channel_weight_fake_quant',
'default_histogram_fake_quant',
'default_fused_act_fake_quant',
'default_fused_wt_fake_quant',
'default_fused_per_channel_wt_fake_quant',
'_is_fake_quant_script_module',
'disable_fake_quant',
'enable_fake_quant',
'disable_observer',
'enable_observer',
"_is_per_channel",
"_is_per_tensor",
"_is_symmetric_quant",
"FakeQuantizeBase",
"FakeQuantize",
"FixedQParamsFakeQuantize",
"FusedMovingAvgObsFakeQuantize",
"default_fake_quant",
"default_weight_fake_quant",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_fixed_qparams_range_0to1_fake_quant",
"default_per_channel_weight_fake_quant",
"default_histogram_fake_quant",
"default_fused_act_fake_quant",
"default_fused_wt_fake_quant",
"default_fused_per_channel_wt_fake_quant",
"_is_fake_quant_script_module",
"disable_fake_quant",
"enable_fake_quant",
"disable_observer",
"enable_observer",
]
self._test_function_import('fake_quantize', function_list)
self._test_function_import("fake_quantize", function_list)
def test_function_import_fuse_modules(self):
function_list = [
'_fuse_modules',
'_get_module',
'_set_module',
'fuse_conv_bn',
'fuse_conv_bn_relu',
'fuse_known_modules',
'fuse_modules',
'get_fuser_method',
"_fuse_modules",
"_get_module",
"_set_module",
"fuse_conv_bn",
"fuse_conv_bn_relu",
"fuse_known_modules",
"fuse_modules",
"get_fuser_method",
]
self._test_function_import('fuse_modules', function_list)
self._test_function_import("fuse_modules", function_list)
def test_function_import_quant_type(self):
function_list = [
'QuantType',
'_get_quant_type_to_str',
"QuantType",
"_get_quant_type_to_str",
]
self._test_function_import('quant_type', function_list)
self._test_function_import("quant_type", function_list)
def test_function_import_observer(self):
function_list = [
@ -133,7 +134,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
"default_dynamic_quant_observer",
"default_float_qparams_observer",
]
self._test_function_import('observer', function_list)
self._test_function_import("observer", function_list)
def test_function_import_qconfig(self):
function_list = [
@ -156,9 +157,9 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
"_assert_valid_qconfig",
"QConfigAny",
"_add_module_to_qconfig_obs_ctr",
"qconfig_equals"
"qconfig_equals",
]
self._test_function_import('qconfig', function_list)
self._test_function_import("qconfig", function_list)
def test_function_import_quantization_mappings(self):
function_list = [
@ -184,8 +185,8 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
"DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
"DEFAULT_MODULE_TO_ACT_POST_PROCESS",
]
self._test_function_import('quantization_mappings', function_list)
self._test_dict_import('quantization_mappings', dict_list)
self._test_function_import("quantization_mappings", function_list)
self._test_dict_import("quantization_mappings", dict_list)
def test_function_import_fuser_method_mappings(self):
function_list = [
@ -194,29 +195,27 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
"fuse_linear_bn",
"get_fuser_method",
]
dict_list = [
"_DEFAULT_OP_LIST_TO_FUSER_METHOD"
]
self._test_function_import('fuser_method_mappings', function_list)
self._test_dict_import('fuser_method_mappings', dict_list)
dict_list = ["_DEFAULT_OP_LIST_TO_FUSER_METHOD"]
self._test_function_import("fuser_method_mappings", function_list)
self._test_dict_import("fuser_method_mappings", dict_list)
def test_function_import_utils(self):
function_list = [
'activation_dtype',
'activation_is_int8_quantized',
'activation_is_statically_quantized',
'calculate_qmin_qmax',
'check_min_max_valid',
'get_combined_dict',
'get_qconfig_dtypes',
'get_qparam_dict',
'get_quant_type',
'get_swapped_custom_module_class',
'getattr_from_fqn',
'is_per_channel',
'is_per_tensor',
'weight_dtype',
'weight_is_quantized',
'weight_is_statically_quantized',
"activation_dtype",
"activation_is_int8_quantized",
"activation_is_statically_quantized",
"calculate_qmin_qmax",
"check_min_max_valid",
"get_combined_dict",
"get_qconfig_dtypes",
"get_qparam_dict",
"get_quant_type",
"get_swapped_custom_module_class",
"getattr_from_fqn",
"is_per_channel",
"is_per_tensor",
"weight_dtype",
"weight_is_quantized",
"weight_is_statically_quantized",
]
self._test_function_import('utils', function_list)
self._test_function_import("utils", function_list)

View File

@ -2,144 +2,133 @@
from .common import AOMigrationTestCase
class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_function_import_quantize_fx(self):
function_list = [
'_check_is_graph_module',
'_swap_ff_with_fxff',
'_fuse_fx',
'QuantizationTracer',
'_prepare_fx',
'_prepare_standalone_module_fx',
'fuse_fx',
'Scope',
'ScopeContextManager',
'prepare_fx',
'prepare_qat_fx',
'_convert_fx',
'convert_fx',
'_convert_standalone_module_fx',
"_check_is_graph_module",
"_swap_ff_with_fxff",
"_fuse_fx",
"QuantizationTracer",
"_prepare_fx",
"_prepare_standalone_module_fx",
"fuse_fx",
"Scope",
"ScopeContextManager",
"prepare_fx",
"prepare_qat_fx",
"_convert_fx",
"convert_fx",
"_convert_standalone_module_fx",
]
self._test_function_import('quantize_fx', function_list)
self._test_function_import("quantize_fx", function_list)
def test_function_import_fx(self):
function_list = [
'prepare',
'convert',
'fuse',
"prepare",
"convert",
"fuse",
]
self._test_function_import('fx', function_list)
self._test_function_import("fx", function_list)
def test_function_import_fx_graph_module(self):
function_list = [
'FusedGraphModule',
'ObservedGraphModule',
'_is_observed_module',
'ObservedStandaloneGraphModule',
'_is_observed_standalone_module',
'QuantizedGraphModule'
"FusedGraphModule",
"ObservedGraphModule",
"_is_observed_module",
"ObservedStandaloneGraphModule",
"_is_observed_standalone_module",
"QuantizedGraphModule",
]
self._test_function_import('fx.graph_module', function_list)
self._test_function_import("fx.graph_module", function_list)
def test_function_import_fx_pattern_utils(self):
function_list = [
'QuantizeHandler',
'_register_fusion_pattern',
'get_default_fusion_patterns',
'_register_quant_pattern',
'get_default_quant_patterns',
'get_default_output_activation_post_process_map'
"QuantizeHandler",
"_register_fusion_pattern",
"get_default_fusion_patterns",
"_register_quant_pattern",
"get_default_quant_patterns",
"get_default_output_activation_post_process_map",
]
self._test_function_import('fx.pattern_utils', function_list)
self._test_function_import("fx.pattern_utils", function_list)
def test_function_import_fx_equalize(self):
function_list = [
'reshape_scale',
'_InputEqualizationObserver',
'_WeightEqualizationObserver',
'calculate_equalization_scale',
'EqualizationQConfig',
'input_equalization_observer',
'weight_equalization_observer',
'default_equalization_qconfig',
'fused_module_supports_equalization',
'nn_module_supports_equalization',
'node_supports_equalization',
'is_equalization_observer',
'get_op_node_and_weight_eq_obs',
'maybe_get_weight_eq_obs_node',
'maybe_get_next_input_eq_obs',
'maybe_get_next_equalization_scale',
'scale_input_observer',
'scale_weight_node',
'scale_weight_functional',
'clear_weight_quant_obs_node',
'remove_node',
'update_obs_for_equalization',
'convert_eq_obs',
'_convert_equalization_ref',
'get_layer_sqnr_dict',
'get_equalization_qconfig_dict'
"reshape_scale",
"_InputEqualizationObserver",
"_WeightEqualizationObserver",
"calculate_equalization_scale",
"EqualizationQConfig",
"input_equalization_observer",
"weight_equalization_observer",
"default_equalization_qconfig",
"fused_module_supports_equalization",
"nn_module_supports_equalization",
"node_supports_equalization",
"is_equalization_observer",
"get_op_node_and_weight_eq_obs",
"maybe_get_weight_eq_obs_node",
"maybe_get_next_input_eq_obs",
"maybe_get_next_equalization_scale",
"scale_input_observer",
"scale_weight_node",
"scale_weight_functional",
"clear_weight_quant_obs_node",
"remove_node",
"update_obs_for_equalization",
"convert_eq_obs",
"_convert_equalization_ref",
"get_layer_sqnr_dict",
"get_equalization_qconfig_dict",
]
self._test_function_import('fx._equalize', function_list)
self._test_function_import("fx._equalize", function_list)
def test_function_import_fx_quantization_patterns(self):
function_list = [
'QuantizeHandler',
'BinaryOpQuantizeHandler',
'CatQuantizeHandler',
'ConvReluQuantizeHandler',
'LinearReLUQuantizeHandler',
'BatchNormQuantizeHandler',
'EmbeddingQuantizeHandler',
'RNNDynamicQuantizeHandler',
'DefaultNodeQuantizeHandler',
'FixedQParamsOpQuantizeHandler',
'CopyNodeQuantizeHandler',
'CustomModuleQuantizeHandler',
'GeneralTensorShapeOpQuantizeHandler',
'StandaloneModuleQuantizeHandler'
"QuantizeHandler",
"BinaryOpQuantizeHandler",
"CatQuantizeHandler",
"ConvReluQuantizeHandler",
"LinearReLUQuantizeHandler",
"BatchNormQuantizeHandler",
"EmbeddingQuantizeHandler",
"RNNDynamicQuantizeHandler",
"DefaultNodeQuantizeHandler",
"FixedQParamsOpQuantizeHandler",
"CopyNodeQuantizeHandler",
"CustomModuleQuantizeHandler",
"GeneralTensorShapeOpQuantizeHandler",
"StandaloneModuleQuantizeHandler",
]
self._test_function_import(
'fx.quantization_patterns',
"fx.quantization_patterns",
function_list,
new_package_name='fx.quantize_handler',
new_package_name="fx.quantize_handler",
)
def test_function_import_fx_match_utils(self):
function_list = [
'_MatchResult',
'MatchAllNode',
'_is_match',
'_find_matches'
]
self._test_function_import('fx.match_utils', function_list)
function_list = ["_MatchResult", "MatchAllNode", "_is_match", "_find_matches"]
self._test_function_import("fx.match_utils", function_list)
def test_function_import_fx_prepare(self):
function_list = [
'prepare'
]
self._test_function_import('fx.prepare', function_list)
function_list = ["prepare"]
self._test_function_import("fx.prepare", function_list)
def test_function_import_fx_convert(self):
function_list = [
'convert'
]
self._test_function_import('fx.convert', function_list)
function_list = ["convert"]
self._test_function_import("fx.convert", function_list)
def test_function_import_fx_fuse(self):
function_list = ['fuse']
self._test_function_import('fx.fuse', function_list)
function_list = ["fuse"]
self._test_function_import("fx.fuse", function_list)
def test_function_import_fx_fusion_patterns(self):
function_list = [
'FuseHandler',
'DefaultFuseHandler'
]
function_list = ["FuseHandler", "DefaultFuseHandler"]
self._test_function_import(
'fx.fusion_patterns',
"fx.fusion_patterns",
function_list,
new_package_name='fx.fuse_handler',
new_package_name="fx.fuse_handler",
)
# we removed matching test for torch.quantization.fx.quantization_types
@ -149,15 +138,15 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_function_import_fx_utils(self):
function_list = [
'get_custom_module_class_keys',
'get_linear_prepack_op_for_dtype',
'get_qconv_prepack_op',
'get_new_attr_name_with_prefix',
'graph_module_from_producer_nodes',
'assert_and_get_unique_device',
'create_getattr_from_value',
'all_node_args_have_no_tensors',
'get_non_observable_arg_indexes_and_types',
'maybe_get_next_module'
"get_custom_module_class_keys",
"get_linear_prepack_op_for_dtype",
"get_qconv_prepack_op",
"get_new_attr_name_with_prefix",
"graph_module_from_producer_nodes",
"assert_and_get_unique_device",
"create_getattr_from_value",
"all_node_args_have_no_tensors",
"get_non_observable_arg_indexes_and_types",
"maybe_get_next_module",
]
self._test_function_import('fx.utils', function_list)
self._test_function_import("fx.utils", function_list)

View File

@ -1,32 +1,39 @@
# Owner(s): ["oncall: quantization"]
import sys
import os
import sys
import unittest
from typing import Set
# torch
import torch
import torch.nn as nn
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.nn.intrinsic.quantized as nniq
from torch.fx import GraphModule
# Testing utils
from torch.testing._internal.common_utils import TestCase, IS_AVX512_VNNI_SUPPORTED
from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.quantization_torch_package_models import LinearReluFunctional
import torch.ao.quantization.quantize_fx as quantize_fx
import torch.nn as nn
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.fx import GraphModule
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_quantized import (
override_qengines,
qengine_is_fbgemm,
)
# Testing utils
from torch.testing._internal.common_utils import IS_AVX512_VNNI_SUPPORTED, TestCase
from torch.testing._internal.quantization_torch_package_models import (
LinearReluFunctional,
)
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix):]
return text[len(prefix) :]
return text
def get_filenames(self, subname):
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
@ -34,9 +41,7 @@ def get_filenames(self, subname):
module_id = self.__class__.__module__
munged_id = remove_prefix(self.id(), module_id + ".")
test_file = os.path.realpath(sys.modules[module_id].__file__)
base_name = os.path.join(os.path.dirname(test_file),
"../serialized",
munged_id)
base_name = os.path.join(os.path.dirname(test_file), "../serialized", munged_id)
subname_output = ""
if subname:
@ -51,32 +56,59 @@ def get_filenames(self, subname):
package_file = base_name + ".package.pt"
get_attr_targets_file = base_name + ".get_attr_targets.pt"
return input_file, state_dict_file, scripted_module_file, \
traced_module_file, expected_file, package_file, get_attr_targets_file
return (
input_file,
state_dict_file,
scripted_module_file,
traced_module_file,
expected_file,
package_file,
get_attr_targets_file,
)
class TestSerialization(TestCase):
""" Test backward compatiblity for serialization and numerics
"""
"""Test backward compatiblity for serialization and numerics"""
# Copy and modified from TestCase.assertExpected
def _test_op(self, qmodule, subname=None, input_size=None, input_quantized=True,
generate=False, prec=None, new_zipfile_serialization=False):
r""" Test quantized modules serialized previously can be loaded
def _test_op(
self,
qmodule,
subname=None,
input_size=None,
input_quantized=True,
generate=False,
prec=None,
new_zipfile_serialization=False,
):
r"""Test quantized modules serialized previously can be loaded
with current code, make sure we don't break backward compatibility for the
serialization of quantized modules
"""
input_file, state_dict_file, scripted_module_file, traced_module_file, \
expected_file, _package_file, _get_attr_targets_file = \
get_filenames(self, subname)
(
input_file,
state_dict_file,
scripted_module_file,
traced_module_file,
expected_file,
_package_file,
_get_attr_targets_file,
) = get_filenames(self, subname)
# only generate once.
if generate and qengine_is_fbgemm():
input_tensor = torch.rand(*input_size).float()
if input_quantized:
input_tensor = torch.quantize_per_tensor(input_tensor, 0.5, 2, torch.quint8)
input_tensor = torch.quantize_per_tensor(
input_tensor, 0.5, 2, torch.quint8
)
torch.save(input_tensor, input_file)
# Temporary fix to use _use_new_zipfile_serialization until #38379 lands.
torch.save(qmodule.state_dict(), state_dict_file, _use_new_zipfile_serialization=new_zipfile_serialization)
torch.save(
qmodule.state_dict(),
state_dict_file,
_use_new_zipfile_serialization=new_zipfile_serialization,
)
torch.jit.save(torch.jit.script(qmodule), scripted_module_file)
torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file)
torch.save(qmodule(input_tensor), expected_file)
@ -90,8 +122,16 @@ class TestSerialization(TestCase):
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)
def _test_op_graph(self, qmodule, subname=None, input_size=None, input_quantized=True,
generate=False, prec=None, new_zipfile_serialization=False):
def _test_op_graph(
self,
qmodule,
subname=None,
input_size=None,
input_quantized=True,
generate=False,
prec=None,
new_zipfile_serialization=False,
):
r"""
Input: a floating point module
@ -101,9 +141,15 @@ class TestSerialization(TestCase):
If generate == False, traces and scripts the module and quantizes the results with
PTQ, and compares to saved results.
"""
input_file, state_dict_file, scripted_module_file, traced_module_file, \
expected_file, _package_file, _get_attr_targets_file = \
get_filenames(self, subname)
(
input_file,
state_dict_file,
scripted_module_file,
traced_module_file,
expected_file,
_package_file,
_get_attr_targets_file,
) = get_filenames(self, subname)
# only generate once.
if generate and qengine_is_fbgemm():
@ -119,11 +165,13 @@ class TestSerialization(TestCase):
def _eval_fn(model, data):
model(data)
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
qconfig_dict = {"": torch.ao.quantization.default_qconfig}
scripted_q = torch.ao.quantization.quantize_jit(
scripted, qconfig_dict, _eval_fn, [input_tensor])
scripted, qconfig_dict, _eval_fn, [input_tensor]
)
traced_q = torch.ao.quantization.quantize_jit(
traced, qconfig_dict, _eval_fn, [input_tensor])
traced, qconfig_dict, _eval_fn, [input_tensor]
)
torch.jit.save(scripted_q, scripted_module_file)
torch.jit.save(traced_q, traced_module_file)
@ -136,12 +184,21 @@ class TestSerialization(TestCase):
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)
def _test_obs(self, obs, input_size, subname=None, generate=False, check_numerics=True):
def _test_obs(
self, obs, input_size, subname=None, generate=False, check_numerics=True
):
"""
Test observer code can be loaded from state_dict.
"""
input_file, state_dict_file, _, traced_module_file, expected_file, \
_package_file, _get_attr_targets_file = get_filenames(self, None)
(
input_file,
state_dict_file,
_,
traced_module_file,
expected_file,
_package_file,
_get_attr_targets_file,
) = get_filenames(self, None)
if generate:
input_tensor = torch.rand(*input_size).float()
torch.save(input_tensor, input_file)
@ -159,12 +216,18 @@ class TestSerialization(TestCase):
Verifies that files created in the past with torch.package
work on today's FX graph mode quantization transforms.
"""
input_file, state_dict_file, _scripted_module_file, _traced_module_file, \
expected_file, package_file, get_attr_targets_file = \
get_filenames(self, None)
(
input_file,
state_dict_file,
_scripted_module_file,
_traced_module_file,
expected_file,
package_file,
get_attr_targets_file,
) = get_filenames(self, None)
package_name = 'test'
resource_name_model = 'test.pkl'
package_name = "test"
resource_name_model = "test.pkl"
def _do_quant_transforms(
m: torch.nn.Module,
@ -172,8 +235,8 @@ class TestSerialization(TestCase):
) -> torch.nn.Module:
example_inputs = (input_tensor,)
# do the quantizaton transforms and save result
qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
mp = quantize_fx.prepare_fx(m, {'': qconfig}, example_inputs=example_inputs)
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
mp = quantize_fx.prepare_fx(m, {"": qconfig}, example_inputs=example_inputs)
mp(input_tensor)
mq = quantize_fx.convert_fx(mp)
return mq
@ -181,7 +244,7 @@ class TestSerialization(TestCase):
def _get_get_attr_target_strings(m: GraphModule) -> Set[str]:
results = set()
for node in m.graph.nodes:
if node.op == 'get_attr':
if node.op == "get_attr":
results.add(node.target)
return results
@ -191,7 +254,7 @@ class TestSerialization(TestCase):
# save the model with torch.package
with torch.package.PackageExporter(package_file) as exp:
exp.intern('torch.testing._internal.quantization_torch_package_models')
exp.intern("torch.testing._internal.quantization_torch_package_models")
exp.save_pickle(package_name, resource_name_model, fp32_module)
# do the quantization transforms and save the result
@ -214,7 +277,8 @@ class TestSerialization(TestCase):
get_attrs = _get_get_attr_target_strings(mq)
self.assertTrue(
get_attrs == expected_get_attrs,
f'get_attrs: expected {expected_get_attrs}, got {get_attrs}')
f"get_attrs: expected {expected_get_attrs}, got {get_attrs}",
)
output_tensor = mq(input_tensor)
self.assertTrue(torch.allclose(output_tensor, expected_output_tensor))
@ -231,29 +295,68 @@ class TestSerialization(TestCase):
@override_qengines
def test_linear_dynamic(self):
module_qint8 = nnqd.Linear(3, 1, bias_=True, dtype=torch.qint8)
self._test_op(module_qint8, "qint8", input_size=[1, 3], input_quantized=False, generate=False)
self._test_op(
module_qint8,
"qint8",
input_size=[1, 3],
input_quantized=False,
generate=False,
)
if qengine_is_fbgemm():
module_float16 = nnqd.Linear(3, 1, bias_=True, dtype=torch.float16)
self._test_op(module_float16, "float16", input_size=[1, 3], input_quantized=False, generate=False)
self._test_op(
module_float16,
"float16",
input_size=[1, 3],
input_quantized=False,
generate=False,
)
@override_qengines
def test_conv2d(self):
module = nnq.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=True, padding_mode="zeros")
module = nnq.Conv2d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
)
self._test_op(module, input_size=[1, 3, 6, 6], generate=False)
@override_qengines
def test_conv2d_nobias(self):
module = nnq.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=False, padding_mode="zeros")
module = nnq.Conv2d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
padding_mode="zeros",
)
self._test_op(module, input_size=[1, 3, 6, 6], generate=False)
@override_qengines
def test_conv2d_graph(self):
module = nn.Sequential(
torch.ao.quantization.QuantStub(),
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=True, padding_mode="zeros"),
nn.Conv2d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
),
)
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
@ -261,8 +364,17 @@ class TestSerialization(TestCase):
def test_conv2d_nobias_graph(self):
module = nn.Sequential(
torch.ao.quantization.QuantStub(),
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=False, padding_mode="zeros"),
nn.Conv2d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
padding_mode="zeros",
),
)
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
@ -272,8 +384,17 @@ class TestSerialization(TestCase):
# ConvPackedParams{n}d
module = nn.Sequential(
torch.ao.quantization.QuantStub(),
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=True, padding_mode="zeros"),
nn.Conv2d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
),
)
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
@ -283,8 +404,17 @@ class TestSerialization(TestCase):
# ConvPackedParams{n}d
module = nn.Sequential(
torch.ao.quantization.QuantStub(),
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=False, padding_mode="zeros"),
nn.Conv2d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
padding_mode="zeros",
),
)
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
@ -294,8 +424,17 @@ class TestSerialization(TestCase):
# ConvPackedParams{n}d
module = nn.Sequential(
torch.ao.quantization.QuantStub(),
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=True, padding_mode="zeros"),
nn.Conv2d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
),
)
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
@ -305,48 +444,96 @@ class TestSerialization(TestCase):
# ConvPackedParams{n}d
module = nn.Sequential(
torch.ao.quantization.QuantStub(),
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=False, padding_mode="zeros"),
nn.Conv2d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
padding_mode="zeros",
),
)
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
@override_qengines
def test_conv2d_relu(self):
module = nniq.ConvReLU2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=True, padding_mode="zeros")
module = nniq.ConvReLU2d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
)
self._test_op(module, input_size=[1, 3, 6, 6], generate=False)
# TODO: graph mode quantized conv2d module
@override_qengines
def test_conv3d(self):
if qengine_is_fbgemm():
module = nnq.Conv3d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=True, padding_mode="zeros")
module = nnq.Conv3d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
)
self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False)
# TODO: graph mode quantized conv3d module
@override_qengines
def test_conv3d_relu(self):
if qengine_is_fbgemm():
module = nniq.ConvReLU3d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
groups=1, bias=True, padding_mode="zeros")
module = nniq.ConvReLU3d(
3,
3,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
)
self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False)
# TODO: graph mode quantized conv3d module
@override_qengines
@unittest.skipIf(IS_AVX512_VNNI_SUPPORTED, "This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098")
@unittest.skipIf(
IS_AVX512_VNNI_SUPPORTED,
"This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098",
)
def test_lstm(self):
class LSTMModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = nnqd.LSTM(input_size=3, hidden_size=7, num_layers=1).to(dtype=torch.float)
self.lstm = nnqd.LSTM(input_size=3, hidden_size=7, num_layers=1).to(
dtype=torch.float
)
def forward(self, x):
x = self.lstm(x)
return x
if qengine_is_fbgemm():
mod = LSTMModule()
self._test_op(mod, input_size=[4, 4, 3], input_quantized=False, generate=False, new_zipfile_serialization=True)
self._test_op(
mod,
input_size=[4, 4, 3],
input_quantized=False,
generate=False,
new_zipfile_serialization=True,
)
def test_per_channel_observer(self):
obs = PerChannelMinMaxObserver()
@ -373,7 +560,9 @@ class TestSerialization(TestCase):
model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")
ref_model = torch.ao.quantization.QuantWrapper(model)
ref_model = torch.ao.quantization.prepare_qat(ref_model)
self._test_obs(ref_model, input_size=[5, 5], generate=False, check_numerics=False)
self._test_obs(
ref_model, input_size=[5, 5], generate=False, check_numerics=False
)
@skipIfNoFBGEMM
def test_linear_relu_package_quantization_transforms(self):