mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76637 The previous naming convention `default_affine_fixed_qparams_observer` and `default_symmetric_fixed_qparams_observer` were uninformative, and users had to read the definition in order to understand what these observers are. The new naming convention reveals information about the range of the observers The analogous changes were also made for `default_symmetric_fixed_qparams_fake_quant` and `default_affine_fixed_qparams_fake_quant` Test Plan: ``` python test/test_quantization.py ``` ``` python test/test_quantization.py ``` Differential Revision: D36054169 D36054169 Reviewed By: vkuzo Pulled By: dzdang fbshipit-source-id: 215f7786a4b7abda7327f17cc61735697ec5cca9 (cherry picked from commit 21a4e6eda4467c8adca7fd534a506a14e975f9cf)
64 lines
2.6 KiB
Python
64 lines
2.6 KiB
Python
from .quantize import * # noqa: F403
|
|
from .observer import * # noqa: F403
|
|
from .qconfig import * # noqa: F403
|
|
from .fake_quantize import * # noqa: F403
|
|
from .fuse_modules import fuse_modules
|
|
from .stubs import * # noqa: F403
|
|
from .quant_type import * # noqa: F403
|
|
from .quantize_jit import * # noqa: F403
|
|
# from .quantize_fx import *
|
|
from .quantization_mappings import * # noqa: F403
|
|
from .fuser_method_mappings import * # noqa: F403
|
|
|
|
def default_eval_fn(model, calib_data):
|
|
r"""
|
|
Default evaluation function takes a torch.utils.data.Dataset or a list of
|
|
input Tensors and run the model on the dataset
|
|
"""
|
|
for data, target in calib_data:
|
|
model(data)
|
|
|
|
# TODO(future PR): fix the typo, should be `__all__`
|
|
_all__ = [
|
|
'QuantWrapper', 'QuantStub', 'DeQuantStub',
|
|
# Top level API for eager mode quantization
|
|
'quantize', 'quantize_dynamic', 'quantize_qat',
|
|
'prepare', 'convert', 'prepare_qat',
|
|
# Top level API for graph mode quantization on TorchScript
|
|
'quantize_jit', 'quantize_dynamic_jit',
|
|
# Top level API for graph mode quantization on GraphModule(torch.fx)
|
|
# 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
|
|
# 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
|
|
'QuantType', 'quant_type_to_str', # quantization type
|
|
# custom module APIs
|
|
'get_default_static_quant_module_mappings', 'get_static_quant_module_class',
|
|
'get_default_dynamic_quant_module_mappings',
|
|
'get_default_qat_module_mappings',
|
|
'get_default_qconfig_propagation_list',
|
|
'get_default_compare_output_module_list',
|
|
'get_quantized_operator',
|
|
'get_fuser_method',
|
|
# Sub functions for `prepare` and `swap_module`
|
|
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
|
|
'default_eval_fn', 'get_observer_dict',
|
|
'register_activation_post_process_hook',
|
|
# Observers
|
|
'ObserverBase', 'WeightObserver', 'HistogramObserver',
|
|
'observer', 'default_observer',
|
|
'default_weight_observer', 'default_placeholder_observer',
|
|
'default_per_channel_weight_observer',
|
|
# FakeQuantize (for qat)
|
|
'default_fake_quant', 'default_weight_fake_quant',
|
|
'default_fixed_qparams_range_neg1to1_fake_quant',
|
|
'default_fixed_qparams_range_0to1_fake_quant',
|
|
'default_per_channel_weight_fake_quant',
|
|
'default_histogram_fake_quant',
|
|
# QConfig
|
|
'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig',
|
|
'float_qparams_weight_only_qconfig',
|
|
# QAT utilities
|
|
'default_qat_qconfig', 'prepare_qat', 'quantize_qat',
|
|
# module transformations
|
|
'fuse_modules',
|
|
]
|