[ao] quantize.py fixing public v private (#87521)

Summary: made _register_activation_post_process_hook, _add_observer,
_get_unique_devices_, _get_observer_dict private

Test Plan: python test/test_public_bindings.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D40709277](https://our.internmc.facebook.com/intern/diff/D40709277)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87521
Approved by: https://github.com/jerryzh168
This commit is contained in:
HDCharles
2022-12-13 17:00:11 -08:00
committed by PyTorch MergeBot
parent 691a44f403
commit 1ca9d43d4e
10 changed files with 30 additions and 47 deletions

View File

@ -45,11 +45,9 @@ Utility functions
:nosignatures:
:template: classtemplate.rst
add_observer_
swap_module
propagate_qconfig_
default_eval_fn
get_observer_dict
torch.quantization.quantize_fx
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -737,7 +737,6 @@
"QuantWrapper",
"RecordingObserver",
"_add_module_to_qconfig_obs_ctr",
"add_observer_",
"add_quant_dequant",
"_assert_valid_qconfig",
"convert",
@ -781,11 +780,9 @@
"get_default_static_quant_module_mappings",
"get_dynamic_quant_module_class",
"get_fuser_method",
"get_observer_dict",
"get_observer_state_dict",
"get_quantized_operator",
"get_static_quant_module_class",
"get_unique_devices_",
"is_activation_post_process",
"load_observer_state_dict",
"no_observer_set",
@ -801,7 +798,6 @@
"quantize_dynamic_jit",
"quantize_jit",
"quantize_qat",
"register_activation_post_process_hook",
"script_qconfig",
"script_qconfig_dict",
"swap_module"
@ -889,11 +885,8 @@
"no_observer_set"
],
"torch.quantization.quantize": [
"add_observer_",
"add_quant_dequant",
"convert",
"get_observer_dict",
"get_unique_devices_",
"is_activation_post_process",
"prepare",
"prepare_qat",
@ -901,7 +894,6 @@
"quantize",
"quantize_dynamic",
"quantize_qat",
"register_activation_post_process_hook",
"swap_module"
],
"torch.quantization.quantize_jit": [

View File

@ -14,11 +14,11 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
'_propagate_qconfig_helper',
'_remove_activation_post_process',
'_remove_qconfig',
'add_observer_',
'_add_observer_',
'add_quant_dequant',
'convert',
'get_observer_dict',
'get_unique_devices_',
'_get_observer_dict',
'_get_unique_devices_',
'is_activation_post_process',
'prepare',
'prepare_qat',
@ -26,7 +26,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
'quantize',
'quantize_dynamic',
'quantize_qat',
'register_activation_post_process_hook',
'_register_activation_post_process_hook',
'swap_module',
]
self._test_function_import('quantize', function_list)

View File

@ -17,11 +17,11 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
'_propagate_qconfig_helper',
'_remove_activation_post_process',
'_remove_qconfig',
'add_observer_',
'_add_observer_',
'add_quant_dequant',
'convert',
'get_observer_dict',
'get_unique_devices_',
'_get_observer_dict',
'_get_unique_devices_',
'is_activation_post_process',
'prepare',
'prepare_qat',
@ -29,7 +29,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
'quantize',
'quantize_dynamic',
'quantize_qat',
'register_activation_post_process_hook',
'_register_activation_post_process_hook',
'swap_module',
]
self._test_function_import('quantize', function_list)

View File

@ -17,7 +17,6 @@ from torch.ao.quantization import (
default_observer,
default_histogram_observer,
default_per_channel_weight_observer,
get_observer_dict,
prepare,
prepare_qat,
convert,
@ -26,6 +25,7 @@ from torch.ao.quantization import (
get_embedding_qat_module_mappings,
get_embedding_static_quant_module_mappings,
)
from torch.ao.quantization.quantize import _get_observer_dict
import torch.nn as nn
@ -566,7 +566,7 @@ class TestRecordHistogramObserver(QuantizationTestCase):
test_only_eval_fn(model, self.calib_data)
test_only_eval_fn(model, self.calib_data)
observer_dict = {}
get_observer_dict(model, observer_dict)
_get_observer_dict(model, observer_dict)
self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
'observer is not recorded in the dict')

View File

