[quant][fx] Move backend_config folder to torch.ao.quantization

Summary:
Following https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md we implemented
the backend configuration for fbgemm/qnnpack backend, currently it was under fx folder, but we'd like to use this for all different
workflows, including eager, fx graph and define by run quantization, this PR moves it to torch.ao.quantization namespace so that
it can be shared by different workflows
Also moves some utility functions specific to fx to fx/backend_config_utils.py and some files are kept in fx folder (quantize_handler.py and fuse_handler.py)

Test Plan:
python test/teset_quantization.py TestQuantizeFx
python test/teset_quantization.py TestQuantizeFxOps
python test/teset_quantization.py TestQuantizeFxModels
python test/test_quantization.py TestAOMigrationQuantization
python test/test_quantization.py TestAOMigrationQuantizationFx

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75823

Approved by: https://github.com/vkuzo
This commit is contained in:
Jerry Zhang
2022-04-18 15:20:09 -07:00
committed by PyTorch MergeBot
parent cbb9b33c85
commit 74454bdb46
29 changed files with 263 additions and 208 deletions

View File

@ -133,7 +133,7 @@ coverage_ignore_functions = [
"unregister_custom_op_symbolic",
# torch.ao.quantization
"default_eval_fn",
# torch.ao.quantization.fx.backend_config
# torch.ao.quantization.backend_config
"validate_backend_config_dict",
# torch.backends
"disable_global_flags",

View File

@ -910,7 +910,7 @@ Numerical Debugging (prototype)
.. py:module:: torch.ao.ns.fx
.. py:module:: torch.ao.quantization
.. py:module:: torch.ao.quantization.fx
.. py:module:: torch.ao.quantization.fx.backend_config
.. py:module:: torch.ao.quantization.backend_config
.. py:module:: torch.ao.sparsity
.. py:module:: torch.ao.sparsity.experimental
.. py:module:: torch.ao.sparsity.experimental.pruner

View File

@ -3,7 +3,7 @@ This script will generate default values of quantization configs.
These are for use in the documentation.
"""
from torch.ao.quantization.fx.backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config import get_native_backend_config_dict
import os.path
from pprint import pprint

View File

@ -80,10 +80,7 @@
"Union",
"get_combined_dict"
],
"torch.ao.quantization.fx.backend_config.fuse_handler": [
"DefaultFuseHandler"
],
"torch.ao.quantization.fx.backend_config.native": [
"torch.ao.quantization.backend_config.native": [
"Any",
"Dict",
"FixedQParamsFakeQuantize",
@ -100,39 +97,21 @@
"reverse3",
"reverse_sequential_wrapper2"
],
"torch.ao.quantization.fx.backend_config.observation_type": [
"torch.ao.quantization.backend_config.observation_type": [
"Enum"
],
"torch.ao.quantization.fx.backend_config.quantize_handler": [
"Any",
"Callable",
"Dict",
"NodePattern",
"ObservationType",
"Optional",
"Pattern",
"QuantizeHandler",
"activation_dtype"
],
"torch.ao.quantization.fx.backend_config.tensorrt": [
"torch.ao.quantization.backend_config.tensorrt": [
"ObservationType",
"reverse_sequential_wrapper2"
],
"torch.ao.quantization.fx.backend_config.utils": [
"Any",
"Callable",
"Dict",
"List",
"Pattern",
"QuantizerCls",
"Tuple",
"Union",
"get_combined_dict",
"get_default_quant_patterns",
"get_fuse_handler_cls",
"get_native_backend_config_dict",
"get_quantize_handler_cls",
"sorted_patterns_dict"
"torch.ao.quantization.quantization_types": [
"Any",
"Node",
"NodePattern",
"Pattern",
"QuantizerCls",
"Tuple",
"Union"
],
"torch.ao.quantization.fx.convert": [
"Any",
@ -380,6 +359,23 @@
"map_arg",
"namedtuple"
],
"torch.ao.quantization.fx.backend_config_utils": [
"Any",
"Callable",
"DefaultFuseHandler",
"Dict",
"NodePattern",
"ObservationType",
"Optional",
"Pattern",
"QuantizeHandler",
"QuantizerCls",
"activation_dtype",
"get_combined_dict",
"get_default_quant_patterns",
"get_native_backend_config_dict",
"sorted_patterns_dict"
],
"torch.ao.quantization.observer": [
"ABC",
"ABCMeta",

View File

@ -168,15 +168,10 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
]
self._test_function_import('fx.fusion_patterns', function_list)
def test_package_import_fx_quantization_types(self):
self._test_package_import('fx.quantization_types')
def test_function_import_fx_quantization_types(self):
function_list = [
'Pattern',
'QuantizerCls'
]
self._test_function_import('fx.quantization_types', function_list)
# we removed matching test for torch.quantization.fx.quantization_types
# old: torch.quantization.fx.quantization_types
# new: torch.ao.quantization.quantization_types
# both are valid, but we'll deprecate the old path in the future
def test_package_import_fx_utils(self):
self._test_package_import('fx.utils')

View File

@ -71,8 +71,8 @@ from torch.ao.ns._numeric_suite_fx import (
extract_shadow_logger_info,
extend_logger_results_with_comparison,
)
from torch.ao.quantization.fx.backend_config import get_native_backend_config_dict
from torch.ao.quantization.fx.backend_config.utils import get_pattern_to_quantize_handlers
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers
# Note: these models are not for use outside of this file. While it's good

View File

@ -8,7 +8,7 @@ from torch.fx.graph import Node
from torch.ao.quantization.utils import getattr_from_fqn
from .ns_types import NSNodeTargetType
from torch.ao.quantization.fx.backend_config.utils import get_native_quant_patterns
from torch.ao.quantization.fx.backend_config_utils import get_native_quant_patterns
from torch.ao.quantization import (
ObserverBase,
FakeQuantizeBase,

View File

@ -4,3 +4,8 @@ from .native import get_native_backend_config_dict
# TODO: add more validations
def validate_backend_config_dict(backend_config_dict):
return "configs" in backend_config_dict
__all__ = [
"get_native_backend_config_dict",
"get_tensorrt_backend_config_dict",
]

View File

@ -2,19 +2,19 @@ from collections import namedtuple
from typing import List, Dict, Any
import operator
import torch
from .observation_type import ObservationType
from torch.ao.quantization.backend_config.observation_type import ObservationType
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.qat as nniqat
import torch.nn.qat as nnqat
import torch.nn.quantized._reference as nnqr
from ...observer import (
from ..observer import (
default_affine_fixed_qparams_observer,
default_symmetric_fixed_qparams_observer,
)
from ...fake_quantize import FixedQParamsFakeQuantize
from ...fuser_method_mappings import (
from ..fake_quantize import FixedQParamsFakeQuantize
from ..fuser_method_mappings import (
reverse_sequential_wrapper2,
reverse2,
reverse3,
@ -711,3 +711,7 @@ def get_native_backend_config_dict():
*_get_embedding_op_configs(),
],
}
__all__ = [
"get_native_backend_config_dict",
]

View File

@ -4,7 +4,7 @@ import torch.nn.qat as nnqat
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.qat as nniqat
from ...fuser_method_mappings import reverse_sequential_wrapper2
from ..fuser_method_mappings import reverse_sequential_wrapper2
def get_tensorrt_backend_config_dict():
""" Get the backend config dictionary for tensorrt backend
@ -218,3 +218,7 @@ def get_tensorrt_backend_config_dict():
identity_config,
]
}
__all__ = [
"get_tensorrt_backend_config_dict",
]

