mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ao] making _is_activation_post_process private (#87520)
Summary: same function in observer and quantize, consolidated to a single function. Note the definitions were slightly different, I've changed the definition to be maximally inclusive so that the name of the function is more accurate Test Plan: python test/test_public_bindings.py python test/test_quantization.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D40709276](https://our.internmc.facebook.com/intern/diff/D40709276) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87520 Approved by: https://github.com/jcaip
This commit is contained in:
committed by
PyTorch MergeBot
parent
aee96bbf5a
commit
45c62a3377
@ -786,7 +786,7 @@
|
||||
"get_quantized_operator",
|
||||
"get_static_quant_module_class",
|
||||
"get_unique_devices_",
|
||||
"is_activation_post_process",
|
||||
"_is_activation_post_process",
|
||||
"load_observer_state_dict",
|
||||
"no_observer_set",
|
||||
"prepare",
|
||||
@ -894,7 +894,7 @@
|
||||
"convert",
|
||||
"get_observer_dict",
|
||||
"get_unique_devices_",
|
||||
"is_activation_post_process",
|
||||
"_is_activation_post_process",
|
||||
"prepare",
|
||||
"prepare_qat",
|
||||
"propagate_qconfig_",
|
||||
|
@ -19,7 +19,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
'convert',
|
||||
'get_observer_dict',
|
||||
'get_unique_devices_',
|
||||
'is_activation_post_process',
|
||||
'_is_activation_post_process',
|
||||
'prepare',
|
||||
'prepare_qat',
|
||||
'propagate_qconfig_',
|
||||
|
@ -22,7 +22,7 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
|
||||
'convert',
|
||||
'get_observer_dict',
|
||||
'get_unique_devices_',
|
||||
'is_activation_post_process',
|
||||
'_is_activation_post_process',
|
||||
'prepare',
|
||||
'prepare_qat',
|
||||
'propagate_qconfig_',
|
||||
|
@ -55,7 +55,6 @@ from torch.ao.quantization import (
|
||||
get_default_qat_qconfig,
|
||||
get_default_qconfig_mapping,
|
||||
get_default_qat_qconfig_mapping,
|
||||
is_activation_post_process,
|
||||
fuse_modules,
|
||||
fuse_modules_qat,
|
||||
prepare,
|
||||
@ -148,6 +147,7 @@ from torch.ao.quantization.observer import (
|
||||
default_fixed_qparams_range_0to1_observer,
|
||||
default_fixed_qparams_range_neg1to1_observer,
|
||||
MinMaxObserver,
|
||||
_is_activation_post_process,
|
||||
)
|
||||
|
||||
# test utils
|
||||
@ -3249,7 +3249,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
_check_node_not_observed(model, new_node, node)
|
||||
elif arg_node.op == "call_module":
|
||||
self.assertTrue(
|
||||
not is_activation_post_process(getattr(model, arg_node.target)),
|
||||
not _is_activation_post_process(getattr(model, arg_node.target)),
|
||||
"Arg: {0} of node: {1} is observed but is not a float tensor".format(
|
||||
arg_node, node
|
||||
),
|
||||
@ -4933,7 +4933,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
qconfig_dict = func(backend)
|
||||
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1)))
|
||||
for name, mod in m.named_modules():
|
||||
if is_activation_post_process(mod) and mod.dtype == torch.quint8:
|
||||
if _is_activation_post_process(mod) and mod.dtype == torch.quint8:
|
||||
if backend == "fbgemm":
|
||||
lower_bnd = 0
|
||||
upper_bnd = 127
|
||||
|
@ -24,7 +24,7 @@ from .ns_types import (
|
||||
from torch.ao.ns.fx.mappings import (
|
||||
get_node_type_to_io_type_map,
|
||||
)
|
||||
from torch.ao.quantization.quantize import is_activation_post_process
|
||||
from torch.ao.quantization.observer import _is_activation_post_process
|
||||
|
||||
from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
|
||||
|
||||
@ -38,7 +38,7 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
|
||||
if node.op == 'call_module':
|
||||
assert isinstance(node.target, str)
|
||||
module = getattr_from_fqn(gm, node.target)
|
||||
if is_activation_post_process(module):
|
||||
if _is_activation_post_process(module):
|
||||
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
|
||||
fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
|
||||
return fqn # type: ignore[return-value]
|
||||
|
@ -13,10 +13,10 @@ from torch.fx import GraphModule
|
||||
from torch.fx.graph import Node
|
||||
from torch.ao.quantization import (
|
||||
ObserverBase,
|
||||
FakeQuantizeBase,
|
||||
FakeQuantizeBase
|
||||
)
|
||||
from torch.ao.quantization.observer import _is_activation_post_process
|
||||
from torch.ao.quantization.utils import getattr_from_fqn
|
||||
from torch.ao.quantization.quantize import is_activation_post_process
|
||||
|
||||
from .ns_types import NSNodeTargetType, NSResultsType
|
||||
|
||||
@ -256,14 +256,14 @@ def return_first_non_observer_node(
|
||||
"""
|
||||
if node.op == "call_module":
|
||||
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
|
||||
if is_activation_post_process(node_obj):
|
||||
if _is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
node = node.args[0]
|
||||
# code duplication intended, not worth refactoring
|
||||
assert isinstance(node.target, str)
|
||||
node_obj = getattr_from_fqn(gm, node.target)
|
||||
if is_activation_post_process(node_obj):
|
||||
if _is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
node = node.args[0]
|
||||
|
@ -114,7 +114,6 @@ __all__ = [
|
||||
"get_quantized_operator",
|
||||
"get_static_quant_module_class",
|
||||
"get_unique_devices_",
|
||||
"is_activation_post_process",
|
||||
"load_observer_state_dict",
|
||||
"no_observer_set",
|
||||
"per_channel_weight_observer_range_neg_127_to_127",
|
||||
|
@ -23,7 +23,7 @@ from torch.ao.quantization.fx._equalize import (
|
||||
default_equalization_qconfig,
|
||||
EqualizationQConfig,
|
||||
)
|
||||
from torch.ao.quantization.quantize import is_activation_post_process
|
||||
from torch.ao.quantization.observer import _is_activation_post_process
|
||||
|
||||
# Names for observer insert keys
|
||||
DETECTOR_TARGET_NODE_KEY = "target_node"
|
||||
@ -1273,7 +1273,7 @@ class OutlierDetector(DetectorBase):
|
||||
# case for insertion of module
|
||||
# check if the module has any children and isn't observer
|
||||
num_children = len(list(module.children()))
|
||||
return num_children == 0 and not is_activation_post_process(module)
|
||||
return num_children == 0 and not _is_activation_post_process(module)
|
||||
|
||||
def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
|
||||
r""" Returns the DetectorQConfigInfo for each module_fqn relavent
|
||||
|
@ -61,7 +61,6 @@ from .utils import (
|
||||
)
|
||||
from torch.ao.quantization.quantize import (
|
||||
_remove_qconfig,
|
||||
is_activation_post_process,
|
||||
)
|
||||
from torch.ao.quantization.stubs import DeQuantStub
|
||||
from .custom_config import (
|
||||
@ -71,6 +70,7 @@ from .custom_config import (
|
||||
from .lower_to_fbgemm import lower_to_fbgemm
|
||||
# importing the lib so that the quantized_decomposed ops are registered
|
||||
from ._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
from torch.ao.quantization.observer import _is_activation_post_process
|
||||
|
||||
|
||||
# TODO: revisit this list. Many helper methods shouldn't be public
|
||||
@ -218,7 +218,7 @@ def maybe_get_observer_for_node(
|
||||
for maybe_obs_node, _ in node.users.items():
|
||||
if maybe_obs_node.op == 'call_module':
|
||||
maybe_obs = modules[str(maybe_obs_node.target)]
|
||||
if is_activation_post_process(maybe_obs):
|
||||
if _is_activation_post_process(maybe_obs):
|
||||
return maybe_obs
|
||||
return None
|
||||
|
||||
@ -725,7 +725,7 @@ def convert(
|
||||
elif node.op == "call_module":
|
||||
mod = _get_module(node, modules)
|
||||
assert mod is not None
|
||||
if is_activation_post_process(mod):
|
||||
if _is_activation_post_process(mod):
|
||||
observed_node = node.args[0]
|
||||
if observed_node in statically_quantized_custom_module_nodes:
|
||||
replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
|
||||
|
@ -16,6 +16,7 @@ from ..quantize import (
|
||||
)
|
||||
from ..observer import (
|
||||
ObserverBase,
|
||||
_is_activation_post_process
|
||||
)
|
||||
from ..qconfig import (
|
||||
_is_reuse_input_qconfig,
|
||||
@ -78,7 +79,6 @@ from .utils import (
|
||||
)
|
||||
|
||||
from torch.ao.quantization.quantize import (
|
||||
is_activation_post_process,
|
||||
convert
|
||||
)
|
||||
|
||||
@ -148,7 +148,7 @@ DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
|
||||
|
||||
def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
|
||||
return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
|
||||
is_activation_post_process(modules[str(node.target)])
|
||||
_is_activation_post_process(modules[str(node.target)])
|
||||
|
||||
def is_input_arg_dtype_supported_by_backend(
|
||||
arg: Argument,
|
||||
|
@ -3,8 +3,8 @@ from collections import defaultdict, OrderedDict
|
||||
from typing import Callable, Any, Dict, Tuple, Set, List
|
||||
from torch.ao.quantization import QConfig
|
||||
from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals
|
||||
from torch.ao.quantization.quantize import (
|
||||
is_activation_post_process,
|
||||
from torch.ao.quantization.observer import (
|
||||
_is_activation_post_process,
|
||||
)
|
||||
from torch.ao.quantization.backend_config import (
|
||||
DTypeConfig,
|
||||
@ -158,7 +158,7 @@ def generate_node_name_to_qconfig(
|
||||
|
||||
elif node.op == 'call_module':
|
||||
# if the node is an observer, just continue - don't add it to the qconfig_map
|
||||
if is_activation_post_process(modules[node.target]):
|
||||
if _is_activation_post_process(modules[node.target]):
|
||||
continue
|
||||
qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
|
||||
qconfig_mapping, type(modules[node.target]), node.target, global_qconfig)
|
||||
|
@ -30,7 +30,7 @@ from torch.ao.quantization.utils import (
|
||||
is_per_channel,
|
||||
to_underlying_dtype,
|
||||
)
|
||||
from torch.ao.quantization.quantize import is_activation_post_process
|
||||
from torch.ao.quantization.observer import _is_activation_post_process
|
||||
|
||||
from torch.fx import GraphModule, map_arg
|
||||
|
||||
@ -447,7 +447,7 @@ def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module
|
||||
result = False
|
||||
elif node.op == 'call_module':
|
||||
assert isinstance(node.target, str)
|
||||
if is_activation_post_process(modules[node.target]):
|
||||
if _is_activation_post_process(modules[node.target]):
|
||||
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
|
||||
elif node.op == 'call_module':
|
||||
result = False
|
||||
@ -1040,7 +1040,7 @@ def _qconfig_satisfies_dtype_config_constraints(
|
||||
satisfies_constraints = True
|
||||
if activation_post_process_ctr is not None:
|
||||
activation_post_process = activation_post_process_ctr()
|
||||
assert is_activation_post_process(activation_post_process)
|
||||
assert _is_activation_post_process(activation_post_process)
|
||||
# If dtypes don't match, don't check the activation_post_process and return True early
|
||||
if activation_post_process.dtype != dtype_with_constraints.dtype:
|
||||
return True
|
||||
|
@ -1437,7 +1437,7 @@ def _is_observer_script_module(mod, obs_type_name):
|
||||
def _is_activation_post_process(module):
|
||||
return (
|
||||
isinstance(module, torch.ao.quantization.ObserverBase)
|
||||
or isinstance(module, torch.ao.quantization.FakeQuantize)
|
||||
or isinstance(module, torch.ao.quantization.FakeQuantizeBase)
|
||||
or _is_observer_script_module(module, "quantization.observer")
|
||||
)
|
||||
|
||||
|
@ -27,10 +27,10 @@ from torch.ao.quantization.qconfig import (
|
||||
float_qparams_weight_only_qconfig_4bit,
|
||||
_activation_is_memoryless)
|
||||
from torch.nn.utils.parametrize import type_before_parametrizations
|
||||
from torch.ao.quantization.observer import _is_activation_post_process
|
||||
|
||||
__all__ = [
|
||||
"get_default_custom_config_dict",
|
||||
"is_activation_post_process",
|
||||
"propagate_qconfig_",
|
||||
"register_activation_post_process_hook",
|
||||
"add_observer_",
|
||||
@ -62,11 +62,6 @@ def get_default_custom_config_dict():
|
||||
"""
|
||||
return _DEFAULT_CUSTOM_CONFIG_DICT
|
||||
|
||||
def is_activation_post_process(module):
|
||||
return (isinstance(module, torch.ao.quantization.ObserverBase) or
|
||||
isinstance(module, torch.ao.quantization.FakeQuantizeBase))
|
||||
|
||||
|
||||
def _propagate_qconfig_helper(module, qconfig_dict,
|
||||
qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
|
||||
r"""This is a helper function for `propagate_qconfig_`
|
||||
@ -322,7 +317,7 @@ def _remove_activation_post_process(module):
|
||||
# TODO: maybe we should change activation_post_process to _activation_post_process
|
||||
# to prevent it from being used by user
|
||||
if hasattr(module, 'activation_post_process') and \
|
||||
is_activation_post_process(module.activation_post_process):
|
||||
_is_activation_post_process(module.activation_post_process):
|
||||
delattr(module, 'activation_post_process')
|
||||
|
||||
# remove activation_post_proceess pre and post hooks
|
||||
|
@ -17,7 +17,7 @@ from torch.ao.quantization.quantize import add_quant_dequant
|
||||
from torch.ao.quantization.quantize import convert
|
||||
from torch.ao.quantization.quantize import get_observer_dict
|
||||
from torch.ao.quantization.quantize import get_unique_devices_
|
||||
from torch.ao.quantization.quantize import is_activation_post_process
|
||||
from torch.ao.quantization.quantize import _is_activation_post_process
|
||||
from torch.ao.quantization.quantize import prepare
|
||||
from torch.ao.quantization.quantize import prepare_qat
|
||||
from torch.ao.quantization.quantize import propagate_qconfig_
|
||||
|
Reference in New Issue
Block a user