@ -40,7 +40,6 @@ __all__ = [
"RecordingObserver",
"ReuseInputObserver",
"UniformQuantizationObserverBase",
"add_observer_",
"add_quant_dequant",
"convert",
"convert_dynamic_jit",
@ -109,11 +108,9 @@ __all__ = [
"get_embedding_static_quant_module_mappings",
"get_fuser_method",
"get_fuser_method_new",
"get_observer_dict",
"get_observer_state_dict",
"get_quantized_operator",
"get_static_quant_module_class",
"get_unique_devices_",
"is_activation_post_process",
"load_observer_state_dict",
"no_observer_set",
@ -129,7 +126,6 @@ __all__ = [
"quantize_dynamic_jit",
"quantize_jit",
"quantize_qat",
"register_activation_post_process_hook",
"script_qconfig",
"script_qconfig_dict",
"swap_module",

View File

@ -32,9 +32,6 @@ __all__ = [
"get_default_custom_config_dict",
"is_activation_post_process",
"propagate_qconfig_",
"register_activation_post_process_hook",
"add_observer_",
"get_unique_devices_",
"add_quant_dequant",
"prepare",
"quantize",
@ -43,7 +40,6 @@ __all__ = [
"quantize_qat",
"convert",
"swap_module",
"get_observer_dict",
]
_DEFAULT_CUSTOM_CONFIG_DICT = {
@ -139,7 +135,7 @@ def _observer_forward_pre_hook(self, input):
"""
return self.activation_post_process(input[0])
def register_activation_post_process_hook(module, pre_hook=False):
def _register_activation_post_process_hook(module, pre_hook=False):
assert hasattr(module, 'activation_post_process'), \
'Expect activation_post_process attribute already attached to the module'
if pre_hook:
@ -152,7 +148,7 @@ def register_activation_post_process_hook(module, pre_hook=False):
)
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
r"""Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that
@ -176,9 +172,9 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
# respect device affinity when adding observers
if device is None:
devices = get_unique_devices_(module)
devices = _get_unique_devices_(module)
assert len(devices) <= 1, (
"add_observer_ only works with cpu or single-device CUDA modules, "
"_add_observer_ only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None
@ -203,7 +199,7 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
m.qconfig, device, special_act_post_process))
# Register observer as the first entry in the hook list
# All post forward hooks are preserved and will be executed after the observer before convert
register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig))
_register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig))
for name, child in module.named_children():
# TODO remove Dropout special after codebase stable
@ -230,7 +226,7 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set():
insert_activation_post_process(observed_child)
else:
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
_add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
@ -238,7 +234,7 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
and type_before_parametrizations(module) in qconfig_propagation_list:
insert_activation_post_process(module)
def get_unique_devices_(module):
def _get_unique_devices_(module):
return {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
@ -315,7 +311,7 @@ def prepare(model, inplace=False, allow_list=None,
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")
add_observer_(
_add_observer_(
model, qconfig_propagation_list, observer_non_leaf_module_list,
custom_module_class_mapping=custom_module_class_mapping)
return model
@ -639,7 +635,7 @@ def swap_module(mod, mapping, custom_module_class_mapping):
new_mod.register_forward_hook(hook_fn)
# respect device affinity when swapping modules
devices = get_unique_devices_(mod)
devices = _get_unique_devices_(mod)
assert len(devices) <= 1, (
"swap_module only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
@ -649,7 +645,7 @@ def swap_module(mod, mapping, custom_module_class_mapping):
new_mod.to(device)
return new_mod
def get_observer_dict(mod, target_dict, prefix=""):
def _get_observer_dict(mod, target_dict, prefix=""):
r"""Traverse the modules and save all observers into dict.
This is mainly used for quantization accuracy debug
Args:
@ -664,4 +660,4 @@ def get_observer_dict(mod, target_dict, prefix=""):
target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
for name, child in mod.named_children():
module_prefix = get_prefix(prefix) + name if prefix else name
get_observer_dict(child, target_dict, module_prefix)
_get_observer_dict(child, target_dict, module_prefix)

View File

@ -625,9 +625,11 @@ def _get_lstm_with_individually_observed_parts(
cell.initial_hidden_state_qparams = (obs.scale, obs.zero_point)
cell.hidden_state_dtype = obs.dtype
# need to do this here to avoid circular dependency
from torch.ao.quantization.quantize import _add_observer_
# Insert the observers based on the previously attached QConfigs
# Pass in non_leaf_module_list to prevent the observers for sigmoid/tanh from being overridden
torch.ao.quantization.add_observer_(
_add_observer_( # type: ignore[attr-defined]
observed_lstm,
non_leaf_module_list=[torch.nn.Sigmoid, torch.nn.Tanh]
)

View File

@ -40,9 +40,8 @@ _all__ = [
'get_quantized_operator',
'get_fuser_method',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',
'register_activation_post_process_hook',
'propagate_qconfig_', 'add_quant_dequant', 'swap_module',
'default_eval_fn',
# Observers
'ObserverBase', 'WeightObserver', 'HistogramObserver',
'observer', 'default_observer',

View File

@ -12,11 +12,11 @@ from torch.ao.quantization.quantize import _observer_forward_hook
from torch.ao.quantization.quantize import _propagate_qconfig_helper
from torch.ao.quantization.quantize import _remove_activation_post_process
from torch.ao.quantization.quantize import _remove_qconfig
from torch.ao.quantization.quantize import add_observer_
from torch.ao.quantization.quantize import _add_observer_
from torch.ao.quantization.quantize import add_quant_dequant
from torch.ao.quantization.quantize import convert
from torch.ao.quantization.quantize import get_observer_dict
from torch.ao.quantization.quantize import get_unique_devices_
from torch.ao.quantization.quantize import _get_observer_dict
from torch.ao.quantization.quantize import _get_unique_devices_
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.quantize import prepare
from torch.ao.quantization.quantize import prepare_qat
@ -24,5 +24,5 @@ from torch.ao.quantization.quantize import propagate_qconfig_
from torch.ao.quantization.quantize import quantize
from torch.ao.quantization.quantize import quantize_dynamic
from torch.ao.quantization.quantize import quantize_qat
from torch.ao.quantization.quantize import register_activation_post_process_hook
from torch.ao.quantization.quantize import _register_activation_post_process_hook
from torch.ao.quantization.quantize import swap_module