[ao] making _is_activation_post_process private (#87520)

Summary: same function in observer and quantize, consolidated to a
single function. Note the definitions were slightly different, I've
changed the definition to be maximally inclusive so that the name of the
function is more accurate

Test Plan: python test/test_public_bindings.py
python test/test_quantization.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D40709276](https://our.internmc.facebook.com/intern/diff/D40709276)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87520
Approved by: https://github.com/jcaip
This commit is contained in:
HDCharles
2022-11-16 10:07:14 -08:00
committed by PyTorch MergeBot
parent aee96bbf5a
commit 45c62a3377
15 changed files with 30 additions and 36 deletions

View File

@ -786,7 +786,7 @@
"get_quantized_operator",
"get_static_quant_module_class",
"get_unique_devices_",
"is_activation_post_process",
"_is_activation_post_process",
"load_observer_state_dict",
"no_observer_set",
"prepare",
@ -894,7 +894,7 @@
"convert",
"get_observer_dict",
"get_unique_devices_",
"is_activation_post_process",
"_is_activation_post_process",
"prepare",
"prepare_qat",
"propagate_qconfig_",

View File

@ -19,7 +19,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
'convert',
'get_observer_dict',
'get_unique_devices_',
'is_activation_post_process',
'_is_activation_post_process',
'prepare',
'prepare_qat',
'propagate_qconfig_',

View File

@ -22,7 +22,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
'convert',
'get_observer_dict',
'get_unique_devices_',
'is_activation_post_process',
'_is_activation_post_process',
'prepare',
'prepare_qat',
'propagate_qconfig_',

View File

@ -55,7 +55,6 @@ from torch.ao.quantization import (
get_default_qat_qconfig,
get_default_qconfig_mapping,
get_default_qat_qconfig_mapping,
is_activation_post_process,
fuse_modules,
fuse_modules_qat,
prepare,
@ -148,6 +147,7 @@ from torch.ao.quantization.observer import (
default_fixed_qparams_range_0to1_observer,
default_fixed_qparams_range_neg1to1_observer,
MinMaxObserver,
_is_activation_post_process,
)
# test utils
@ -3249,7 +3249,7 @@ class TestQuantizeFx(QuantizationTestCase):
_check_node_not_observed(model, new_node, node)
elif arg_node.op == "call_module":
self.assertTrue(
not is_activation_post_process(getattr(model, arg_node.target)),
not _is_activation_post_process(getattr(model, arg_node.target)),
"Arg: {0} of node: {1} is observed but is not a float tensor".format(
arg_node, node
),
@ -4933,7 +4933,7 @@ class TestQuantizeFx(QuantizationTestCase):
qconfig_dict = func(backend)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1)))
for name, mod in m.named_modules():
if is_activation_post_process(mod) and mod.dtype == torch.quint8:
if _is_activation_post_process(mod) and mod.dtype == torch.quint8:
if backend == "fbgemm":
lower_bnd = 0
upper_bnd = 127

View File