View File

@ -1,41 +1,8 @@
from typing import Dict, Any, List, Callable, Union, Tuple
import torch
from torch.ao.quantization.utils import get_combined_dict
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict
import torch.nn as nn
from .quantize_handler import get_quantize_handler_cls
from .fuse_handler import get_fuse_handler_cls
from .native import get_native_backend_config_dict
from ..quantization_types import Pattern, QuantizerCls
def get_pattern_to_quantize_handlers(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, QuantizerCls]:
"""
Note: Quantize handler is just a holder for some check methods like
(should_insert_observer_for_output), maybe this can be a enum as well,
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
new path, this is not exposed to backend developers
"""
pattern_to_quantize_handlers = dict()
for config in backend_config_dict.get("configs", []):
pattern = config["pattern"]
observation_type = config.get("observation_type", None)
dtype_configs = config["dtype_configs"]
num_tensor_args_to_observation_type = config.get("num_tensor_args_to_observation_type", {})
overwrite_fake_quantizer = config.get("_overwrite_output_fake_quantizer", None)
overwrite_observer = config.get("_overwrite_output_observer", None)
input_output_observed = config.get("_input_output_observed", True)
pattern_to_quantize_handlers[pattern] = \
get_quantize_handler_cls(
observation_type,
dtype_configs,
num_tensor_args_to_observation_type,
overwrite_fake_quantizer,
overwrite_observer,
input_output_observed)
return pattern_to_quantize_handlers
from ..quantization_types import Pattern
def get_pattern_to_dtype_configs(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, List[Dict[str, Any]]]:
@ -83,17 +50,6 @@ def get_root_module_to_quantized_reference_module(
mapping[config["root_module"]] = config["reference_quantized_module_for_root"]
return mapping
def get_fusion_pattern_to_fuse_handler_cls(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Callable]:
fusion_pattern_to_fuse_handlers = dict()
for config in backend_config_dict.get("configs", []):
if "fuser_method" in config:
pattern = config["pattern"]
fusion_pattern_to_fuse_handlers[pattern] = \
get_fuse_handler_cls()
return fusion_pattern_to_fuse_handlers
def get_fuser_method_mapping(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Union[nn.Sequential, Callable]]:
fuser_method_mapping : Dict[Pattern, Union[nn.Sequential, Callable]] = dict()
@ -159,19 +115,3 @@ def get_fusion_pattern_to_extra_inputs_getter(
extra_inputs_getter_mapping[pattern] = extra_inputs_getter
return extra_inputs_getter_mapping
# TODO: remove when all uses are changed to backend_config_dict
def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]:
"""
Return a map from pattern to quantize handlers based on the default patterns and the native backend_config_dict.
The returned map is sorted such that longer patterns will be encountered first when iterating through it.
"""
patterns = get_default_quant_patterns()
if additional_quant_patterns is not None:
patterns = get_combined_dict(patterns, additional_quant_patterns)
# TODO: currently we just extend the quantize handlers generated from
# `get_native_backend_config_dict`
# in the future we can just assign backend_config_dict when everything is defined
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config_dict()).items():
patterns[pattern] = quantize_handler
return sorted_patterns_dict(patterns)

