mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Enable UFMT on all of test/quantization/ao_migration &bc (#123994)
Partially addresses #123062 Ran lintrunner on: - test/quantization/ao_migration - test/quantization/bc Detail: ``` $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/123994 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
285c93d64d
commit
6ac8fe46dd
@ -1334,13 +1334,6 @@ exclude_patterns = [
|
||||
'test/profiler/test_profiler.py',
|
||||
'test/profiler/test_profiler_tree.py',
|
||||
'test/quantization/__init__.py',
|
||||
'test/quantization/ao_migration/__init__.py',
|
||||
'test/quantization/ao_migration/common.py',
|
||||
'test/quantization/ao_migration/test_ao_migration.py',
|
||||
'test/quantization/ao_migration/test_quantization.py',
|
||||
'test/quantization/ao_migration/test_quantization_fx.py',
|
||||
'test/quantization/bc/__init__.py',
|
||||
'test/quantization/bc/test_backward_compatibility.py',
|
||||
'test/quantization/core/__init__.py',
|
||||
'test/quantization/core/experimental/apot_fx_graph_mode_ptq.py',
|
||||
'test/quantization/core/experimental/apot_fx_graph_mode_qat.py',
|
||||
|
@ -1,42 +1,52 @@
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
import importlib
|
||||
from typing import List, Optional
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class AOMigrationTestCase(TestCase):
|
||||
def _test_function_import(self, package_name: str, function_list: List[str],
|
||||
base: Optional[str] = None, new_package_name: Optional[str] = None):
|
||||
def _test_function_import(
|
||||
self,
|
||||
package_name: str,
|
||||
function_list: List[str],
|
||||
base: Optional[str] = None,
|
||||
new_package_name: Optional[str] = None,
|
||||
):
|
||||
r"""Tests individual function list import by comparing the functions
|
||||
and their hashes."""
|
||||
if base is None:
|
||||
base = 'quantization'
|
||||
old_base = 'torch.' + base
|
||||
new_base = 'torch.ao.' + base
|
||||
base = "quantization"
|
||||
old_base = "torch." + base
|
||||
new_base = "torch.ao." + base
|
||||
if new_package_name is None:
|
||||
new_package_name = package_name
|
||||
old_location = importlib.import_module(f'{old_base}.{package_name}')
|
||||
new_location = importlib.import_module(f'{new_base}.{new_package_name}')
|
||||
old_location = importlib.import_module(f"{old_base}.{package_name}")
|
||||
new_location = importlib.import_module(f"{new_base}.{new_package_name}")
|
||||
for fn_name in function_list:
|
||||
old_function = getattr(old_location, fn_name)
|
||||
new_function = getattr(new_location, fn_name)
|
||||
assert old_function == new_function, f"Functions don't match: {fn_name}"
|
||||
assert hash(old_function) == hash(new_function), \
|
||||
f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
|
||||
assert hash(old_function) == hash(new_function), (
|
||||
f"Hashes don't match: {old_function}({hash(old_function)}) vs. "
|
||||
f"{new_function}({hash(new_function)})"
|
||||
)
|
||||
|
||||
def _test_dict_import(self, package_name: str, dict_list: List[str],
|
||||
base: Optional[str] = None):
|
||||
def _test_dict_import(
|
||||
self, package_name: str, dict_list: List[str], base: Optional[str] = None
|
||||
):
|
||||
r"""Tests individual function list import by comparing the functions
|
||||
and their hashes."""
|
||||
if base is None:
|
||||
base = 'quantization'
|
||||
old_base = 'torch.' + base
|
||||
new_base = 'torch.ao.' + base
|
||||
old_location = importlib.import_module(f'{old_base}.{package_name}')
|
||||
new_location = importlib.import_module(f'{new_base}.{package_name}')
|
||||
base = "quantization"
|
||||
old_base = "torch." + base
|
||||
new_base = "torch.ao." + base
|
||||
old_location = importlib.import_module(f"{old_base}.{package_name}")
|
||||
new_location = importlib.import_module(f"{new_base}.{package_name}")
|
||||
for dict_name in dict_list:
|
||||
old_dict = getattr(old_location, dict_name)
|
||||
new_dict = getattr(new_location, dict_name)
|
||||
assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
|
||||
for key in new_dict.keys():
|
||||
assert old_dict[key] == new_dict[key], f"Dicts don't match: {dict_name} for key {key}"
|
||||
assert (
|
||||
old_dict[key] == new_dict[key]
|
||||
), f"Dicts don't match: {dict_name} for key {key}"
|
||||
|
@ -7,257 +7,261 @@ class TestAOMigrationNNQuantized(AOMigrationTestCase):
|
||||
def test_functional_import(self):
|
||||
r"""Tests the migration of the torch.nn.quantized.functional"""
|
||||
function_list = [
|
||||
'avg_pool2d',
|
||||
'avg_pool3d',
|
||||
'adaptive_avg_pool2d',
|
||||
'adaptive_avg_pool3d',
|
||||
'conv1d',
|
||||
'conv2d',
|
||||
'conv3d',
|
||||
'interpolate',
|
||||
'linear',
|
||||
'max_pool1d',
|
||||
'max_pool2d',
|
||||
'celu',
|
||||
'leaky_relu',
|
||||
'hardtanh',
|
||||
'hardswish',
|
||||
'threshold',
|
||||
'elu',
|
||||
'hardsigmoid',
|
||||
'clamp',
|
||||
'upsample',
|
||||
'upsample_bilinear',
|
||||
'upsample_nearest',
|
||||
"avg_pool2d",
|
||||
"avg_pool3d",
|
||||
"adaptive_avg_pool2d",
|
||||
"adaptive_avg_pool3d",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"interpolate",
|
||||
"linear",
|
||||
"max_pool1d",
|
||||
"max_pool2d",
|
||||
"celu",
|
||||
"leaky_relu",
|
||||
"hardtanh",
|
||||
"hardswish",
|
||||
"threshold",
|
||||
"elu",
|
||||
"hardsigmoid",
|
||||
"clamp",
|
||||
"upsample",
|
||||
"upsample_bilinear",
|
||||
"upsample_nearest",
|
||||
]
|
||||
self._test_function_import('functional', function_list, base='nn.quantized')
|
||||
self._test_function_import("functional", function_list, base="nn.quantized")
|
||||
|
||||
def test_modules_import(self):
|
||||
module_list = [
|
||||
# Modules
|
||||
'BatchNorm2d',
|
||||
'BatchNorm3d',
|
||||
'Conv1d',
|
||||
'Conv2d',
|
||||
'Conv3d',
|
||||
'ConvTranspose1d',
|
||||
'ConvTranspose2d',
|
||||
'ConvTranspose3d',
|
||||
'DeQuantize',
|
||||
'ELU',
|
||||
'Embedding',
|
||||
'EmbeddingBag',
|
||||
'GroupNorm',
|
||||
'Hardswish',
|
||||
'InstanceNorm1d',
|
||||
'InstanceNorm2d',
|
||||
'InstanceNorm3d',
|
||||
'LayerNorm',
|
||||
'LeakyReLU',
|
||||
'Linear',
|
||||
'MaxPool2d',
|
||||
'Quantize',
|
||||
'ReLU6',
|
||||
'Sigmoid',
|
||||
'Softmax',
|
||||
'Dropout',
|
||||
"BatchNorm2d",
|
||||
"BatchNorm3d",
|
||||
"Conv1d",
|
||||
"Conv2d",
|
||||
"Conv3d",
|
||||
"ConvTranspose1d",
|
||||
"ConvTranspose2d",
|
||||
"ConvTranspose3d",
|
||||
"DeQuantize",
|
||||
"ELU",
|
||||
"Embedding",
|
||||
"EmbeddingBag",
|
||||
"GroupNorm",
|
||||
"Hardswish",
|
||||
"InstanceNorm1d",
|
||||
"InstanceNorm2d",
|
||||
"InstanceNorm3d",
|
||||
"LayerNorm",
|
||||
"LeakyReLU",
|
||||
"Linear",
|
||||
"MaxPool2d",
|
||||
"Quantize",
|
||||
"ReLU6",
|
||||
"Sigmoid",
|
||||
"Softmax",
|
||||
"Dropout",
|
||||
# Wrapper modules
|
||||
'FloatFunctional',
|
||||
'FXFloatFunctional',
|
||||
'QFunctional',
|
||||
"FloatFunctional",
|
||||
"FXFloatFunctional",
|
||||
"QFunctional",
|
||||
]
|
||||
self._test_function_import('modules', module_list, base='nn.quantized')
|
||||
self._test_function_import("modules", module_list, base="nn.quantized")
|
||||
|
||||
def test_modules_activation(self):
|
||||
function_list = [
|
||||
'ReLU6',
|
||||
'Hardswish',
|
||||
'ELU',
|
||||
'LeakyReLU',
|
||||
'Sigmoid',
|
||||
'Softmax',
|
||||
"ReLU6",
|
||||
"Hardswish",
|
||||
"ELU",
|
||||
"LeakyReLU",
|
||||
"Sigmoid",
|
||||
"Softmax",
|
||||
]
|
||||
self._test_function_import('activation', function_list,
|
||||
base='nn.quantized.modules')
|
||||
self._test_function_import(
|
||||
"activation", function_list, base="nn.quantized.modules"
|
||||
)
|
||||
|
||||
def test_modules_batchnorm(self):
|
||||
function_list = [
|
||||
'BatchNorm2d',
|
||||
'BatchNorm3d',
|
||||
"BatchNorm2d",
|
||||
"BatchNorm3d",
|
||||
]
|
||||
self._test_function_import('batchnorm', function_list,
|
||||
base='nn.quantized.modules')
|
||||
self._test_function_import(
|
||||
"batchnorm", function_list, base="nn.quantized.modules"
|
||||
)
|
||||
|
||||
def test_modules_conv(self):
|
||||
function_list = [
|
||||
'_reverse_repeat_padding',
|
||||
'Conv1d',
|
||||
'Conv2d',
|
||||
'Conv3d',
|
||||
'ConvTranspose1d',
|
||||
'ConvTranspose2d',
|
||||
'ConvTranspose3d',
|
||||
"_reverse_repeat_padding",
|
||||
"Conv1d",
|
||||
"Conv2d",
|
||||
"Conv3d",
|
||||
"ConvTranspose1d",
|
||||
"ConvTranspose2d",
|
||||
"ConvTranspose3d",
|
||||
]
|
||||
|
||||
self._test_function_import('conv', function_list,
|
||||
base='nn.quantized.modules')
|
||||
self._test_function_import("conv", function_list, base="nn.quantized.modules")
|
||||
|
||||
def test_modules_dropout(self):
|
||||
function_list = [
|
||||
'Dropout',
|
||||
"Dropout",
|
||||
]
|
||||
self._test_function_import('dropout', function_list,
|
||||
base='nn.quantized.modules')
|
||||
self._test_function_import(
|
||||
"dropout", function_list, base="nn.quantized.modules"
|
||||
)
|
||||
|
||||
def test_modules_embedding_ops(self):
|
||||
function_list = [
|
||||
'EmbeddingPackedParams',
|
||||
'Embedding',
|
||||
'EmbeddingBag',
|
||||
"EmbeddingPackedParams",
|
||||
"Embedding",
|
||||
"EmbeddingBag",
|
||||
]
|
||||
self._test_function_import('embedding_ops', function_list,
|
||||
base='nn.quantized.modules')
|
||||
self._test_function_import(
|
||||
"embedding_ops", function_list, base="nn.quantized.modules"
|
||||
)
|
||||
|
||||
def test_modules_functional_modules(self):
|
||||
function_list = [
|
||||
'FloatFunctional',
|
||||
'FXFloatFunctional',
|
||||
'QFunctional',
|
||||
"FloatFunctional",
|
||||
"FXFloatFunctional",
|
||||
"QFunctional",
|
||||
]
|
||||
self._test_function_import('functional_modules', function_list,
|
||||
base='nn.quantized.modules')
|
||||
self._test_function_import(
|
||||
"functional_modules", function_list, base="nn.quantized.modules"
|
||||
)
|
||||
|
||||
def test_modules_linear(self):
|
||||
function_list = [
|
||||
'Linear',
|
||||
'LinearPackedParams',
|
||||
"Linear",
|
||||
"LinearPackedParams",
|
||||
]
|
||||
self._test_function_import('linear', function_list,
|
||||
base='nn.quantized.modules')
|
||||
self._test_function_import("linear", function_list, base="nn.quantized.modules")
|
||||
|
||||
def test_modules_normalization(self):
|
||||
function_list = [
|
||||
'LayerNorm',
|
||||
'GroupNorm',
|
||||
'InstanceNorm1d',
|
||||
'InstanceNorm2d',
|
||||
'InstanceNorm3d',
|
||||
"LayerNorm",
|
||||
"GroupNorm",
|
||||
"InstanceNorm1d",
|
||||
"InstanceNorm2d",
|
||||
"InstanceNorm3d",
|
||||
]
|
||||
self._test_function_import('normalization', function_list,
|
||||
base='nn.quantized.modules')
|
||||
self._test_function_import(
|
||||
"normalization", function_list, base="nn.quantized.modules"
|
||||
)
|
||||
|
||||
def test_modules_utils(self):
|
||||
function_list = [
|
||||
'_ntuple_from_first',
|
||||
'_pair_from_first',
|
||||
'_quantize_weight',
|
||||
'_hide_packed_params_repr',
|
||||
'WeightedQuantizedModule',
|
||||
"_ntuple_from_first",
|
||||
"_pair_from_first",
|
||||
"_quantize_weight",
|
||||
"_hide_packed_params_repr",
|
||||
"WeightedQuantizedModule",
|
||||
]
|
||||
self._test_function_import('utils', function_list,
|
||||
base='nn.quantized.modules')
|
||||
self._test_function_import("utils", function_list, base="nn.quantized.modules")
|
||||
|
||||
def test_import_nn_quantized_dynamic_import(self):
|
||||
module_list = [
|
||||
# Modules
|
||||
'Linear',
|
||||
'LSTM',
|
||||
'GRU',
|
||||
'LSTMCell',
|
||||
'RNNCell',
|
||||
'GRUCell',
|
||||
'Conv1d',
|
||||
'Conv2d',
|
||||
'Conv3d',
|
||||
'ConvTranspose1d',
|
||||
'ConvTranspose2d',
|
||||
'ConvTranspose3d',
|
||||
"Linear",
|
||||
"LSTM",
|
||||
"GRU",
|
||||
"LSTMCell",
|
||||
"RNNCell",
|
||||
"GRUCell",
|
||||
"Conv1d",
|
||||
"Conv2d",
|
||||
"Conv3d",
|
||||
"ConvTranspose1d",
|
||||
"ConvTranspose2d",
|
||||
"ConvTranspose3d",
|
||||
]
|
||||
self._test_function_import('dynamic', module_list, base='nn.quantized')
|
||||
self._test_function_import("dynamic", module_list, base="nn.quantized")
|
||||
|
||||
def test_import_nn_quantizable_activation(self):
|
||||
module_list = [
|
||||
# Modules
|
||||
'MultiheadAttention',
|
||||
"MultiheadAttention",
|
||||
]
|
||||
self._test_function_import('activation', module_list, base='nn.quantizable.modules')
|
||||
self._test_function_import(
|
||||
"activation", module_list, base="nn.quantizable.modules"
|
||||
)
|
||||
|
||||
def test_import_nn_quantizable_rnn(self):
|
||||
module_list = [
|
||||
# Modules
|
||||
'LSTM',
|
||||
'LSTMCell',
|
||||
"LSTM",
|
||||
"LSTMCell",
|
||||
]
|
||||
self._test_function_import('rnn', module_list, base='nn.quantizable.modules')
|
||||
self._test_function_import("rnn", module_list, base="nn.quantizable.modules")
|
||||
|
||||
def test_import_nn_qat_conv(self):
|
||||
module_list = [
|
||||
'Conv1d',
|
||||
'Conv2d',
|
||||
'Conv3d',
|
||||
"Conv1d",
|
||||
"Conv2d",
|
||||
"Conv3d",
|
||||
]
|
||||
self._test_function_import('conv', module_list, base='nn.qat.modules')
|
||||
self._test_function_import("conv", module_list, base="nn.qat.modules")
|
||||
|
||||
def test_import_nn_qat_embedding_ops(self):
|
||||
module_list = [
|
||||
'Embedding',
|
||||
'EmbeddingBag',
|
||||
"Embedding",
|
||||
"EmbeddingBag",
|
||||
]
|
||||
self._test_function_import('embedding_ops', module_list, base='nn.qat.modules')
|
||||
self._test_function_import("embedding_ops", module_list, base="nn.qat.modules")
|
||||
|
||||
def test_import_nn_qat_linear(self):
|
||||
module_list = [
|
||||
'Linear',
|
||||
"Linear",
|
||||
]
|
||||
self._test_function_import('linear', module_list, base='nn.qat.modules')
|
||||
self._test_function_import("linear", module_list, base="nn.qat.modules")
|
||||
|
||||
def test_import_nn_qat_dynamic_linear(self):
|
||||
module_list = [
|
||||
'Linear',
|
||||
"Linear",
|
||||
]
|
||||
self._test_function_import('linear', module_list, base='nn.qat.dynamic.modules')
|
||||
self._test_function_import("linear", module_list, base="nn.qat.dynamic.modules")
|
||||
|
||||
|
||||
class TestAOMigrationNNIntrinsic(AOMigrationTestCase):
|
||||
def test_modules_import_nn_intrinsic(self):
|
||||
module_list = [
|
||||
# Modules
|
||||
'_FusedModule',
|
||||
'ConvBn1d',
|
||||
'ConvBn2d',
|
||||
'ConvBn3d',
|
||||
'ConvBnReLU1d',
|
||||
'ConvBnReLU2d',
|
||||
'ConvBnReLU3d',
|
||||
'ConvReLU1d',
|
||||
'ConvReLU2d',
|
||||
'ConvReLU3d',
|
||||
'LinearReLU',
|
||||
'BNReLU2d',
|
||||
'BNReLU3d',
|
||||
'LinearBn1d',
|
||||
"_FusedModule",
|
||||
"ConvBn1d",
|
||||
"ConvBn2d",
|
||||
"ConvBn3d",
|
||||
"ConvBnReLU1d",
|
||||
"ConvBnReLU2d",
|
||||
"ConvBnReLU3d",
|
||||
"ConvReLU1d",
|
||||
"ConvReLU2d",
|
||||
"ConvReLU3d",
|
||||
"LinearReLU",
|
||||
"BNReLU2d",
|
||||
"BNReLU3d",
|
||||
"LinearBn1d",
|
||||
]
|
||||
self._test_function_import('intrinsic', module_list, base='nn')
|
||||
self._test_function_import("intrinsic", module_list, base="nn")
|
||||
|
||||
def test_modules_nn_intrinsic_fused(self):
|
||||
function_list = [
|
||||
'_FusedModule',
|
||||
'ConvBn1d',
|
||||
'ConvBn2d',
|
||||
'ConvBn3d',
|
||||
'ConvBnReLU1d',
|
||||
'ConvBnReLU2d',
|
||||
'ConvBnReLU3d',
|
||||
'ConvReLU1d',
|
||||
'ConvReLU2d',
|
||||
'ConvReLU3d',
|
||||
'LinearReLU',
|
||||
'BNReLU2d',
|
||||
'BNReLU3d',
|
||||
'LinearBn1d',
|
||||
"_FusedModule",
|
||||
"ConvBn1d",
|
||||
"ConvBn2d",
|
||||
"ConvBn3d",
|
||||
"ConvBnReLU1d",
|
||||
"ConvBnReLU2d",
|
||||
"ConvBnReLU3d",
|
||||
"ConvReLU1d",
|
||||
"ConvReLU2d",
|
||||
"ConvReLU3d",
|
||||
"LinearReLU",
|
||||
"BNReLU2d",
|
||||
"BNReLU3d",
|
||||
"LinearBn1d",
|
||||
]
|
||||
self._test_function_import('fused', function_list,
|
||||
base='nn.intrinsic.modules')
|
||||
self._test_function_import("fused", function_list, base="nn.intrinsic.modules")
|
||||
|
||||
def test_modules_import_nn_intrinsic_qat(self):
|
||||
module_list = [
|
||||
@ -275,76 +279,83 @@ class TestAOMigrationNNIntrinsic(AOMigrationTestCase):
|
||||
"update_bn_stats",
|
||||
"freeze_bn_stats",
|
||||
]
|
||||
self._test_function_import('qat', module_list, base='nn.intrinsic')
|
||||
self._test_function_import("qat", module_list, base="nn.intrinsic")
|
||||
|
||||
def test_modules_intrinsic_qat_conv_fused(self):
|
||||
function_list = [
|
||||
'ConvBn1d',
|
||||
'ConvBnReLU1d',
|
||||
'ConvReLU1d',
|
||||
'ConvBn2d',
|
||||
'ConvBnReLU2d',
|
||||
'ConvReLU2d',
|
||||
'ConvBn3d',
|
||||
'ConvBnReLU3d',
|
||||
'ConvReLU3d',
|
||||
'update_bn_stats',
|
||||
'freeze_bn_stats'
|
||||
"ConvBn1d",
|
||||
"ConvBnReLU1d",
|
||||
"ConvReLU1d",
|
||||
"ConvBn2d",
|
||||
"ConvBnReLU2d",
|
||||
"ConvReLU2d",
|
||||
"ConvBn3d",
|
||||
"ConvBnReLU3d",
|
||||
"ConvReLU3d",
|
||||
"update_bn_stats",
|
||||
"freeze_bn_stats",
|
||||
]
|
||||
self._test_function_import('conv_fused', function_list,
|
||||
base='nn.intrinsic.qat.modules')
|
||||
self._test_function_import(
|
||||
"conv_fused", function_list, base="nn.intrinsic.qat.modules"
|
||||
)
|
||||
|
||||
def test_modules_intrinsic_qat_linear_fused(self):
|
||||
function_list = [
|
||||
'LinearBn1d',
|
||||
"LinearBn1d",
|
||||
]
|
||||
self._test_function_import('linear_fused', function_list,
|
||||
base='nn.intrinsic.qat.modules')
|
||||
self._test_function_import(
|
||||
"linear_fused", function_list, base="nn.intrinsic.qat.modules"
|
||||
)
|
||||
|
||||
def test_modules_intrinsic_qat_linear_relu(self):
|
||||
function_list = [
|
||||
'LinearReLU',
|
||||
"LinearReLU",
|
||||
]
|
||||
self._test_function_import('linear_relu', function_list,
|
||||
base='nn.intrinsic.qat.modules')
|
||||
self._test_function_import(
|
||||
"linear_relu", function_list, base="nn.intrinsic.qat.modules"
|
||||
)
|
||||
|
||||
def test_modules_import_nn_intrinsic_quantized(self):
|
||||
module_list = [
|
||||
'BNReLU2d',
|
||||
'BNReLU3d',
|
||||
'ConvReLU1d',
|
||||
'ConvReLU2d',
|
||||
'ConvReLU3d',
|
||||
'LinearReLU',
|
||||
"BNReLU2d",
|
||||
"BNReLU3d",
|
||||
"ConvReLU1d",
|
||||
"ConvReLU2d",
|
||||
"ConvReLU3d",
|
||||
"LinearReLU",
|
||||
]
|
||||
self._test_function_import('quantized', module_list, base='nn.intrinsic')
|
||||
self._test_function_import("quantized", module_list, base="nn.intrinsic")
|
||||
|
||||
def test_modules_intrinsic_quantized_bn_relu(self):
|
||||
function_list = [
|
||||
'BNReLU2d',
|
||||
'BNReLU3d',
|
||||
"BNReLU2d",
|
||||
"BNReLU3d",
|
||||
]
|
||||
self._test_function_import('bn_relu', function_list,
|
||||
base='nn.intrinsic.quantized.modules')
|
||||
self._test_function_import(
|
||||
"bn_relu", function_list, base="nn.intrinsic.quantized.modules"
|
||||
)
|
||||
|
||||
def test_modules_intrinsic_quantized_conv_relu(self):
|
||||
function_list = [
|
||||
'ConvReLU1d',
|
||||
'ConvReLU2d',
|
||||
'ConvReLU3d',
|
||||
"ConvReLU1d",
|
||||
"ConvReLU2d",
|
||||
"ConvReLU3d",
|
||||
]
|
||||
self._test_function_import('conv_relu', function_list,
|
||||
base='nn.intrinsic.quantized.modules')
|
||||
self._test_function_import(
|
||||
"conv_relu", function_list, base="nn.intrinsic.quantized.modules"
|
||||
)
|
||||
|
||||
def test_modules_intrinsic_quantized_linear_relu(self):
|
||||
function_list = [
|
||||
'LinearReLU',
|
||||
"LinearReLU",
|
||||
]
|
||||
self._test_function_import('linear_relu', function_list,
|
||||
base='nn.intrinsic.quantized.modules')
|
||||
self._test_function_import(
|
||||
"linear_relu", function_list, base="nn.intrinsic.quantized.modules"
|
||||
)
|
||||
|
||||
def test_modules_no_import_nn_intrinsic_quantized_dynamic(self):
|
||||
# TODO(future PR): generalize this
|
||||
import torch
|
||||
|
||||
_ = torch.ao.nn.intrinsic.quantized.dynamic
|
||||
_ = torch.nn.intrinsic.quantized.dynamic
|
||||
|
@ -7,102 +7,103 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
r"""Modules and functions related to the
|
||||
`torch/quantization` migration to `torch/ao/quantization`.
|
||||
"""
|
||||
|
||||
def test_function_import_quantize(self):
|
||||
function_list = [
|
||||
'_convert',
|
||||
'_observer_forward_hook',
|
||||
'_propagate_qconfig_helper',
|
||||
'_remove_activation_post_process',
|
||||
'_remove_qconfig',
|
||||
'_add_observer_',
|
||||
'add_quant_dequant',
|
||||
'convert',
|
||||
'_get_observer_dict',
|
||||
'_get_unique_devices_',
|
||||
'_is_activation_post_process',
|
||||
'prepare',
|
||||
'prepare_qat',
|
||||
'propagate_qconfig_',
|
||||
'quantize',
|
||||
'quantize_dynamic',
|
||||
'quantize_qat',
|
||||
'_register_activation_post_process_hook',
|
||||
'swap_module',
|
||||
"_convert",
|
||||
"_observer_forward_hook",
|
||||
"_propagate_qconfig_helper",
|
||||
"_remove_activation_post_process",
|
||||
"_remove_qconfig",
|
||||
"_add_observer_",
|
||||
"add_quant_dequant",
|
||||
"convert",
|
||||
"_get_observer_dict",
|
||||
"_get_unique_devices_",
|
||||
"_is_activation_post_process",
|
||||
"prepare",
|
||||
"prepare_qat",
|
||||
"propagate_qconfig_",
|
||||
"quantize",
|
||||
"quantize_dynamic",
|
||||
"quantize_qat",
|
||||
"_register_activation_post_process_hook",
|
||||
"swap_module",
|
||||
]
|
||||
self._test_function_import('quantize', function_list)
|
||||
self._test_function_import("quantize", function_list)
|
||||
|
||||
def test_function_import_stubs(self):
|
||||
function_list = [
|
||||
'QuantStub',
|
||||
'DeQuantStub',
|
||||
'QuantWrapper',
|
||||
"QuantStub",
|
||||
"DeQuantStub",
|
||||
"QuantWrapper",
|
||||
]
|
||||
self._test_function_import('stubs', function_list)
|
||||
self._test_function_import("stubs", function_list)
|
||||
|
||||
def test_function_import_quantize_jit(self):
|
||||
function_list = [
|
||||
'_check_is_script_module',
|
||||
'_check_forward_method',
|
||||
'script_qconfig',
|
||||
'script_qconfig_dict',
|
||||
'fuse_conv_bn_jit',
|
||||
'_prepare_jit',
|
||||
'prepare_jit',
|
||||
'prepare_dynamic_jit',
|
||||
'_convert_jit',
|
||||
'convert_jit',
|
||||
'convert_dynamic_jit',
|
||||
'_quantize_jit',
|
||||
'quantize_jit',
|
||||
'quantize_dynamic_jit',
|
||||
"_check_is_script_module",
|
||||
"_check_forward_method",
|
||||
"script_qconfig",
|
||||
"script_qconfig_dict",
|
||||
"fuse_conv_bn_jit",
|
||||
"_prepare_jit",
|
||||
"prepare_jit",
|
||||
"prepare_dynamic_jit",
|
||||
"_convert_jit",
|
||||
"convert_jit",
|
||||
"convert_dynamic_jit",
|
||||
"_quantize_jit",
|
||||
"quantize_jit",
|
||||
"quantize_dynamic_jit",
|
||||
]
|
||||
self._test_function_import('quantize_jit', function_list)
|
||||
self._test_function_import("quantize_jit", function_list)
|
||||
|
||||
def test_function_import_fake_quantize(self):
|
||||
function_list = [
|
||||
'_is_per_channel',
|
||||
'_is_per_tensor',
|
||||
'_is_symmetric_quant',
|
||||
'FakeQuantizeBase',
|
||||
'FakeQuantize',
|
||||
'FixedQParamsFakeQuantize',
|
||||
'FusedMovingAvgObsFakeQuantize',
|
||||
'default_fake_quant',
|
||||
'default_weight_fake_quant',
|
||||
'default_fixed_qparams_range_neg1to1_fake_quant',
|
||||
'default_fixed_qparams_range_0to1_fake_quant',
|
||||
'default_per_channel_weight_fake_quant',
|
||||
'default_histogram_fake_quant',
|
||||
'default_fused_act_fake_quant',
|
||||
'default_fused_wt_fake_quant',
|
||||
'default_fused_per_channel_wt_fake_quant',
|
||||
'_is_fake_quant_script_module',
|
||||
'disable_fake_quant',
|
||||
'enable_fake_quant',
|
||||
'disable_observer',
|
||||
'enable_observer',
|
||||
"_is_per_channel",
|
||||
"_is_per_tensor",
|
||||
"_is_symmetric_quant",
|
||||
"FakeQuantizeBase",
|
||||
"FakeQuantize",
|
||||
"FixedQParamsFakeQuantize",
|
||||
"FusedMovingAvgObsFakeQuantize",
|
||||
"default_fake_quant",
|
||||
"default_weight_fake_quant",
|
||||
"default_fixed_qparams_range_neg1to1_fake_quant",
|
||||
"default_fixed_qparams_range_0to1_fake_quant",
|
||||
"default_per_channel_weight_fake_quant",
|
||||
"default_histogram_fake_quant",
|
||||
"default_fused_act_fake_quant",
|
||||
"default_fused_wt_fake_quant",
|
||||
"default_fused_per_channel_wt_fake_quant",
|
||||
"_is_fake_quant_script_module",
|
||||
"disable_fake_quant",
|
||||
"enable_fake_quant",
|
||||
"disable_observer",
|
||||
"enable_observer",
|
||||
]
|
||||
self._test_function_import('fake_quantize', function_list)
|
||||
self._test_function_import("fake_quantize", function_list)
|
||||
|
||||
def test_function_import_fuse_modules(self):
|
||||
function_list = [
|
||||
'_fuse_modules',
|
||||
'_get_module',
|
||||
'_set_module',
|
||||
'fuse_conv_bn',
|
||||
'fuse_conv_bn_relu',
|
||||
'fuse_known_modules',
|
||||
'fuse_modules',
|
||||
'get_fuser_method',
|
||||
"_fuse_modules",
|
||||
"_get_module",
|
||||
"_set_module",
|
||||
"fuse_conv_bn",
|
||||
"fuse_conv_bn_relu",
|
||||
"fuse_known_modules",
|
||||
"fuse_modules",
|
||||
"get_fuser_method",
|
||||
]
|
||||
self._test_function_import('fuse_modules', function_list)
|
||||
self._test_function_import("fuse_modules", function_list)
|
||||
|
||||
def test_function_import_quant_type(self):
|
||||
function_list = [
|
||||
'QuantType',
|
||||
'_get_quant_type_to_str',
|
||||
"QuantType",
|
||||
"_get_quant_type_to_str",
|
||||
]
|
||||
self._test_function_import('quant_type', function_list)
|
||||
self._test_function_import("quant_type", function_list)
|
||||
|
||||
def test_function_import_observer(self):
|
||||
function_list = [
|
||||
@ -133,7 +134,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
"default_dynamic_quant_observer",
|
||||
"default_float_qparams_observer",
|
||||
]
|
||||
self._test_function_import('observer', function_list)
|
||||
self._test_function_import("observer", function_list)
|
||||
|
||||
def test_function_import_qconfig(self):
|
||||
function_list = [
|
||||
@ -156,9 +157,9 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
"_assert_valid_qconfig",
|
||||
"QConfigAny",
|
||||
"_add_module_to_qconfig_obs_ctr",
|
||||
"qconfig_equals"
|
||||
"qconfig_equals",
|
||||
]
|
||||
self._test_function_import('qconfig', function_list)
|
||||
self._test_function_import("qconfig", function_list)
|
||||
|
||||
def test_function_import_quantization_mappings(self):
|
||||
function_list = [
|
||||
@ -184,8 +185,8 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
"DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
|
||||
"DEFAULT_MODULE_TO_ACT_POST_PROCESS",
|
||||
]
|
||||
self._test_function_import('quantization_mappings', function_list)
|
||||
self._test_dict_import('quantization_mappings', dict_list)
|
||||
self._test_function_import("quantization_mappings", function_list)
|
||||
self._test_dict_import("quantization_mappings", dict_list)
|
||||
|
||||
def test_function_import_fuser_method_mappings(self):
|
||||
function_list = [
|
||||
@ -194,29 +195,27 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
"fuse_linear_bn",
|
||||
"get_fuser_method",
|
||||
]
|
||||
dict_list = [
|
||||
"_DEFAULT_OP_LIST_TO_FUSER_METHOD"
|
||||
]
|
||||
self._test_function_import('fuser_method_mappings', function_list)
|
||||
self._test_dict_import('fuser_method_mappings', dict_list)
|
||||
dict_list = ["_DEFAULT_OP_LIST_TO_FUSER_METHOD"]
|
||||
self._test_function_import("fuser_method_mappings", function_list)
|
||||
self._test_dict_import("fuser_method_mappings", dict_list)
|
||||
|
||||
def test_function_import_utils(self):
|
||||
function_list = [
|
||||
'activation_dtype',
|
||||
'activation_is_int8_quantized',
|
||||
'activation_is_statically_quantized',
|
||||
'calculate_qmin_qmax',
|
||||
'check_min_max_valid',
|
||||
'get_combined_dict',
|
||||
'get_qconfig_dtypes',
|
||||
'get_qparam_dict',
|
||||
'get_quant_type',
|
||||
'get_swapped_custom_module_class',
|
||||
'getattr_from_fqn',
|
||||
'is_per_channel',
|
||||
'is_per_tensor',
|
||||
'weight_dtype',
|
||||
'weight_is_quantized',
|
||||
'weight_is_statically_quantized',
|
||||
"activation_dtype",
|
||||
"activation_is_int8_quantized",
|
||||
"activation_is_statically_quantized",
|
||||
"calculate_qmin_qmax",
|
||||
"check_min_max_valid",
|
||||
"get_combined_dict",
|
||||
"get_qconfig_dtypes",
|
||||
"get_qparam_dict",
|
||||
"get_quant_type",
|
||||
"get_swapped_custom_module_class",
|
||||
"getattr_from_fqn",
|
||||
"is_per_channel",
|
||||
"is_per_tensor",
|
||||
"weight_dtype",
|
||||
"weight_is_quantized",
|
||||
"weight_is_statically_quantized",
|
||||
]
|
||||
self._test_function_import('utils', function_list)
|
||||
self._test_function_import("utils", function_list)
|
||||
|
@ -2,144 +2,133 @@
|
||||
|
||||
from .common import AOMigrationTestCase
|
||||
|
||||
|
||||
class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
def test_function_import_quantize_fx(self):
|
||||
function_list = [
|
||||
'_check_is_graph_module',
|
||||
'_swap_ff_with_fxff',
|
||||
'_fuse_fx',
|
||||
'QuantizationTracer',
|
||||
'_prepare_fx',
|
||||
'_prepare_standalone_module_fx',
|
||||
'fuse_fx',
|
||||
'Scope',
|
||||
'ScopeContextManager',
|
||||
'prepare_fx',
|
||||
'prepare_qat_fx',
|
||||
'_convert_fx',
|
||||
'convert_fx',
|
||||
'_convert_standalone_module_fx',
|
||||
"_check_is_graph_module",
|
||||
"_swap_ff_with_fxff",
|
||||
"_fuse_fx",
|
||||
"QuantizationTracer",
|
||||
"_prepare_fx",
|
||||
"_prepare_standalone_module_fx",
|
||||
"fuse_fx",
|
||||
"Scope",
|
||||
"ScopeContextManager",
|
||||
"prepare_fx",
|
||||
"prepare_qat_fx",
|
||||
"_convert_fx",
|
||||
"convert_fx",
|
||||
"_convert_standalone_module_fx",
|
||||
]
|
||||
self._test_function_import('quantize_fx', function_list)
|
||||
self._test_function_import("quantize_fx", function_list)
|
||||
|
||||
def test_function_import_fx(self):
|
||||
function_list = [
|
||||
'prepare',
|
||||
'convert',
|
||||
'fuse',
|
||||
"prepare",
|
||||
"convert",
|
||||
"fuse",
|
||||
]
|
||||
self._test_function_import('fx', function_list)
|
||||
self._test_function_import("fx", function_list)
|
||||
|
||||
def test_function_import_fx_graph_module(self):
|
||||
function_list = [
|
||||
'FusedGraphModule',
|
||||
'ObservedGraphModule',
|
||||
'_is_observed_module',
|
||||
'ObservedStandaloneGraphModule',
|
||||
'_is_observed_standalone_module',
|
||||
'QuantizedGraphModule'
|
||||
"FusedGraphModule",
|
||||
"ObservedGraphModule",
|
||||
"_is_observed_module",
|
||||
"ObservedStandaloneGraphModule",
|
||||
"_is_observed_standalone_module",
|
||||
"QuantizedGraphModule",
|
||||
]
|
||||
self._test_function_import('fx.graph_module', function_list)
|
||||
self._test_function_import("fx.graph_module", function_list)
|
||||
|
||||
def test_function_import_fx_pattern_utils(self):
|
||||
function_list = [
|
||||
'QuantizeHandler',
|
||||
'_register_fusion_pattern',
|
||||
'get_default_fusion_patterns',
|
||||
'_register_quant_pattern',
|
||||
'get_default_quant_patterns',
|
||||
'get_default_output_activation_post_process_map'
|
||||
"QuantizeHandler",
|
||||
"_register_fusion_pattern",
|
||||
"get_default_fusion_patterns",
|
||||
"_register_quant_pattern",
|
||||
"get_default_quant_patterns",
|
||||
"get_default_output_activation_post_process_map",
|
||||
]
|
||||
self._test_function_import('fx.pattern_utils', function_list)
|
||||
self._test_function_import("fx.pattern_utils", function_list)
|
||||
|
||||
def test_function_import_fx_equalize(self):
|
||||
function_list = [
|
||||
'reshape_scale',
|
||||
'_InputEqualizationObserver',
|
||||
'_WeightEqualizationObserver',
|
||||
'calculate_equalization_scale',
|
||||
'EqualizationQConfig',
|
||||
'input_equalization_observer',
|
||||
'weight_equalization_observer',
|
||||
'default_equalization_qconfig',
|
||||
'fused_module_supports_equalization',
|
||||
'nn_module_supports_equalization',
|
||||
'node_supports_equalization',
|
||||
'is_equalization_observer',
|
||||
'get_op_node_and_weight_eq_obs',
|
||||
'maybe_get_weight_eq_obs_node',
|
||||
'maybe_get_next_input_eq_obs',
|
||||
'maybe_get_next_equalization_scale',
|
||||
'scale_input_observer',
|
||||
'scale_weight_node',
|
||||
'scale_weight_functional',
|
||||
'clear_weight_quant_obs_node',
|
||||
'remove_node',
|
||||
'update_obs_for_equalization',
|
||||
'convert_eq_obs',
|
||||
'_convert_equalization_ref',
|
||||
'get_layer_sqnr_dict',
|
||||
'get_equalization_qconfig_dict'
|
||||
"reshape_scale",
|
||||
"_InputEqualizationObserver",
|
||||
"_WeightEqualizationObserver",
|
||||
"calculate_equalization_scale",
|
||||
"EqualizationQConfig",
|
||||
"input_equalization_observer",
|
||||
"weight_equalization_observer",
|
||||
"default_equalization_qconfig",
|
||||
"fused_module_supports_equalization",
|
||||
"nn_module_supports_equalization",
|
||||
"node_supports_equalization",
|
||||
"is_equalization_observer",
|
||||
"get_op_node_and_weight_eq_obs",
|
||||
"maybe_get_weight_eq_obs_node",
|
||||
"maybe_get_next_input_eq_obs",
|
||||
"maybe_get_next_equalization_scale",
|
||||
"scale_input_observer",
|
||||
"scale_weight_node",
|
||||
"scale_weight_functional",
|
||||
"clear_weight_quant_obs_node",
|
||||
"remove_node",
|
||||
"update_obs_for_equalization",
|
||||
"convert_eq_obs",
|
||||
"_convert_equalization_ref",
|
||||
"get_layer_sqnr_dict",
|
||||
"get_equalization_qconfig_dict",
|
||||
]
|
||||
self._test_function_import('fx._equalize', function_list)
|
||||
self._test_function_import("fx._equalize", function_list)
|
||||
|
||||
def test_function_import_fx_quantization_patterns(self):
|
||||
function_list = [
|
||||
'QuantizeHandler',
|
||||
'BinaryOpQuantizeHandler',
|
||||
'CatQuantizeHandler',
|
||||
'ConvReluQuantizeHandler',
|
||||
'LinearReLUQuantizeHandler',
|
||||
'BatchNormQuantizeHandler',
|
||||
'EmbeddingQuantizeHandler',
|
||||
'RNNDynamicQuantizeHandler',
|
||||
'DefaultNodeQuantizeHandler',
|
||||
'FixedQParamsOpQuantizeHandler',
|
||||
'CopyNodeQuantizeHandler',
|
||||
'CustomModuleQuantizeHandler',
|
||||
'GeneralTensorShapeOpQuantizeHandler',
|
||||
'StandaloneModuleQuantizeHandler'
|
||||
"QuantizeHandler",
|
||||
"BinaryOpQuantizeHandler",
|
||||
"CatQuantizeHandler",
|
||||
"ConvReluQuantizeHandler",
|
||||
"LinearReLUQuantizeHandler",
|
||||
"BatchNormQuantizeHandler",
|
||||
"EmbeddingQuantizeHandler",
|
||||
"RNNDynamicQuantizeHandler",
|
||||
"DefaultNodeQuantizeHandler",
|
||||
"FixedQParamsOpQuantizeHandler",
|
||||
"CopyNodeQuantizeHandler",
|
||||
"CustomModuleQuantizeHandler",
|
||||
"GeneralTensorShapeOpQuantizeHandler",
|
||||
"StandaloneModuleQuantizeHandler",
|
||||
]
|
||||
self._test_function_import(
|
||||
'fx.quantization_patterns',
|
||||
"fx.quantization_patterns",
|
||||
function_list,
|
||||
new_package_name='fx.quantize_handler',
|
||||
new_package_name="fx.quantize_handler",
|
||||
)
|
||||
|
||||
def test_function_import_fx_match_utils(self):
|
||||
function_list = [
|
||||
'_MatchResult',
|
||||
'MatchAllNode',
|
||||
'_is_match',
|
||||
'_find_matches'
|
||||
]
|
||||
self._test_function_import('fx.match_utils', function_list)
|
||||
function_list = ["_MatchResult", "MatchAllNode", "_is_match", "_find_matches"]
|
||||
self._test_function_import("fx.match_utils", function_list)
|
||||
|
||||
def test_function_import_fx_prepare(self):
|
||||
function_list = [
|
||||
'prepare'
|
||||
]
|
||||
self._test_function_import('fx.prepare', function_list)
|
||||
function_list = ["prepare"]
|
||||
self._test_function_import("fx.prepare", function_list)
|
||||
|
||||
def test_function_import_fx_convert(self):
|
||||
function_list = [
|
||||
'convert'
|
||||
]
|
||||
self._test_function_import('fx.convert', function_list)
|
||||
function_list = ["convert"]
|
||||
self._test_function_import("fx.convert", function_list)
|
||||
|
||||
def test_function_import_fx_fuse(self):
|
||||
function_list = ['fuse']
|
||||
self._test_function_import('fx.fuse', function_list)
|
||||
function_list = ["fuse"]
|
||||
self._test_function_import("fx.fuse", function_list)
|
||||
|
||||
def test_function_import_fx_fusion_patterns(self):
|
||||
function_list = [
|
||||
'FuseHandler',
|
||||
'DefaultFuseHandler'
|
||||
]
|
||||
function_list = ["FuseHandler", "DefaultFuseHandler"]
|
||||
self._test_function_import(
|
||||
'fx.fusion_patterns',
|
||||
"fx.fusion_patterns",
|
||||
function_list,
|
||||
new_package_name='fx.fuse_handler',
|
||||
new_package_name="fx.fuse_handler",
|
||||
)
|
||||
|
||||
# we removed matching test for torch.quantization.fx.quantization_types
|
||||
@ -149,15 +138,15 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
|
||||
def test_function_import_fx_utils(self):
|
||||
function_list = [
|
||||
'get_custom_module_class_keys',
|
||||
'get_linear_prepack_op_for_dtype',
|
||||
'get_qconv_prepack_op',
|
||||
'get_new_attr_name_with_prefix',
|
||||
'graph_module_from_producer_nodes',
|
||||
'assert_and_get_unique_device',
|
||||
'create_getattr_from_value',
|
||||
'all_node_args_have_no_tensors',
|
||||
'get_non_observable_arg_indexes_and_types',
|
||||
'maybe_get_next_module'
|
||||
"get_custom_module_class_keys",
|
||||
"get_linear_prepack_op_for_dtype",
|
||||
"get_qconv_prepack_op",
|
||||
"get_new_attr_name_with_prefix",
|
||||
"graph_module_from_producer_nodes",
|
||||
"assert_and_get_unique_device",
|
||||
"create_getattr_from_value",
|
||||
"all_node_args_have_no_tensors",
|
||||
"get_non_observable_arg_indexes_and_types",
|
||||
"maybe_get_next_module",
|
||||
]
|
||||
self._test_function_import('fx.utils', function_list)
|
||||
self._test_function_import("fx.utils", function_list)
|
||||
|
@ -1,32 +1,39 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Set
|
||||
|
||||
# torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.ao.nn.intrinsic.quantized as nniq
|
||||
import torch.ao.nn.quantized as nnq
|
||||
import torch.ao.nn.quantized.dynamic as nnqd
|
||||
import torch.ao.nn.intrinsic.quantized as nniq
|
||||
from torch.fx import GraphModule
|
||||
|
||||
# Testing utils
|
||||
from torch.testing._internal.common_utils import TestCase, IS_AVX512_VNNI_SUPPORTED
|
||||
from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||
from torch.testing._internal.quantization_torch_package_models import LinearReluFunctional
|
||||
import torch.ao.quantization.quantize_fx as quantize_fx
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver
|
||||
import torch.ao.quantization.quantize_fx as quantize_fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||
from torch.testing._internal.common_quantized import (
|
||||
override_qengines,
|
||||
qengine_is_fbgemm,
|
||||
)
|
||||
|
||||
# Testing utils
|
||||
from torch.testing._internal.common_utils import IS_AVX512_VNNI_SUPPORTED, TestCase
|
||||
from torch.testing._internal.quantization_torch_package_models import (
|
||||
LinearReluFunctional,
|
||||
)
|
||||
|
||||
|
||||
def remove_prefix(text, prefix):
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix):]
|
||||
return text[len(prefix) :]
|
||||
return text
|
||||
|
||||
|
||||
def get_filenames(self, subname):
|
||||
# NB: we take __file__ from the module that defined the test
|
||||
# class, so we place the expect directory where the test script
|
||||
@ -34,9 +41,7 @@ def get_filenames(self, subname):
|
||||
module_id = self.__class__.__module__
|
||||
munged_id = remove_prefix(self.id(), module_id + ".")
|
||||
test_file = os.path.realpath(sys.modules[module_id].__file__)
|
||||
base_name = os.path.join(os.path.dirname(test_file),
|
||||
"../serialized",
|
||||
munged_id)
|
||||
base_name = os.path.join(os.path.dirname(test_file), "../serialized", munged_id)
|
||||
|
||||
subname_output = ""
|
||||
if subname:
|
||||
@ -51,32 +56,59 @@ def get_filenames(self, subname):
|
||||
package_file = base_name + ".package.pt"
|
||||
get_attr_targets_file = base_name + ".get_attr_targets.pt"
|
||||
|
||||
return input_file, state_dict_file, scripted_module_file, \
|
||||
traced_module_file, expected_file, package_file, get_attr_targets_file
|
||||
return (
|
||||
input_file,
|
||||
state_dict_file,
|
||||
scripted_module_file,
|
||||
traced_module_file,
|
||||
expected_file,
|
||||
package_file,
|
||||
get_attr_targets_file,
|
||||
)
|
||||
|
||||
|
||||
class TestSerialization(TestCase):
|
||||
""" Test backward compatiblity for serialization and numerics
|
||||
"""
|
||||
"""Test backward compatiblity for serialization and numerics"""
|
||||
|
||||
# Copy and modified from TestCase.assertExpected
|
||||
def _test_op(self, qmodule, subname=None, input_size=None, input_quantized=True,
|
||||
generate=False, prec=None, new_zipfile_serialization=False):
|
||||
r""" Test quantized modules serialized previously can be loaded
|
||||
def _test_op(
|
||||
self,
|
||||
qmodule,
|
||||
subname=None,
|
||||
input_size=None,
|
||||
input_quantized=True,
|
||||
generate=False,
|
||||
prec=None,
|
||||
new_zipfile_serialization=False,
|
||||
):
|
||||
r"""Test quantized modules serialized previously can be loaded
|
||||
with current code, make sure we don't break backward compatibility for the
|
||||
serialization of quantized modules
|
||||
"""
|
||||
input_file, state_dict_file, scripted_module_file, traced_module_file, \
|
||||
expected_file, _package_file, _get_attr_targets_file = \
|
||||
get_filenames(self, subname)
|
||||
(
|
||||
input_file,
|
||||
state_dict_file,
|
||||
scripted_module_file,
|
||||
traced_module_file,
|
||||
expected_file,
|
||||
_package_file,
|
||||
_get_attr_targets_file,
|
||||
) = get_filenames(self, subname)
|
||||
|
||||
# only generate once.
|
||||
if generate and qengine_is_fbgemm():
|
||||
input_tensor = torch.rand(*input_size).float()
|
||||
if input_quantized:
|
||||
input_tensor = torch.quantize_per_tensor(input_tensor, 0.5, 2, torch.quint8)
|
||||
input_tensor = torch.quantize_per_tensor(
|
||||
input_tensor, 0.5, 2, torch.quint8
|
||||
)
|
||||
torch.save(input_tensor, input_file)
|
||||
# Temporary fix to use _use_new_zipfile_serialization until #38379 lands.
|
||||
torch.save(qmodule.state_dict(), state_dict_file, _use_new_zipfile_serialization=new_zipfile_serialization)
|
||||
torch.save(
|
||||
qmodule.state_dict(),
|
||||
state_dict_file,
|
||||
_use_new_zipfile_serialization=new_zipfile_serialization,
|
||||
)
|
||||
torch.jit.save(torch.jit.script(qmodule), scripted_module_file)
|
||||
torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file)
|
||||
torch.save(qmodule(input_tensor), expected_file)
|
||||
@ -90,8 +122,16 @@ class TestSerialization(TestCase):
|
||||
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
|
||||
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)
|
||||
|
||||
def _test_op_graph(self, qmodule, subname=None, input_size=None, input_quantized=True,
|
||||
generate=False, prec=None, new_zipfile_serialization=False):
|
||||
def _test_op_graph(
|
||||
self,
|
||||
qmodule,
|
||||
subname=None,
|
||||
input_size=None,
|
||||
input_quantized=True,
|
||||
generate=False,
|
||||
prec=None,
|
||||
new_zipfile_serialization=False,
|
||||
):
|
||||
r"""
|
||||
Input: a floating point module
|
||||
|
||||
@ -101,9 +141,15 @@ class TestSerialization(TestCase):
|
||||
If generate == False, traces and scripts the module and quantizes the results with
|
||||
PTQ, and compares to saved results.
|
||||
"""
|
||||
input_file, state_dict_file, scripted_module_file, traced_module_file, \
|
||||
expected_file, _package_file, _get_attr_targets_file = \
|
||||
get_filenames(self, subname)
|
||||
(
|
||||
input_file,
|
||||
state_dict_file,
|
||||
scripted_module_file,
|
||||
traced_module_file,
|
||||
expected_file,
|
||||
_package_file,
|
||||
_get_attr_targets_file,
|
||||
) = get_filenames(self, subname)
|
||||
|
||||
# only generate once.
|
||||
if generate and qengine_is_fbgemm():
|
||||
@ -119,11 +165,13 @@ class TestSerialization(TestCase):
|
||||
def _eval_fn(model, data):
|
||||
model(data)
|
||||
|
||||
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
|
||||
qconfig_dict = {"": torch.ao.quantization.default_qconfig}
|
||||
scripted_q = torch.ao.quantization.quantize_jit(
|
||||
scripted, qconfig_dict, _eval_fn, [input_tensor])
|
||||
scripted, qconfig_dict, _eval_fn, [input_tensor]
|
||||
)
|
||||
traced_q = torch.ao.quantization.quantize_jit(
|
||||
traced, qconfig_dict, _eval_fn, [input_tensor])
|
||||
traced, qconfig_dict, _eval_fn, [input_tensor]
|
||||
)
|
||||
|
||||
torch.jit.save(scripted_q, scripted_module_file)
|
||||
torch.jit.save(traced_q, traced_module_file)
|
||||
@ -136,12 +184,21 @@ class TestSerialization(TestCase):
|
||||
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
|
||||
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)
|
||||
|
||||
def _test_obs(self, obs, input_size, subname=None, generate=False, check_numerics=True):
|
||||
def _test_obs(
|
||||
self, obs, input_size, subname=None, generate=False, check_numerics=True
|
||||
):
|
||||
"""
|
||||
Test observer code can be loaded from state_dict.
|
||||
"""
|
||||
input_file, state_dict_file, _, traced_module_file, expected_file, \
|
||||
_package_file, _get_attr_targets_file = get_filenames(self, None)
|
||||
(
|
||||
input_file,
|
||||
state_dict_file,
|
||||
_,
|
||||
traced_module_file,
|
||||
expected_file,
|
||||
_package_file,
|
||||
_get_attr_targets_file,
|
||||
) = get_filenames(self, None)
|
||||
if generate:
|
||||
input_tensor = torch.rand(*input_size).float()
|
||||
torch.save(input_tensor, input_file)
|
||||
@ -159,12 +216,18 @@ class TestSerialization(TestCase):
|
||||
Verifies that files created in the past with torch.package
|
||||
work on today's FX graph mode quantization transforms.
|
||||
"""
|
||||
input_file, state_dict_file, _scripted_module_file, _traced_module_file, \
|
||||
expected_file, package_file, get_attr_targets_file = \
|
||||
get_filenames(self, None)
|
||||
(
|
||||
input_file,
|
||||
state_dict_file,
|
||||
_scripted_module_file,
|
||||
_traced_module_file,
|
||||
expected_file,
|
||||
package_file,
|
||||
get_attr_targets_file,
|
||||
) = get_filenames(self, None)
|
||||
|
||||
package_name = 'test'
|
||||
resource_name_model = 'test.pkl'
|
||||
package_name = "test"
|
||||
resource_name_model = "test.pkl"
|
||||
|
||||
def _do_quant_transforms(
|
||||
m: torch.nn.Module,
|
||||
@ -172,8 +235,8 @@ class TestSerialization(TestCase):
|
||||
) -> torch.nn.Module:
|
||||
example_inputs = (input_tensor,)
|
||||
# do the quantizaton transforms and save result
|
||||
qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
|
||||
mp = quantize_fx.prepare_fx(m, {'': qconfig}, example_inputs=example_inputs)
|
||||
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
|
||||
mp = quantize_fx.prepare_fx(m, {"": qconfig}, example_inputs=example_inputs)
|
||||
mp(input_tensor)
|
||||
mq = quantize_fx.convert_fx(mp)
|
||||
return mq
|
||||
@ -181,7 +244,7 @@ class TestSerialization(TestCase):
|
||||
def _get_get_attr_target_strings(m: GraphModule) -> Set[str]:
|
||||
results = set()
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'get_attr':
|
||||
if node.op == "get_attr":
|
||||
results.add(node.target)
|
||||
return results
|
||||
|
||||
@ -191,7 +254,7 @@ class TestSerialization(TestCase):
|
||||
|
||||
# save the model with torch.package
|
||||
with torch.package.PackageExporter(package_file) as exp:
|
||||
exp.intern('torch.testing._internal.quantization_torch_package_models')
|
||||
exp.intern("torch.testing._internal.quantization_torch_package_models")
|
||||
exp.save_pickle(package_name, resource_name_model, fp32_module)
|
||||
|
||||
# do the quantization transforms and save the result
|
||||
@ -214,7 +277,8 @@ class TestSerialization(TestCase):
|
||||
get_attrs = _get_get_attr_target_strings(mq)
|
||||
self.assertTrue(
|
||||
get_attrs == expected_get_attrs,
|
||||
f'get_attrs: expected {expected_get_attrs}, got {get_attrs}')
|
||||
f"get_attrs: expected {expected_get_attrs}, got {get_attrs}",
|
||||
)
|
||||
output_tensor = mq(input_tensor)
|
||||
self.assertTrue(torch.allclose(output_tensor, expected_output_tensor))
|
||||
|
||||
@ -231,29 +295,68 @@ class TestSerialization(TestCase):
|
||||
@override_qengines
|
||||
def test_linear_dynamic(self):
|
||||
module_qint8 = nnqd.Linear(3, 1, bias_=True, dtype=torch.qint8)
|
||||
self._test_op(module_qint8, "qint8", input_size=[1, 3], input_quantized=False, generate=False)
|
||||
self._test_op(
|
||||
module_qint8,
|
||||
"qint8",
|
||||
input_size=[1, 3],
|
||||
input_quantized=False,
|
||||
generate=False,
|
||||
)
|
||||
if qengine_is_fbgemm():
|
||||
module_float16 = nnqd.Linear(3, 1, bias_=True, dtype=torch.float16)
|
||||
self._test_op(module_float16, "float16", input_size=[1, 3], input_quantized=False, generate=False)
|
||||
self._test_op(
|
||||
module_float16,
|
||||
"float16",
|
||||
input_size=[1, 3],
|
||||
input_quantized=False,
|
||||
generate=False,
|
||||
)
|
||||
|
||||
@override_qengines
|
||||
def test_conv2d(self):
|
||||
module = nnq.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, padding_mode="zeros")
|
||||
module = nnq.Conv2d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self._test_op(module, input_size=[1, 3, 6, 6], generate=False)
|
||||
|
||||
@override_qengines
|
||||
def test_conv2d_nobias(self):
|
||||
module = nnq.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=False, padding_mode="zeros")
|
||||
module = nnq.Conv2d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self._test_op(module, input_size=[1, 3, 6, 6], generate=False)
|
||||
|
||||
@override_qengines
|
||||
def test_conv2d_graph(self):
|
||||
module = nn.Sequential(
|
||||
torch.ao.quantization.QuantStub(),
|
||||
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, padding_mode="zeros"),
|
||||
nn.Conv2d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
),
|
||||
)
|
||||
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
|
||||
|
||||
@ -261,8 +364,17 @@ class TestSerialization(TestCase):
|
||||
def test_conv2d_nobias_graph(self):
|
||||
module = nn.Sequential(
|
||||
torch.ao.quantization.QuantStub(),
|
||||
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=False, padding_mode="zeros"),
|
||||
nn.Conv2d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
padding_mode="zeros",
|
||||
),
|
||||
)
|
||||
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
|
||||
|
||||
@ -272,8 +384,17 @@ class TestSerialization(TestCase):
|
||||
# ConvPackedParams{n}d
|
||||
module = nn.Sequential(
|
||||
torch.ao.quantization.QuantStub(),
|
||||
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, padding_mode="zeros"),
|
||||
nn.Conv2d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
),
|
||||
)
|
||||
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
|
||||
|
||||
@ -283,8 +404,17 @@ class TestSerialization(TestCase):
|
||||
# ConvPackedParams{n}d
|
||||
module = nn.Sequential(
|
||||
torch.ao.quantization.QuantStub(),
|
||||
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=False, padding_mode="zeros"),
|
||||
nn.Conv2d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
padding_mode="zeros",
|
||||
),
|
||||
)
|
||||
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
|
||||
|
||||
@ -294,8 +424,17 @@ class TestSerialization(TestCase):
|
||||
# ConvPackedParams{n}d
|
||||
module = nn.Sequential(
|
||||
torch.ao.quantization.QuantStub(),
|
||||
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, padding_mode="zeros"),
|
||||
nn.Conv2d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
),
|
||||
)
|
||||
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
|
||||
|
||||
@ -305,48 +444,96 @@ class TestSerialization(TestCase):
|
||||
# ConvPackedParams{n}d
|
||||
module = nn.Sequential(
|
||||
torch.ao.quantization.QuantStub(),
|
||||
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=False, padding_mode="zeros"),
|
||||
nn.Conv2d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
padding_mode="zeros",
|
||||
),
|
||||
)
|
||||
self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
|
||||
|
||||
@override_qengines
|
||||
def test_conv2d_relu(self):
|
||||
module = nniq.ConvReLU2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, padding_mode="zeros")
|
||||
module = nniq.ConvReLU2d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self._test_op(module, input_size=[1, 3, 6, 6], generate=False)
|
||||
# TODO: graph mode quantized conv2d module
|
||||
|
||||
@override_qengines
|
||||
def test_conv3d(self):
|
||||
if qengine_is_fbgemm():
|
||||
module = nnq.Conv3d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, padding_mode="zeros")
|
||||
module = nnq.Conv3d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False)
|
||||
# TODO: graph mode quantized conv3d module
|
||||
|
||||
@override_qengines
|
||||
def test_conv3d_relu(self):
|
||||
if qengine_is_fbgemm():
|
||||
module = nniq.ConvReLU3d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, padding_mode="zeros")
|
||||
module = nniq.ConvReLU3d(
|
||||
3,
|
||||
3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False)
|
||||
# TODO: graph mode quantized conv3d module
|
||||
|
||||
@override_qengines
|
||||
@unittest.skipIf(IS_AVX512_VNNI_SUPPORTED, "This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098")
|
||||
@unittest.skipIf(
|
||||
IS_AVX512_VNNI_SUPPORTED,
|
||||
"This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098",
|
||||
)
|
||||
def test_lstm(self):
|
||||
class LSTMModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nnqd.LSTM(input_size=3, hidden_size=7, num_layers=1).to(dtype=torch.float)
|
||||
self.lstm = nnqd.LSTM(input_size=3, hidden_size=7, num_layers=1).to(
|
||||
dtype=torch.float
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.lstm(x)
|
||||
return x
|
||||
|
||||
if qengine_is_fbgemm():
|
||||
mod = LSTMModule()
|
||||
self._test_op(mod, input_size=[4, 4, 3], input_quantized=False, generate=False, new_zipfile_serialization=True)
|
||||
self._test_op(
|
||||
mod,
|
||||
input_size=[4, 4, 3],
|
||||
input_quantized=False,
|
||||
generate=False,
|
||||
new_zipfile_serialization=True,
|
||||
)
|
||||
|
||||
def test_per_channel_observer(self):
|
||||
obs = PerChannelMinMaxObserver()
|
||||
@ -373,7 +560,9 @@ class TestSerialization(TestCase):
|
||||
model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")
|
||||
ref_model = torch.ao.quantization.QuantWrapper(model)
|
||||
ref_model = torch.ao.quantization.prepare_qat(ref_model)
|
||||
self._test_obs(ref_model, input_size=[5, 5], generate=False, check_numerics=False)
|
||||
self._test_obs(
|
||||
ref_model, input_size=[5, 5], generate=False, check_numerics=False
|
||||
)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_linear_relu_package_quantization_transforms(self):
|
||||
|
Reference in New Issue
Block a user