mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In quantization tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154728 Approved by: https://github.com/ezyang
228 lines
7.8 KiB
Python
228 lines
7.8 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
|
|
from .common import AOMigrationTestCase
|
|
|
|
|
|
class TestAOMigrationQuantization(AOMigrationTestCase):
|
|
r"""Modules and functions related to the
|
|
`torch/quantization` migration to `torch/ao/quantization`.
|
|
"""
|
|
|
|
def test_function_import_quantize(self):
|
|
function_list = [
|
|
"_convert",
|
|
"_observer_forward_hook",
|
|
"_propagate_qconfig_helper",
|
|
"_remove_activation_post_process",
|
|
"_remove_qconfig",
|
|
"_add_observer_",
|
|
"add_quant_dequant",
|
|
"convert",
|
|
"_get_observer_dict",
|
|
"_get_unique_devices_",
|
|
"_is_activation_post_process",
|
|
"prepare",
|
|
"prepare_qat",
|
|
"propagate_qconfig_",
|
|
"quantize",
|
|
"quantize_dynamic",
|
|
"quantize_qat",
|
|
"_register_activation_post_process_hook",
|
|
"swap_module",
|
|
]
|
|
self._test_function_import("quantize", function_list)
|
|
|
|
def test_function_import_stubs(self):
|
|
function_list = [
|
|
"QuantStub",
|
|
"DeQuantStub",
|
|
"QuantWrapper",
|
|
]
|
|
self._test_function_import("stubs", function_list)
|
|
|
|
def test_function_import_quantize_jit(self):
|
|
function_list = [
|
|
"_check_is_script_module",
|
|
"_check_forward_method",
|
|
"script_qconfig",
|
|
"script_qconfig_dict",
|
|
"fuse_conv_bn_jit",
|
|
"_prepare_jit",
|
|
"prepare_jit",
|
|
"prepare_dynamic_jit",
|
|
"_convert_jit",
|
|
"convert_jit",
|
|
"convert_dynamic_jit",
|
|
"_quantize_jit",
|
|
"quantize_jit",
|
|
"quantize_dynamic_jit",
|
|
]
|
|
self._test_function_import("quantize_jit", function_list)
|
|
|
|
def test_function_import_fake_quantize(self):
|
|
function_list = [
|
|
"_is_per_channel",
|
|
"_is_per_tensor",
|
|
"_is_symmetric_quant",
|
|
"FakeQuantizeBase",
|
|
"FakeQuantize",
|
|
"FixedQParamsFakeQuantize",
|
|
"FusedMovingAvgObsFakeQuantize",
|
|
"default_fake_quant",
|
|
"default_weight_fake_quant",
|
|
"default_fixed_qparams_range_neg1to1_fake_quant",
|
|
"default_fixed_qparams_range_0to1_fake_quant",
|
|
"default_per_channel_weight_fake_quant",
|
|
"default_histogram_fake_quant",
|
|
"default_fused_act_fake_quant",
|
|
"default_fused_wt_fake_quant",
|
|
"default_fused_per_channel_wt_fake_quant",
|
|
"_is_fake_quant_script_module",
|
|
"disable_fake_quant",
|
|
"enable_fake_quant",
|
|
"disable_observer",
|
|
"enable_observer",
|
|
]
|
|
self._test_function_import("fake_quantize", function_list)
|
|
|
|
def test_function_import_fuse_modules(self):
|
|
function_list = [
|
|
"_fuse_modules",
|
|
"_get_module",
|
|
"_set_module",
|
|
"fuse_conv_bn",
|
|
"fuse_conv_bn_relu",
|
|
"fuse_known_modules",
|
|
"fuse_modules",
|
|
"get_fuser_method",
|
|
]
|
|
self._test_function_import("fuse_modules", function_list)
|
|
|
|
def test_function_import_quant_type(self):
|
|
function_list = [
|
|
"QuantType",
|
|
"_get_quant_type_to_str",
|
|
]
|
|
self._test_function_import("quant_type", function_list)
|
|
|
|
def test_function_import_observer(self):
|
|
function_list = [
|
|
"_PartialWrapper",
|
|
"_with_args",
|
|
"_with_callable_args",
|
|
"ABC",
|
|
"ObserverBase",
|
|
"_ObserverBase",
|
|
"MinMaxObserver",
|
|
"MovingAverageMinMaxObserver",
|
|
"PerChannelMinMaxObserver",
|
|
"MovingAveragePerChannelMinMaxObserver",
|
|
"HistogramObserver",
|
|
"PlaceholderObserver",
|
|
"RecordingObserver",
|
|
"NoopObserver",
|
|
"_is_activation_post_process",
|
|
"_is_per_channel_script_obs_instance",
|
|
"get_observer_state_dict",
|
|
"load_observer_state_dict",
|
|
"default_observer",
|
|
"default_placeholder_observer",
|
|
"default_debug_observer",
|
|
"default_weight_observer",
|
|
"default_histogram_observer",
|
|
"default_per_channel_weight_observer",
|
|
"default_dynamic_quant_observer",
|
|
"default_float_qparams_observer",
|
|
]
|
|
self._test_function_import("observer", function_list)
|
|
|
|
def test_function_import_qconfig(self):
|
|
function_list = [
|
|
"QConfig",
|
|
"default_qconfig",
|
|
"default_debug_qconfig",
|
|
"default_per_channel_qconfig",
|
|
"QConfigDynamic",
|
|
"default_dynamic_qconfig",
|
|
"float16_dynamic_qconfig",
|
|
"float16_static_qconfig",
|
|
"per_channel_dynamic_qconfig",
|
|
"float_qparams_weight_only_qconfig",
|
|
"default_qat_qconfig",
|
|
"default_weight_only_qconfig",
|
|
"default_activation_only_qconfig",
|
|
"default_qat_qconfig_v2",
|
|
"get_default_qconfig",
|
|
"get_default_qat_qconfig",
|
|
"_assert_valid_qconfig",
|
|
"QConfigAny",
|
|
"_add_module_to_qconfig_obs_ctr",
|
|
"qconfig_equals",
|
|
]
|
|
self._test_function_import("qconfig", function_list)
|
|
|
|
def test_function_import_quantization_mappings(self):
|
|
function_list = [
|
|
"no_observer_set",
|
|
"get_default_static_quant_module_mappings",
|
|
"get_static_quant_module_class",
|
|
"get_dynamic_quant_module_class",
|
|
"get_default_qat_module_mappings",
|
|
"get_default_dynamic_quant_module_mappings",
|
|
"get_default_qconfig_propagation_list",
|
|
"get_default_compare_output_module_list",
|
|
"get_default_float_to_quantized_operator_mappings",
|
|
"get_quantized_operator",
|
|
"_get_special_act_post_process",
|
|
"_has_special_act_post_process",
|
|
]
|
|
dict_list = [
|
|
"DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
|
|
"DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
|
|
"DEFAULT_QAT_MODULE_MAPPINGS",
|
|
"DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
|
|
# "_INCLUDE_QCONFIG_PROPAGATE_LIST",
|
|
"DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
|
|
"DEFAULT_MODULE_TO_ACT_POST_PROCESS",
|
|
]
|
|
self._test_function_import("quantization_mappings", function_list)
|
|
self._test_dict_import("quantization_mappings", dict_list)
|
|
|
|
def test_function_import_fuser_method_mappings(self):
|
|
function_list = [
|
|
"fuse_conv_bn",
|
|
"fuse_conv_bn_relu",
|
|
"fuse_linear_bn",
|
|
"get_fuser_method",
|
|
]
|
|
dict_list = ["_DEFAULT_OP_LIST_TO_FUSER_METHOD"]
|
|
self._test_function_import("fuser_method_mappings", function_list)
|
|
self._test_dict_import("fuser_method_mappings", dict_list)
|
|
|
|
def test_function_import_utils(self):
|
|
function_list = [
|
|
"activation_dtype",
|
|
"activation_is_int8_quantized",
|
|
"activation_is_statically_quantized",
|
|
"calculate_qmin_qmax",
|
|
"check_min_max_valid",
|
|
"get_combined_dict",
|
|
"get_qconfig_dtypes",
|
|
"get_qparam_dict",
|
|
"get_quant_type",
|
|
"get_swapped_custom_module_class",
|
|
"getattr_from_fqn",
|
|
"is_per_channel",
|
|
"is_per_tensor",
|
|
"weight_dtype",
|
|
"weight_is_quantized",
|
|
"weight_is_statically_quantized",
|
|
]
|
|
self._test_function_import("utils", function_list)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_quantization.py")
|