View File

@ -1,4 +1,3 @@
from .prepare import prepare
from .convert import convert
from .fuse import fuse
from .backend_config import get_tensorrt_backend_config_dict

View File

@ -1,5 +0,0 @@
from ..fusion_patterns import DefaultFuseHandler
# TODO: move DefaultFuseHandler
def get_fuse_handler_cls():
return DefaultFuseHandler

View File

@ -1,67 +0,0 @@
import torch
from typing import Dict, Callable, Any, Optional
from .observation_type import ObservationType
from ..quantization_patterns import QuantizeHandler
from ..quantization_types import Pattern, NodePattern
from ...utils import (
activation_dtype,
)
def get_quantize_handler_cls(
observation_type,
dtype_configs,
num_tensor_args_to_observation_type,
overwrite_output_fake_quantizer,
overwrite_output_observer,
input_output_observed):
class ConfigurableQuantizeHandler(QuantizeHandler):
def __init__(
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Callable = None):
super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
f" in num_tensor_args_to_observation_type for {node_pattern}"
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
else:
self.observation_type = observation_type
self.dtype_configs = dtype_configs
self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer
self.overwrite_output_observer = overwrite_output_observer
self.input_output_observed_ = input_output_observed
def is_general_tensor_value_op(self) -> bool:
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
# TODO: change this to output activation
def get_activation_ctr(
self,
qconfig: Any,
pattern: Pattern,
is_training: bool,
) -> Optional[Callable]:
"""
Returns the constructor for the activation observer which should be
used for the pattern matched to this handler. Some handlers override
this to a different value than what is specified in the qconfig.
"""
act_dtype = activation_dtype(qconfig)
# TODO: change to is_qat
if is_training:
if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None:
return self.overwrite_output_fake_quantizer
else:
if act_dtype == torch.quint8 and self.overwrite_output_observer is not None:
return self.overwrite_output_observer
return qconfig.activation
# This is temporary, and will be removed soon
def input_output_observed(self):
return self.input_output_observed_
return ConfigurableQuantizeHandler

