Files
pytorch/test/quantization/ao_migration/test_quantization.py
Anthony Barbier 954ce94950 Add __main__ guards to quantization tests (#154728)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs.

In quantization tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154728
Approved by: https://github.com/ezyang
2025-06-10 19:46:07 +00:00

228 lines
7.8 KiB
Python

# Owner(s): ["oncall: quantization"]
from torch.testing._internal.common_utils import raise_on_run_directly
from .common import AOMigrationTestCase
class TestAOMigrationQuantization(AOMigrationTestCase):
r"""Modules and functions related to the
`torch/quantization` migration to `torch/ao/quantization`.
"""
def test_function_import_quantize(self):
function_list = [
"_convert",
"_observer_forward_hook",
"_propagate_qconfig_helper",
"_remove_activation_post_process",
"_remove_qconfig",
"_add_observer_",
"add_quant_dequant",
"convert",
"_get_observer_dict",
"_get_unique_devices_",
"_is_activation_post_process",
"prepare",
"prepare_qat",
"propagate_qconfig_",
"quantize",
"quantize_dynamic",
"quantize_qat",
"_register_activation_post_process_hook",
"swap_module",
]
self._test_function_import("quantize", function_list)
def test_function_import_stubs(self):
function_list = [
"QuantStub",
"DeQuantStub",
"QuantWrapper",
]
self._test_function_import("stubs", function_list)
def test_function_import_quantize_jit(self):
function_list = [
"_check_is_script_module",
"_check_forward_method",
"script_qconfig",
"script_qconfig_dict",
"fuse_conv_bn_jit",
"_prepare_jit",
"prepare_jit",
"prepare_dynamic_jit",
"_convert_jit",
"convert_jit",
"convert_dynamic_jit",
"_quantize_jit",
"quantize_jit",
"quantize_dynamic_jit",
]
self._test_function_import("quantize_jit", function_list)
def test_function_import_fake_quantize(self):
function_list = [
"_is_per_channel",
"_is_per_tensor",
"_is_symmetric_quant",
"FakeQuantizeBase",
"FakeQuantize",
"FixedQParamsFakeQuantize",
"FusedMovingAvgObsFakeQuantize",
"default_fake_quant",
"default_weight_fake_quant",
"default_fixed_qparams_range_neg1to1_fake_quant",
"default_fixed_qparams_range_0to1_fake_quant",
"default_per_channel_weight_fake_quant",
"default_histogram_fake_quant",
"default_fused_act_fake_quant",
"default_fused_wt_fake_quant",
"default_fused_per_channel_wt_fake_quant",
"_is_fake_quant_script_module",
"disable_fake_quant",
"enable_fake_quant",
"disable_observer",
"enable_observer",
]
self._test_function_import("fake_quantize", function_list)
def test_function_import_fuse_modules(self):
function_list = [
"_fuse_modules",
"_get_module",
"_set_module",
"fuse_conv_bn",
"fuse_conv_bn_relu",
"fuse_known_modules",
"fuse_modules",
"get_fuser_method",
]
self._test_function_import("fuse_modules", function_list)
def test_function_import_quant_type(self):
function_list = [
"QuantType",
"_get_quant_type_to_str",
]
self._test_function_import("quant_type", function_list)
def test_function_import_observer(self):
function_list = [
"_PartialWrapper",
"_with_args",
"_with_callable_args",
"ABC",
"ObserverBase",
"_ObserverBase",
"MinMaxObserver",
"MovingAverageMinMaxObserver",
"PerChannelMinMaxObserver",
"MovingAveragePerChannelMinMaxObserver",
"HistogramObserver",
"PlaceholderObserver",
"RecordingObserver",
"NoopObserver",
"_is_activation_post_process",
"_is_per_channel_script_obs_instance",
"get_observer_state_dict",
"load_observer_state_dict",
"default_observer",
"default_placeholder_observer",
"default_debug_observer",
"default_weight_observer",
"default_histogram_observer",
"default_per_channel_weight_observer",
"default_dynamic_quant_observer",
"default_float_qparams_observer",
]
self._test_function_import("observer", function_list)
def test_function_import_qconfig(self):
function_list = [
"QConfig",
"default_qconfig",
"default_debug_qconfig",
"default_per_channel_qconfig",
"QConfigDynamic",
"default_dynamic_qconfig",
"float16_dynamic_qconfig",
"float16_static_qconfig",
"per_channel_dynamic_qconfig",
"float_qparams_weight_only_qconfig",
"default_qat_qconfig",
"default_weight_only_qconfig",
"default_activation_only_qconfig",
"default_qat_qconfig_v2",
"get_default_qconfig",
"get_default_qat_qconfig",
"_assert_valid_qconfig",
"QConfigAny",
"_add_module_to_qconfig_obs_ctr",
"qconfig_equals",
]
self._test_function_import("qconfig", function_list)
def test_function_import_quantization_mappings(self):
function_list = [
"no_observer_set",
"get_default_static_quant_module_mappings",
"get_static_quant_module_class",
"get_dynamic_quant_module_class",
"get_default_qat_module_mappings",
"get_default_dynamic_quant_module_mappings",
"get_default_qconfig_propagation_list",
"get_default_compare_output_module_list",
"get_default_float_to_quantized_operator_mappings",
"get_quantized_operator",
"_get_special_act_post_process",
"_has_special_act_post_process",
]
dict_list = [
"DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
"DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
"DEFAULT_QAT_MODULE_MAPPINGS",
"DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
# "_INCLUDE_QCONFIG_PROPAGATE_LIST",
"DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
"DEFAULT_MODULE_TO_ACT_POST_PROCESS",
]
self._test_function_import("quantization_mappings", function_list)
self._test_dict_import("quantization_mappings", dict_list)
def test_function_import_fuser_method_mappings(self):
function_list = [
"fuse_conv_bn",
"fuse_conv_bn_relu",
"fuse_linear_bn",
"get_fuser_method",
]
dict_list = ["_DEFAULT_OP_LIST_TO_FUSER_METHOD"]
self._test_function_import("fuser_method_mappings", function_list)
self._test_dict_import("fuser_method_mappings", dict_list)
def test_function_import_utils(self):
function_list = [
"activation_dtype",
"activation_is_int8_quantized",
"activation_is_statically_quantized",
"calculate_qmin_qmax",
"check_min_max_valid",
"get_combined_dict",
"get_qconfig_dtypes",
"get_qparam_dict",
"get_quant_type",
"get_swapped_custom_module_class",
"getattr_from_fqn",
"is_per_channel",
"is_per_tensor",
"weight_dtype",
"weight_is_quantized",
"weight_is_statically_quantized",
]
self._test_function_import("utils", function_list)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")