mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ao] quantize.py fixing public v private (#87521)
Summary: made _register_activation_post_process_hook, _add_observer, _get_unique_devices_, _get_observer_dict private Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D40709277](https://our.internmc.facebook.com/intern/diff/D40709277) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87521 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
691a44f403
commit
1ca9d43d4e
@ -45,11 +45,9 @@ Utility functions
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
add_observer_
|
||||
swap_module
|
||||
propagate_qconfig_
|
||||
default_eval_fn
|
||||
get_observer_dict
|
||||
|
||||
torch.quantization.quantize_fx
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -737,7 +737,6 @@
|
||||
"QuantWrapper",
|
||||
"RecordingObserver",
|
||||
"_add_module_to_qconfig_obs_ctr",
|
||||
"add_observer_",
|
||||
"add_quant_dequant",
|
||||
"_assert_valid_qconfig",
|
||||
"convert",
|
||||
@ -781,11 +780,9 @@
|
||||
"get_default_static_quant_module_mappings",
|
||||
"get_dynamic_quant_module_class",
|
||||
"get_fuser_method",
|
||||
"get_observer_dict",
|
||||
"get_observer_state_dict",
|
||||
"get_quantized_operator",
|
||||
"get_static_quant_module_class",
|
||||
"get_unique_devices_",
|
||||
"is_activation_post_process",
|
||||
"load_observer_state_dict",
|
||||
"no_observer_set",
|
||||
@ -801,7 +798,6 @@
|
||||
"quantize_dynamic_jit",
|
||||
"quantize_jit",
|
||||
"quantize_qat",
|
||||
"register_activation_post_process_hook",
|
||||
"script_qconfig",
|
||||
"script_qconfig_dict",
|
||||
"swap_module"
|
||||
@ -889,11 +885,8 @@
|
||||
"no_observer_set"
|
||||
],
|
||||
"torch.quantization.quantize": [
|
||||
"add_observer_",
|
||||
"add_quant_dequant",
|
||||
"convert",
|
||||
"get_observer_dict",
|
||||
"get_unique_devices_",
|
||||
"is_activation_post_process",
|
||||
"prepare",
|
||||
"prepare_qat",
|
||||
@ -901,7 +894,6 @@
|
||||
"quantize",
|
||||
"quantize_dynamic",
|
||||
"quantize_qat",
|
||||
"register_activation_post_process_hook",
|
||||
"swap_module"
|
||||
],
|
||||
"torch.quantization.quantize_jit": [
|
||||
|
||||
@ -14,11 +14,11 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
'_propagate_qconfig_helper',
|
||||
'_remove_activation_post_process',
|
||||
'_remove_qconfig',
|
||||
'add_observer_',
|
||||
'_add_observer_',
|
||||
'add_quant_dequant',
|
||||
'convert',
|
||||
'get_observer_dict',
|
||||
'get_unique_devices_',
|
||||
'_get_observer_dict',
|
||||
'_get_unique_devices_',
|
||||
'is_activation_post_process',
|
||||
'prepare',
|
||||
'prepare_qat',
|
||||
@ -26,7 +26,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
'quantize',
|
||||
'quantize_dynamic',
|
||||
'quantize_qat',
|
||||
'register_activation_post_process_hook',
|
||||
'_register_activation_post_process_hook',
|
||||
'swap_module',
|
||||
]
|
||||
self._test_function_import('quantize', function_list)
|
||||
|
||||
@ -17,11 +17,11 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
'_propagate_qconfig_helper',
|
||||
'_remove_activation_post_process',
|
||||
'_remove_qconfig',
|
||||
'add_observer_',
|
||||
'_add_observer_',
|
||||
'add_quant_dequant',
|
||||
'convert',
|
||||
'get_observer_dict',
|
||||
'get_unique_devices_',
|
||||
'_get_observer_dict',
|
||||
'_get_unique_devices_',
|
||||
'is_activation_post_process',
|
||||
'prepare',
|
||||
'prepare_qat',
|
||||
@ -29,7 +29,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
'quantize',
|
||||
'quantize_dynamic',
|
||||
'quantize_qat',
|
||||
'register_activation_post_process_hook',
|
||||
'_register_activation_post_process_hook',
|
||||
'swap_module',
|
||||
]
|
||||
self._test_function_import('quantize', function_list)
|
||||
|
||||
@ -17,7 +17,6 @@ from torch.ao.quantization import (
|
||||
default_observer,
|
||||
default_histogram_observer,
|
||||
default_per_channel_weight_observer,
|
||||
get_observer_dict,
|
||||
prepare,
|
||||
prepare_qat,
|
||||
convert,
|
||||
@ -26,6 +25,7 @@ from torch.ao.quantization import (
|
||||
get_embedding_qat_module_mappings,
|
||||
get_embedding_static_quant_module_mappings,
|
||||
)
|
||||
from torch.ao.quantization.quantize import _get_observer_dict
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
@ -566,7 +566,7 @@ class TestRecordHistogramObserver(QuantizationTestCase):
|
||||
test_only_eval_fn(model, self.calib_data)
|
||||
test_only_eval_fn(model, self.calib_data)
|
||||
observer_dict = {}
|
||||
get_observer_dict(model, observer_dict)
|
||||
_get_observer_dict(model, observer_dict)
|
||||
|
||||
self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
|
||||
'observer is not recorded in the dict')
|
||||
|
||||
@ -40,7 +40,6 @@ __all__ = [
|
||||
"RecordingObserver",
|
||||
"ReuseInputObserver",
|
||||
"UniformQuantizationObserverBase",
|
||||
"add_observer_",
|
||||
"add_quant_dequant",
|
||||
"convert",
|
||||
"convert_dynamic_jit",
|
||||
@ -109,11 +108,9 @@ __all__ = [
|
||||
"get_embedding_static_quant_module_mappings",
|
||||
"get_fuser_method",
|
||||
"get_fuser_method_new",
|
||||
"get_observer_dict",
|
||||
"get_observer_state_dict",
|
||||
"get_quantized_operator",
|
||||
"get_static_quant_module_class",
|
||||
"get_unique_devices_",
|
||||
"is_activation_post_process",
|
||||
"load_observer_state_dict",
|
||||
"no_observer_set",
|
||||
@ -129,7 +126,6 @@ __all__ = [
|
||||
"quantize_dynamic_jit",
|
||||
"quantize_jit",
|
||||
"quantize_qat",
|
||||
"register_activation_post_process_hook",
|
||||
"script_qconfig",
|
||||
"script_qconfig_dict",
|
||||
"swap_module",
|
||||
|
||||
@ -32,9 +32,6 @@ __all__ = [
|
||||
"get_default_custom_config_dict",
|
||||
"is_activation_post_process",
|
||||
"propagate_qconfig_",
|
||||
"register_activation_post_process_hook",
|
||||
"add_observer_",
|
||||
"get_unique_devices_",
|
||||
"add_quant_dequant",
|
||||
"prepare",
|
||||
"quantize",
|
||||
@ -43,7 +40,6 @@ __all__ = [
|
||||
"quantize_qat",
|
||||
"convert",
|
||||
"swap_module",
|
||||
"get_observer_dict",
|
||||
]
|
||||
|
||||
_DEFAULT_CUSTOM_CONFIG_DICT = {
|
||||
@ -139,7 +135,7 @@ def _observer_forward_pre_hook(self, input):
|
||||
"""
|
||||
return self.activation_post_process(input[0])
|
||||
|
||||
def register_activation_post_process_hook(module, pre_hook=False):
|
||||
def _register_activation_post_process_hook(module, pre_hook=False):
|
||||
assert hasattr(module, 'activation_post_process'), \
|
||||
'Expect activation_post_process attribute already attached to the module'
|
||||
if pre_hook:
|
||||
@ -152,7 +148,7 @@ def register_activation_post_process_hook(module, pre_hook=False):
|
||||
)
|
||||
|
||||
|
||||
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
|
||||
def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
|
||||
r"""Add observer for the leaf child of the module.
|
||||
|
||||
This function insert observer module to all leaf child module that
|
||||
@ -176,9 +172,9 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
|
||||
|
||||
# respect device affinity when adding observers
|
||||
if device is None:
|
||||
devices = get_unique_devices_(module)
|
||||
devices = _get_unique_devices_(module)
|
||||
assert len(devices) <= 1, (
|
||||
"add_observer_ only works with cpu or single-device CUDA modules, "
|
||||
"_add_observer_ only works with cpu or single-device CUDA modules, "
|
||||
"but got devices {}".format(devices)
|
||||
)
|
||||
device = next(iter(devices)) if len(devices) > 0 else None
|
||||
@ -203,7 +199,7 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
|
||||
m.qconfig, device, special_act_post_process))
|
||||
# Register observer as the first entry in the hook list
|
||||
# All post forward hooks are preserved and will be executed after the observer before convert
|
||||
register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig))
|
||||
_register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig))
|
||||
|
||||
for name, child in module.named_children():
|
||||
# TODO remove Dropout special after codebase stable
|
||||
@ -230,7 +226,7 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
|
||||
if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set():
|
||||
insert_activation_post_process(observed_child)
|
||||
else:
|
||||
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
|
||||
_add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
|
||||
|
||||
# Insert observers only for leaf nodes, note that this observer is for
|
||||
# the output of the module, for input QuantStub will observe them
|
||||
@ -238,7 +234,7 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
|
||||
and type_before_parametrizations(module) in qconfig_propagation_list:
|
||||
insert_activation_post_process(module)
|
||||
|
||||
def get_unique_devices_(module):
|
||||
def _get_unique_devices_(module):
|
||||
return {p.device for p in module.parameters()} | \
|
||||
{p.device for p in module.buffers()}
|
||||
|
||||
@ -315,7 +311,7 @@ def prepare(model, inplace=False, allow_list=None,
|
||||
"passed correct configuration through `qconfig_dict` or "
|
||||
"by assigning the `.qconfig` attribute directly on submodules")
|
||||
|
||||
add_observer_(
|
||||
_add_observer_(
|
||||
model, qconfig_propagation_list, observer_non_leaf_module_list,
|
||||
custom_module_class_mapping=custom_module_class_mapping)
|
||||
return model
|
||||
@ -639,7 +635,7 @@ def swap_module(mod, mapping, custom_module_class_mapping):
|
||||
new_mod.register_forward_hook(hook_fn)
|
||||
|
||||
# respect device affinity when swapping modules
|
||||
devices = get_unique_devices_(mod)
|
||||
devices = _get_unique_devices_(mod)
|
||||
assert len(devices) <= 1, (
|
||||
"swap_module only works with cpu or single-device CUDA modules, "
|
||||
"but got devices {}".format(devices)
|
||||
@ -649,7 +645,7 @@ def swap_module(mod, mapping, custom_module_class_mapping):
|
||||
new_mod.to(device)
|
||||
return new_mod
|
||||
|
||||
def get_observer_dict(mod, target_dict, prefix=""):
|
||||
def _get_observer_dict(mod, target_dict, prefix=""):
|
||||
r"""Traverse the modules and save all observers into dict.
|
||||
This is mainly used for quantization accuracy debug
|
||||
Args:
|
||||
@ -664,4 +660,4 @@ def get_observer_dict(mod, target_dict, prefix=""):
|
||||
target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
|
||||
for name, child in mod.named_children():
|
||||
module_prefix = get_prefix(prefix) + name if prefix else name
|
||||
get_observer_dict(child, target_dict, module_prefix)
|
||||
_get_observer_dict(child, target_dict, module_prefix)
|
||||
|
||||
@ -625,9 +625,11 @@ def _get_lstm_with_individually_observed_parts(
|
||||
cell.initial_hidden_state_qparams = (obs.scale, obs.zero_point)
|
||||
cell.hidden_state_dtype = obs.dtype
|
||||
|
||||
# need to do this here to avoid circular dependency
|
||||
from torch.ao.quantization.quantize import _add_observer_
|
||||
# Insert the observers based on the previously attached QConfigs
|
||||
# Pass in non_leaf_module_list to prevent the observers for sigmoid/tanh from being overridden
|
||||
torch.ao.quantization.add_observer_(
|
||||
_add_observer_( # type: ignore[attr-defined]
|
||||
observed_lstm,
|
||||
non_leaf_module_list=[torch.nn.Sigmoid, torch.nn.Tanh]
|
||||
)
|
||||
|
||||
@ -40,9 +40,8 @@ _all__ = [
|
||||
'get_quantized_operator',
|
||||
'get_fuser_method',
|
||||
# Sub functions for `prepare` and `swap_module`
|
||||
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
|
||||
'default_eval_fn', 'get_observer_dict',
|
||||
'register_activation_post_process_hook',
|
||||
'propagate_qconfig_', 'add_quant_dequant', 'swap_module',
|
||||
'default_eval_fn',
|
||||
# Observers
|
||||
'ObserverBase', 'WeightObserver', 'HistogramObserver',
|
||||
'observer', 'default_observer',
|
||||
|
||||
@ -12,11 +12,11 @@ from torch.ao.quantization.quantize import _observer_forward_hook
|
||||
from torch.ao.quantization.quantize import _propagate_qconfig_helper
|
||||
from torch.ao.quantization.quantize import _remove_activation_post_process
|
||||
from torch.ao.quantization.quantize import _remove_qconfig
|
||||
from torch.ao.quantization.quantize import add_observer_
|
||||
from torch.ao.quantization.quantize import _add_observer_
|
||||
from torch.ao.quantization.quantize import add_quant_dequant
|
||||
from torch.ao.quantization.quantize import convert
|
||||
from torch.ao.quantization.quantize import get_observer_dict
|
||||
from torch.ao.quantization.quantize import get_unique_devices_
|
||||
from torch.ao.quantization.quantize import _get_observer_dict
|
||||
from torch.ao.quantization.quantize import _get_unique_devices_
|
||||
from torch.ao.quantization.quantize import is_activation_post_process
|
||||
from torch.ao.quantization.quantize import prepare
|
||||
from torch.ao.quantization.quantize import prepare_qat
|
||||
@ -24,5 +24,5 @@ from torch.ao.quantization.quantize import propagate_qconfig_
|
||||
from torch.ao.quantization.quantize import quantize
|
||||
from torch.ao.quantization.quantize import quantize_dynamic
|
||||
from torch.ao.quantization.quantize import quantize_qat
|
||||
from torch.ao.quantization.quantize import register_activation_post_process_hook
|
||||
from torch.ao.quantization.quantize import _register_activation_post_process_hook
|
||||
from torch.ao.quantization.quantize import swap_module
|
||||
|
||||
Reference in New Issue
Block a user