View File

@ -0,0 +1,141 @@
import torch
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config.observation_type import ObservationType
from torch.ao.quantization.quantization_types import (
Pattern,
NodePattern,
QuantizerCls,
)
from torch.ao.quantization.utils import (
activation_dtype,
get_combined_dict,
)
from .quantization_patterns import QuantizeHandler
from .fusion_patterns import DefaultFuseHandler
from typing import Dict, Any, Callable, Optional
def get_quantize_handler_cls(
observation_type,
dtype_configs,
num_tensor_args_to_observation_type,
overwrite_output_fake_quantizer,
overwrite_output_observer,
input_output_observed):
class ConfigurableQuantizeHandler(QuantizeHandler):
def __init__(
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Callable = None):
super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
f" in num_tensor_args_to_observation_type for {node_pattern}"
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
else:
self.observation_type = observation_type
self.dtype_configs = dtype_configs
self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer
self.overwrite_output_observer = overwrite_output_observer
self.input_output_observed_ = input_output_observed
def is_general_tensor_value_op(self) -> bool:
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
# TODO: change this to output activation
def get_activation_ctr(
self,
qconfig: Any,
pattern: Pattern,
is_training: bool,
) -> Optional[Callable]:
"""
Returns the constructor for the activation observer which should be
used for the pattern matched to this handler. Some handlers override
this to a different value than what is specified in the qconfig.
"""
act_dtype = activation_dtype(qconfig)
# TODO: change to is_qat
if is_training:
if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None:
return self.overwrite_output_fake_quantizer
else:
if act_dtype == torch.quint8 and self.overwrite_output_observer is not None:
return self.overwrite_output_observer
return qconfig.activation
# This is temporary, and will be removed soon
def input_output_observed(self):
return self.input_output_observed_
return ConfigurableQuantizeHandler
def get_pattern_to_quantize_handlers(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, QuantizerCls]:
"""
Note: Quantize handler is just a holder for some check methods like
(should_insert_observer_for_output), maybe this can be a enum as well,
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
new path, this is not exposed to backend developers
"""
pattern_to_quantize_handlers = dict()
for config in backend_config_dict.get("configs", []):
pattern = config["pattern"]
observation_type = config.get("observation_type", None)
dtype_configs = config["dtype_configs"]
num_tensor_args_to_observation_type = config.get("num_tensor_args_to_observation_type", {})
overwrite_fake_quantizer = config.get("_overwrite_output_fake_quantizer", None)
overwrite_observer = config.get("_overwrite_output_observer", None)
input_output_observed = config.get("_input_output_observed", True)
pattern_to_quantize_handlers[pattern] = \
get_quantize_handler_cls(
observation_type,
dtype_configs,
num_tensor_args_to_observation_type,
overwrite_fake_quantizer,
overwrite_observer,
input_output_observed)
return pattern_to_quantize_handlers
def get_fusion_pattern_to_fuse_handler_cls(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Callable]:
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = dict()
for config in backend_config_dict.get("configs", []):
if "fuser_method" in config:
pattern = config["pattern"]
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
return fusion_pattern_to_fuse_handlers
# TODO: remove when all uses are changed to backend_config_dict
def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]:
"""
Return a map from pattern to quantize handlers based on the default patterns and the native backend_config_dict.
The returned map is sorted such that longer patterns will be encountered first when iterating through it.
"""
patterns = get_default_quant_patterns()
if additional_quant_patterns is not None:
patterns = get_combined_dict(patterns, additional_quant_patterns)
# TODO: currently we just extend the quantize handlers generated from
# `get_native_backend_config_dict`
# in the future we can just assign backend_config_dict when everything is defined
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config_dict()).items():
patterns[pattern] = quantize_handler
return sorted_patterns_dict(patterns)
get_fusion_pattern_to_fuse_handler_cls.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_native_quant_patterns.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_pattern_to_quantize_handlers.__module__ = "torch.ao.quantization.fx.backend_config_utils"
__all__ = [
"get_fusion_pattern_to_fuse_handler_cls",
"get_native_quant_patterns",
"get_pattern_to_quantize_handlers",
]

