Revert "[ao] making _is_activation_post_process private (#87520)"

This reverts commit 45c62a337756ff9db97cd64d2d42d9e65dda0a85.

Reverted https://github.com/pytorch/pytorch/pull/87520 on behalf of https://github.com/bigfootjon due to Diff reverted internally
This commit is contained in:
PyTorch MergeBot
2022-11-21 16:48:26 +00:00
parent f3db03612f
commit 9d209e7834
15 changed files with 36 additions and 30 deletions

View File

@ -786,7 +786,7 @@
"get_quantized_operator",
"get_static_quant_module_class",
"get_unique_devices_",
"_is_activation_post_process",
"is_activation_post_process",
"load_observer_state_dict",
"no_observer_set",
"prepare",
@ -894,7 +894,7 @@
"convert",
"get_observer_dict",
"get_unique_devices_",
"_is_activation_post_process",
"is_activation_post_process",
"prepare",
"prepare_qat",
"propagate_qconfig_",

View File

@ -19,7 +19,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
'convert',
'get_observer_dict',
'get_unique_devices_',
'_is_activation_post_process',
'is_activation_post_process',
'prepare',
'prepare_qat',
'propagate_qconfig_',

View File

@ -22,7 +22,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
'convert',
'get_observer_dict',
'get_unique_devices_',
'_is_activation_post_process',
'is_activation_post_process',
'prepare',
'prepare_qat',
'propagate_qconfig_',

View File

@ -55,6 +55,7 @@ from torch.ao.quantization import (
get_default_qat_qconfig,
get_default_qconfig_mapping,
get_default_qat_qconfig_mapping,
is_activation_post_process,
fuse_modules,
fuse_modules_qat,
prepare,
@ -147,7 +148,6 @@ from torch.ao.quantization.observer import (
default_fixed_qparams_range_0to1_observer,
default_fixed_qparams_range_neg1to1_observer,
MinMaxObserver,
_is_activation_post_process,
)
# test utils
@ -3249,7 +3249,7 @@ class TestQuantizeFx(QuantizationTestCase):
_check_node_not_observed(model, new_node, node)
elif arg_node.op == "call_module":
self.assertTrue(
not _is_activation_post_process(getattr(model, arg_node.target)),
not is_activation_post_process(getattr(model, arg_node.target)),
"Arg: {0} of node: {1} is observed but is not a float tensor".format(
arg_node, node
),
@ -5008,7 +5008,7 @@ class TestQuantizeFx(QuantizationTestCase):
qconfig_dict = func(backend)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1)))
for name, mod in m.named_modules():
if _is_activation_post_process(mod) and mod.dtype == torch.quint8:
if is_activation_post_process(mod) and mod.dtype == torch.quint8:
if backend == "fbgemm":
lower_bnd = 0
upper_bnd = 127

View File