@ -24,7 +24,7 @@ from .ns_types import (
from torch.ao.ns.fx.mappings import (
get_node_type_to_io_type_map,
)
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.observer import _is_activation_post_process
from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
@ -38,7 +38,7 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
if node.op == 'call_module':
assert isinstance(node.target, str)
module = getattr_from_fqn(gm, node.target)
if is_activation_post_process(module):
if _is_activation_post_process(module):
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
return fqn # type: ignore[return-value]

View File

@ -13,10 +13,10 @@ from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.ao.quantization import (
ObserverBase,
FakeQuantizeBase,
FakeQuantizeBase
)
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.utils import getattr_from_fqn
from torch.ao.quantization.quantize import is_activation_post_process
from .ns_types import NSNodeTargetType, NSResultsType
@ -256,14 +256,14 @@ def return_first_non_observer_node(
"""
if node.op == "call_module":
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
if is_activation_post_process(node_obj):
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]
# code duplication intended, not worth refactoring
assert isinstance(node.target, str)
node_obj = getattr_from_fqn(gm, node.target)
if is_activation_post_process(node_obj):
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]

View File

@ -114,7 +114,6 @@ __all__ = [
"get_quantized_operator",
"get_static_quant_module_class",
"get_unique_devices_",
"is_activation_post_process",
"load_observer_state_dict",
"no_observer_set",
"per_channel_weight_observer_range_neg_127_to_127",

View File

@ -23,7 +23,7 @@ from torch.ao.quantization.fx._equalize import (
default_equalization_qconfig,
EqualizationQConfig,
)
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.observer import _is_activation_post_process
# Names for observer insert keys
DETECTOR_TARGET_NODE_KEY = "target_node"
@ -1273,7 +1273,7 @@ class OutlierDetector(DetectorBase):
# case for insertion of module
# check if the module has any children and isn't observer
num_children = len(list(module.children()))
return num_children == 0 and not is_activation_post_process(module)
return num_children == 0 and not _is_activation_post_process(module)
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
r""" Returns the DetectorQConfigInfo for each module_fqn relavent

View File

@ -61,7 +61,6 @@ from .utils import (
)
from torch.ao.quantization.quantize import (
_remove_qconfig,
is_activation_post_process,
)
from torch.ao.quantization.stubs import DeQuantStub
from .custom_config import (
@ -71,6 +70,7 @@ from .custom_config import (
from .lower_to_fbgemm import lower_to_fbgemm
# importing the lib so that the quantized_decomposed ops are registered
from ._decomposed import quantized_decomposed_lib # noqa: F401
from torch.ao.quantization.observer import _is_activation_post_process
# TODO: revisit this list. Many helper methods shouldn't be public
@ -218,7 +218,7 @@ def maybe_get_observer_for_node(
for maybe_obs_node, _ in node.users.items():
if maybe_obs_node.op == 'call_module':
maybe_obs = modules[str(maybe_obs_node.target)]
if is_activation_post_process(maybe_obs):
if _is_activation_post_process(maybe_obs):
return maybe_obs
return None
@ -725,7 +725,7 @@ def convert(
elif node.op == "call_module":
mod = _get_module(node, modules)
assert mod is not None
if is_activation_post_process(mod):
if _is_activation_post_process(mod):
observed_node = node.args[0]
if observed_node in statically_quantized_custom_module_nodes:
replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)

View File

@ -16,6 +16,7 @@ from ..quantize import (
)
from ..observer import (
ObserverBase,
_is_activation_post_process
)
from ..qconfig import (
_is_reuse_input_qconfig,
@ -78,7 +79,6 @@ from .utils import (
)
from torch.ao.quantization.quantize import (
is_activation_post_process,
convert
)
@ -148,7 +148,7 @@ DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
is_activation_post_process(modules[str(node.target)])
_is_activation_post_process(modules[str(node.target)])
def is_input_arg_dtype_supported_by_backend(
arg: Argument,

View File

@ -3,8 +3,8 @@ from collections import defaultdict, OrderedDict
from typing import Callable, Any, Dict, Tuple, Set, List
from torch.ao.quantization import QConfig
from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals
from torch.ao.quantization.quantize import (
is_activation_post_process,
from torch.ao.quantization.observer import (
_is_activation_post_process,
)
from torch.ao.quantization.backend_config import (
DTypeConfig,
@ -158,7 +158,7 @@ def generate_node_name_to_qconfig(
elif node.op == 'call_module':
# if the node is an observer, just continue - don't add it to the qconfig_map
if is_activation_post_process(modules[node.target]):
if _is_activation_post_process(modules[node.target]):
continue
qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, type(modules[node.target]), node.target, global_qconfig)

View File

@ -30,7 +30,7 @@ from torch.ao.quantization.utils import (
is_per_channel,
to_underlying_dtype,
)
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.observer import _is_activation_post_process
from torch.fx import GraphModule, map_arg
@ -447,7 +447,7 @@ def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module
result = False
elif node.op == 'call_module':
assert isinstance(node.target, str)
if is_activation_post_process(modules[node.target]):
if _is_activation_post_process(modules[node.target]):
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
elif node.op == 'call_module':
result = False
@ -1040,7 +1040,7 @@ def _qconfig_satisfies_dtype_config_constraints(
satisfies_constraints = True
if activation_post_process_ctr is not None:
activation_post_process = activation_post_process_ctr()
assert is_activation_post_process(activation_post_process)
assert _is_activation_post_process(activation_post_process)
# If dtypes don't match, don't check the activation_post_process and return True early
if activation_post_process.dtype != dtype_with_constraints.dtype:
return True

View File

@ -1437,7 +1437,7 @@ def _is_observer_script_module(mod, obs_type_name):
def _is_activation_post_process(module):
return (
isinstance(module, torch.ao.quantization.ObserverBase)
or isinstance(module, torch.ao.quantization.FakeQuantize)
or isinstance(module, torch.ao.quantization.FakeQuantizeBase)
or _is_observer_script_module(module, "quantization.observer")
)

View File

@ -27,10 +27,10 @@ from torch.ao.quantization.qconfig import (
float_qparams_weight_only_qconfig_4bit,
_activation_is_memoryless)
from torch.nn.utils.parametrize import type_before_parametrizations
from torch.ao.quantization.observer import _is_activation_post_process
__all__ = [
"get_default_custom_config_dict",
"is_activation_post_process",
"propagate_qconfig_",
"register_activation_post_process_hook",
"add_observer_",
@ -62,11 +62,6 @@ def get_default_custom_config_dict():
"""
return _DEFAULT_CUSTOM_CONFIG_DICT
def is_activation_post_process(module):
return (isinstance(module, torch.ao.quantization.ObserverBase) or
isinstance(module, torch.ao.quantization.FakeQuantizeBase))
def _propagate_qconfig_helper(module, qconfig_dict,
qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
r"""This is a helper function for `propagate_qconfig_`
@ -322,7 +317,7 @@ def _remove_activation_post_process(module):
# TODO: maybe we should change activation_post_process to _activation_post_process
# to prevent it from being used by user
if hasattr(module, 'activation_post_process') and \
is_activation_post_process(module.activation_post_process):
_is_activation_post_process(module.activation_post_process):
delattr(module, 'activation_post_process')
# remove activation_post_proceess pre and post hooks

View File

@ -17,7 +17,7 @@ from torch.ao.quantization.quantize import add_quant_dequant
from torch.ao.quantization.quantize import convert
from torch.ao.quantization.quantize import get_observer_dict
from torch.ao.quantization.quantize import get_unique_devices_
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.quantize import _is_activation_post_process
from torch.ao.quantization.quantize import prepare
from torch.ao.quantization.quantize import prepare_qat
from torch.ao.quantization.quantize import propagate_qconfig_