View File

@ -31,13 +31,13 @@ from .qconfig_utils import (
update_qconfig_for_fusion,
is_qconfig_supported_by_dtype_configs,
)
from .backend_config.utils import (
from torch.ao.quantization.backend_config.utils import (
get_root_module_to_quantized_reference_module,
get_pattern_to_dtype_configs,
get_fused_module_classes,
get_qat_module_classes,
)
from .backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from .graph_module import (
QuantizedGraphModule,
is_observed_module,

View File

@ -15,17 +15,17 @@ from .pattern_utils import (
sorted_patterns_dict,
)
from .backend_config.utils import get_fusion_pattern_to_fuse_handler_cls
from .backend_config.utils import get_fuser_method_mapping
from .backend_config.utils import get_fusion_pattern_to_root_node_getter
from .backend_config.utils import get_fusion_pattern_to_extra_inputs_getter
from .backend_config import get_native_backend_config_dict
from ..backend_config.utils import get_fuser_method_mapping
from ..backend_config.utils import get_fusion_pattern_to_root_node_getter
from ..backend_config.utils import get_fusion_pattern_to_extra_inputs_getter
from ..backend_config import get_native_backend_config_dict
from .backend_config_utils import get_fusion_pattern_to_fuse_handler_cls
from .fusion_patterns import * # noqa: F401,F403
from typing import Callable, Tuple, Dict, Any, Optional, List
from .quantization_types import Pattern, NodePattern
from torch.ao.quantization.quantization_types import Pattern, NodePattern
def fuse(
model: GraphModule,

View File

@ -1,7 +1,7 @@
import torch
from torch.fx.graph import Node, Graph
from ..utils import _parent_name
from .quantization_types import NodePattern, Pattern
from torch.ao.quantization.quantization_types import NodePattern, Pattern
from ..fuser_method_mappings import get_fuser_method_new
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union, List

View File

@ -4,7 +4,7 @@ from torch.fx.graph import (
Graph,
Node,
)
from .quantization_types import Pattern
from torch.ao.quantization.quantization_types import Pattern
from .quantization_patterns import (
QuantizeHandler,
)

View File

@ -3,7 +3,7 @@ from typing import Dict, Any, Tuple, List, Optional
from torch.fx.graph import (
Node,
)
from .quantization_types import Pattern
from torch.ao.quantization.quantization_types import Pattern
from ..qconfig import QConfigAny
from ..fake_quantize import FixedQParamsFakeQuantize
# from .quantization_patterns import BinaryOpQuantizeHandler

View File

@ -32,7 +32,7 @@ from .quantization_patterns import (
QuantizeHandler,
)
from .quantization_types import (
from torch.ao.quantization.quantization_types import (
Pattern,
NodePattern
)
@ -80,16 +80,18 @@ from ..utils import (
activation_is_int8_quantized,
)
from .backend_config.utils import (
get_pattern_to_quantize_handlers,
from ..backend_config.utils import (
get_pattern_to_dtype_configs,
get_pattern_to_input_type_to_index,
get_module_to_qat_module,
get_fusion_pattern_to_root_node_getter,
)
from .backend_config import (
from ..backend_config import (
get_native_backend_config_dict,
)
from .backend_config_utils import (
get_pattern_to_quantize_handlers,
)
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set
from collections import defaultdict

View File

@ -6,7 +6,7 @@ from torch.fx.graph import (
from .utils import (
all_node_args_have_no_tensors,
)
from .quantization_types import (
from torch.ao.quantization.quantization_types import (
Pattern,
NodePattern,
)

View File

@ -1,6 +1,8 @@
# TODO: the name of this file is probably confusing, remove this file and move the type
# definitions to somewhere else, e.g. to .utils
from typing import Any, Tuple, Union
from torch.fx import Node
from ..utils import Pattern # noqa: F401
from .utils import Pattern # noqa: F401
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
@ -8,3 +10,9 @@ NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
# Define separately to prevent circular imports.
# TODO(future PR): improve this.
QuantizerCls = Any
__all__ = [
"Pattern",
"NodePattern",
"QuantizerCls",
]

View File

@ -8,7 +8,7 @@ from torch.nn.intrinsic import _FusedModule
from .fx import fuse # noqa: F401
from .fx import prepare # noqa: F401
from .fx.convert import convert
from .fx import get_tensorrt_backend_config_dict # noqa: F401
from .backend_config import get_tensorrt_backend_config_dict # noqa: F401
from .fx.graph_module import ObservedGraphModule
from .fx.qconfig_utils import (
check_is_valid_convert_custom_config_dict,
@ -59,7 +59,7 @@ def _fuse_fx(
"""
_check_is_graph_module(graph_module)
return fuse(
graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)
graph_module, is_qat, fuse_custom_config_dict, backend_config_dict) # type: ignore[operator]
class Scope(object):
@ -251,7 +251,7 @@ forward graph of the parent module,
equalization_qconfig_dict=equalization_qconfig_dict,
backend_config_dict=backend_config_dict,
is_standalone_module=is_standalone_module,
)
) # type: ignore[operator]
for attr_name in preserved_attributes:
setattr(prepared, attr_name, getattr(model, attr_name))
@ -643,7 +643,7 @@ def convert_fx(
operators should be quantized in the backend, this includes quantization
mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
observer placement for each operators and fused operators. Detailed
documentation can be found in torch/ao/quantization/fx/backend_config/README.md
documentation can be found in torch/ao/quantization/backend_config/README.md
Return:
A quantized model (GraphModule)

View File

@ -15,3 +15,21 @@ from torch.ao.quantization.fx.pattern_utils import (
get_default_quant_patterns,
get_default_output_activation_post_process_map
)
# QuantizeHandler.__module__ = _NAMESPACE
MatchResult.__module__ = "torch.quantization.fx.pattern_utils"
register_fusion_pattern.__module__ = "torch.quantization.fx.pattern_utils"
get_default_fusion_patterns.__module__ = "torch.quantization.fx.pattern_utils"
register_quant_pattern.__module__ = "torch.quantization.fx.pattern_utils"
get_default_quant_patterns.__module__ = "torch.quantization.fx.pattern_utils"
get_default_output_activation_post_process_map.__module__ = "torch.quantization.fx.pattern_utils"
# __all__ = [
# "QuantizeHandler",
# "MatchResult",
# "register_fusion_pattern",
# "get_default_fusion_patterns",
# "register_quant_pattern",
# "get_default_quant_patterns",
# "get_default_output_activation_post_process_map",
# ]

View File

@ -22,3 +22,18 @@ from torch.ao.quantization.fx.quantization_patterns import (
GeneralTensorShapeOpQuantizeHandler,
StandaloneModuleQuantizeHandler
)
QuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
BinaryOpQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
CatQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
ConvReluQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
LinearReLUQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
BatchNormQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
EmbeddingQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
RNNDynamicQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
DefaultNodeQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
FixedQParamsOpQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
CopyNodeQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
CustomModuleQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
GeneralTensorShapeOpQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
StandaloneModuleQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"

View File

@ -6,7 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.quantization_types import (
from torch.ao.quantization.quantization_types import (
Pattern,
QuantizerCls
)