@ -24,7 +24,7 @@ from .ns_types import (
from torch.ao.ns.fx.mappings import (
get_node_type_to_io_type_map,
)
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.quantize import is_activation_post_process
from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
@ -38,7 +38,7 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
if node.op == 'call_module':
assert isinstance(node.target, str)
module = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(module):
if is_activation_post_process(module):
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
return fqn # type: ignore[return-value]

View File

@ -13,10 +13,10 @@ from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.ao.quantization import (
ObserverBase,
FakeQuantizeBase
FakeQuantizeBase,
)
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.utils import getattr_from_fqn
from torch.ao.quantization.quantize import is_activation_post_process
from .ns_types import NSNodeTargetType, NSResultsType
@ -256,14 +256,14 @@ def return_first_non_observer_node(
"""
if node.op == "call_module":
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
if _is_activation_post_process(node_obj):
if is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]
# code duplication intended, not worth refactoring
assert isinstance(node.target, str)
node_obj = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(node_obj):
if is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]

View File

@ -114,6 +114,7 @@ __all__ = [
"get_quantized_operator",
"get_static_quant_module_class",
"get_unique_devices_",
"is_activation_post_process",
"load_observer_state_dict",
"no_observer_set",
"per_channel_weight_observer_range_neg_127_to_127",

View File

@ -23,7 +23,7 @@ from torch.ao.quantization.fx._equalize import (
default_equalization_qconfig,
EqualizationQConfig,
)
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.quantize import is_activation_post_process
# Names for observer insert keys
DETECTOR_TARGET_NODE_KEY = "target_node"
@ -1273,7 +1273,7 @@ class OutlierDetector(DetectorBase):
# case for insertion of module
# check if the module has any children and isn't observer
num_children = len(list(module.children()))
return num_children == 0 and not _is_activation_post_process(module)
return num_children == 0 and not is_activation_post_process(module)
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
r""" Returns the DetectorQConfigInfo for each module_fqn relavent

View File

@ -64,6 +64,7 @@ from torch.ao.quantization.utils import (
)
from torch.ao.quantization.quantize import (
_remove_qconfig,
is_activation_post_process,
)
from torch.ao.quantization.stubs import DeQuantStub
from .custom_config import (
@ -73,7 +74,6 @@ from .custom_config import (
from .lower_to_fbgemm import lower_to_fbgemm
# importing the lib so that the quantized_decomposed ops are registered
from ._decomposed import quantized_decomposed_lib # noqa: F401
from torch.ao.quantization.observer import _is_activation_post_process
# TODO: revisit this list. Many helper methods shouldn't be public
@ -359,7 +359,7 @@ def maybe_get_observer_for_node(
for maybe_obs_node, _ in node.users.items():
if maybe_obs_node.op == 'call_module':
maybe_obs = modules[str(maybe_obs_node.target)]
if _is_activation_post_process(maybe_obs):
if is_activation_post_process(maybe_obs):
return maybe_obs
return None
@ -787,7 +787,7 @@ def convert(
elif node.op == "call_module":
mod = _get_module(node, modules)
assert mod is not None
if _is_activation_post_process(mod):
if is_activation_post_process(mod):
observed_node = node.args[0]
if observed_node in statically_quantized_custom_module_nodes:
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)

View File

@ -16,7 +16,6 @@ from ..quantize import (
)
from ..observer import (
ObserverBase,
_is_activation_post_process
)
from ..qconfig import (
_is_reuse_input_qconfig,
@ -79,6 +78,7 @@ from .utils import (
)
from torch.ao.quantization.quantize import (
is_activation_post_process,
convert
)
@ -148,7 +148,7 @@ DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
_is_activation_post_process(modules[str(node.target)])
is_activation_post_process(modules[str(node.target)])
def is_input_arg_dtype_supported_by_backend(
arg: Argument,

View File

@ -3,8 +3,8 @@ from collections import defaultdict, OrderedDict
from typing import Callable, Any, Dict, Tuple, Set, List
from torch.ao.quantization import QConfig
from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals
from torch.ao.quantization.observer import (
_is_activation_post_process,
from torch.ao.quantization.quantize import (
is_activation_post_process,
)
from torch.ao.quantization.backend_config import (
DTypeConfig,
@ -158,7 +158,7 @@ def generate_node_name_to_qconfig(
elif node.op == 'call_module':
# if the node is an observer, just continue - don't add it to the qconfig_map
if _is_activation_post_process(modules[node.target]):
if is_activation_post_process(modules[node.target]):
continue
qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, type(modules[node.target]), node.target, global_qconfig)

View File

@ -30,7 +30,7 @@ from torch.ao.quantization.utils import (
is_per_channel,
to_underlying_dtype,
)
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.quantize import is_activation_post_process
from torch.fx import GraphModule, map_arg
@ -447,7 +447,7 @@ def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module
result = False
elif node.op == 'call_module':
assert isinstance(node.target, str)
if _is_activation_post_process(modules[node.target]):
if is_activation_post_process(modules[node.target]):
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
elif node.op == 'call_module':
result = False
@ -1040,7 +1040,7 @@ def _qconfig_satisfies_dtype_config_constraints(
satisfies_constraints = True
if activation_post_process_ctr is not None:
activation_post_process = activation_post_process_ctr()
assert _is_activation_post_process(activation_post_process)
assert is_activation_post_process(activation_post_process)
# If dtypes don't match, don't check the activation_post_process and return True early
if activation_post_process.dtype != dtype_with_constraints.dtype:
return True

View File

@ -1442,7 +1442,7 @@ def _is_observer_script_module(mod, obs_type_name):
def _is_activation_post_process(module):
return (
isinstance(module, torch.ao.quantization.ObserverBase)
or isinstance(module, torch.ao.quantization.FakeQuantizeBase)
or isinstance(module, torch.ao.quantization.FakeQuantize)
or _is_observer_script_module(module, "quantization.observer")
)

View File

@ -27,10 +27,10 @@ from torch.ao.quantization.qconfig import (
float_qparams_weight_only_qconfig_4bit,
_activation_is_memoryless)
from torch.nn.utils.parametrize import type_before_parametrizations
from torch.ao.quantization.observer import _is_activation_post_process
__all__ = [
"get_default_custom_config_dict",
"is_activation_post_process",
"propagate_qconfig_",
"register_activation_post_process_hook",
"add_observer_",
@ -62,6 +62,11 @@ def get_default_custom_config_dict():
"""
return _DEFAULT_CUSTOM_CONFIG_DICT
def is_activation_post_process(module):
return (isinstance(module, torch.ao.quantization.ObserverBase) or
isinstance(module, torch.ao.quantization.FakeQuantizeBase))
def _propagate_qconfig_helper(module, qconfig_dict,
qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
r"""This is a helper function for `propagate_qconfig_`
@ -319,7 +324,7 @@ def _remove_activation_post_process(module):
# TODO: maybe we should change activation_post_process to _activation_post_process
# to prevent it from being used by user
if hasattr(module, 'activation_post_process') and \
_is_activation_post_process(module.activation_post_process):
is_activation_post_process(module.activation_post_process):
delattr(module, 'activation_post_process')
# remove activation_post_proceess pre and post hooks

View File

@ -17,7 +17,7 @@ from torch.ao.quantization.quantize import add_quant_dequant
from torch.ao.quantization.quantize import convert
from torch.ao.quantization.quantize import get_observer_dict
from torch.ao.quantization.quantize import get_unique_devices_
from torch.ao.quantization.quantize import _is_activation_post_process
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.quantize import prepare
from torch.ao.quantization.quantize import prepare_qat
from torch.ao.quantization.quantize import propagate_qconfig_