mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][fx] Move backend_config folder to torch.ao.quantization
Summary: Following https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md we implemented the backend configuration for fbgemm/qnnpack backend, currently it was under fx folder, but we'd like to use this for all different workflows, including eager, fx graph and define by run quantization, this PR moves it to torch.ao.quantization namespace so that it can be shared by different workflows Also moves some utility functions specific to fx to fx/backend_config_utils.py and some files are kept in fx folder (quantize_handler.py and fuse_handler.py) Test Plan: python test/teset_quantization.py TestQuantizeFx python test/teset_quantization.py TestQuantizeFxOps python test/teset_quantization.py TestQuantizeFxModels python test/test_quantization.py TestAOMigrationQuantization python test/test_quantization.py TestAOMigrationQuantizationFx Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75823 Approved by: https://github.com/vkuzo
This commit is contained in:
committed by
PyTorch MergeBot
parent
cbb9b33c85
commit
74454bdb46
@ -133,7 +133,7 @@ coverage_ignore_functions = [
|
||||
"unregister_custom_op_symbolic",
|
||||
# torch.ao.quantization
|
||||
"default_eval_fn",
|
||||
# torch.ao.quantization.fx.backend_config
|
||||
# torch.ao.quantization.backend_config
|
||||
"validate_backend_config_dict",
|
||||
# torch.backends
|
||||
"disable_global_flags",
|
||||
|
@ -910,7 +910,7 @@ Numerical Debugging (prototype)
|
||||
.. py:module:: torch.ao.ns.fx
|
||||
.. py:module:: torch.ao.quantization
|
||||
.. py:module:: torch.ao.quantization.fx
|
||||
.. py:module:: torch.ao.quantization.fx.backend_config
|
||||
.. py:module:: torch.ao.quantization.backend_config
|
||||
.. py:module:: torch.ao.sparsity
|
||||
.. py:module:: torch.ao.sparsity.experimental
|
||||
.. py:module:: torch.ao.sparsity.experimental.pruner
|
||||
|
@ -3,7 +3,7 @@ This script will generate default values of quantization configs.
|
||||
These are for use in the documentation.
|
||||
"""
|
||||
|
||||
from torch.ao.quantization.fx.backend_config import get_native_backend_config_dict
|
||||
from torch.ao.quantization.backend_config import get_native_backend_config_dict
|
||||
import os.path
|
||||
from pprint import pprint
|
||||
|
||||
|
@ -80,10 +80,7 @@
|
||||
"Union",
|
||||
"get_combined_dict"
|
||||
],
|
||||
"torch.ao.quantization.fx.backend_config.fuse_handler": [
|
||||
"DefaultFuseHandler"
|
||||
],
|
||||
"torch.ao.quantization.fx.backend_config.native": [
|
||||
"torch.ao.quantization.backend_config.native": [
|
||||
"Any",
|
||||
"Dict",
|
||||
"FixedQParamsFakeQuantize",
|
||||
@ -100,39 +97,21 @@
|
||||
"reverse3",
|
||||
"reverse_sequential_wrapper2"
|
||||
],
|
||||
"torch.ao.quantization.fx.backend_config.observation_type": [
|
||||
"torch.ao.quantization.backend_config.observation_type": [
|
||||
"Enum"
|
||||
],
|
||||
"torch.ao.quantization.fx.backend_config.quantize_handler": [
|
||||
"Any",
|
||||
"Callable",
|
||||
"Dict",
|
||||
"NodePattern",
|
||||
"ObservationType",
|
||||
"Optional",
|
||||
"Pattern",
|
||||
"QuantizeHandler",
|
||||
"activation_dtype"
|
||||
],
|
||||
"torch.ao.quantization.fx.backend_config.tensorrt": [
|
||||
"torch.ao.quantization.backend_config.tensorrt": [
|
||||
"ObservationType",
|
||||
"reverse_sequential_wrapper2"
|
||||
],
|
||||
"torch.ao.quantization.fx.backend_config.utils": [
|
||||
"Any",
|
||||
"Callable",
|
||||
"Dict",
|
||||
"List",
|
||||
"Pattern",
|
||||
"QuantizerCls",
|
||||
"Tuple",
|
||||
"Union",
|
||||
"get_combined_dict",
|
||||
"get_default_quant_patterns",
|
||||
"get_fuse_handler_cls",
|
||||
"get_native_backend_config_dict",
|
||||
"get_quantize_handler_cls",
|
||||
"sorted_patterns_dict"
|
||||
"torch.ao.quantization.quantization_types": [
|
||||
"Any",
|
||||
"Node",
|
||||
"NodePattern",
|
||||
"Pattern",
|
||||
"QuantizerCls",
|
||||
"Tuple",
|
||||
"Union"
|
||||
],
|
||||
"torch.ao.quantization.fx.convert": [
|
||||
"Any",
|
||||
@ -380,6 +359,23 @@
|
||||
"map_arg",
|
||||
"namedtuple"
|
||||
],
|
||||
"torch.ao.quantization.fx.backend_config_utils": [
|
||||
"Any",
|
||||
"Callable",
|
||||
"DefaultFuseHandler",
|
||||
"Dict",
|
||||
"NodePattern",
|
||||
"ObservationType",
|
||||
"Optional",
|
||||
"Pattern",
|
||||
"QuantizeHandler",
|
||||
"QuantizerCls",
|
||||
"activation_dtype",
|
||||
"get_combined_dict",
|
||||
"get_default_quant_patterns",
|
||||
"get_native_backend_config_dict",
|
||||
"sorted_patterns_dict"
|
||||
],
|
||||
"torch.ao.quantization.observer": [
|
||||
"ABC",
|
||||
"ABCMeta",
|
||||
|
@ -168,15 +168,10 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
]
|
||||
self._test_function_import('fx.fusion_patterns', function_list)
|
||||
|
||||
def test_package_import_fx_quantization_types(self):
|
||||
self._test_package_import('fx.quantization_types')
|
||||
|
||||
def test_function_import_fx_quantization_types(self):
|
||||
function_list = [
|
||||
'Pattern',
|
||||
'QuantizerCls'
|
||||
]
|
||||
self._test_function_import('fx.quantization_types', function_list)
|
||||
# we removed matching test for torch.quantization.fx.quantization_types
|
||||
# old: torch.quantization.fx.quantization_types
|
||||
# new: torch.ao.quantization.quantization_types
|
||||
# both are valid, but we'll deprecate the old path in the future
|
||||
|
||||
def test_package_import_fx_utils(self):
|
||||
self._test_package_import('fx.utils')
|
||||
|
@ -71,8 +71,8 @@ from torch.ao.ns._numeric_suite_fx import (
|
||||
extract_shadow_logger_info,
|
||||
extend_logger_results_with_comparison,
|
||||
)
|
||||
from torch.ao.quantization.fx.backend_config import get_native_backend_config_dict
|
||||
from torch.ao.quantization.fx.backend_config.utils import get_pattern_to_quantize_handlers
|
||||
from torch.ao.quantization.backend_config import get_native_backend_config_dict
|
||||
from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers
|
||||
|
||||
|
||||
# Note: these models are not for use outside of this file. While it's good
|
||||
|
@ -8,7 +8,7 @@ from torch.fx.graph import Node
|
||||
|
||||
from torch.ao.quantization.utils import getattr_from_fqn
|
||||
from .ns_types import NSNodeTargetType
|
||||
from torch.ao.quantization.fx.backend_config.utils import get_native_quant_patterns
|
||||
from torch.ao.quantization.fx.backend_config_utils import get_native_quant_patterns
|
||||
from torch.ao.quantization import (
|
||||
ObserverBase,
|
||||
FakeQuantizeBase,
|
||||
|
@ -4,3 +4,8 @@ from .native import get_native_backend_config_dict
|
||||
# TODO: add more validations
|
||||
def validate_backend_config_dict(backend_config_dict):
|
||||
return "configs" in backend_config_dict
|
||||
|
||||
__all__ = [
|
||||
"get_native_backend_config_dict",
|
||||
"get_tensorrt_backend_config_dict",
|
||||
]
|
@ -2,19 +2,19 @@ from collections import namedtuple
|
||||
from typing import List, Dict, Any
|
||||
import operator
|
||||
import torch
|
||||
from .observation_type import ObservationType
|
||||
from torch.ao.quantization.backend_config.observation_type import ObservationType
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch.nn.intrinsic as nni
|
||||
import torch.nn.intrinsic.qat as nniqat
|
||||
import torch.nn.qat as nnqat
|
||||
import torch.nn.quantized._reference as nnqr
|
||||
from ...observer import (
|
||||
from ..observer import (
|
||||
default_affine_fixed_qparams_observer,
|
||||
default_symmetric_fixed_qparams_observer,
|
||||
)
|
||||
from ...fake_quantize import FixedQParamsFakeQuantize
|
||||
from ...fuser_method_mappings import (
|
||||
from ..fake_quantize import FixedQParamsFakeQuantize
|
||||
from ..fuser_method_mappings import (
|
||||
reverse_sequential_wrapper2,
|
||||
reverse2,
|
||||
reverse3,
|
||||
@ -711,3 +711,7 @@ def get_native_backend_config_dict():
|
||||
*_get_embedding_op_configs(),
|
||||
],
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"get_native_backend_config_dict",
|
||||
]
|
@ -4,7 +4,7 @@ import torch.nn.qat as nnqat
|
||||
import torch.nn.intrinsic as nni
|
||||
import torch.nn.intrinsic.qat as nniqat
|
||||
|
||||
from ...fuser_method_mappings import reverse_sequential_wrapper2
|
||||
from ..fuser_method_mappings import reverse_sequential_wrapper2
|
||||
|
||||
def get_tensorrt_backend_config_dict():
|
||||
""" Get the backend config dictionary for tensorrt backend
|
||||
@ -218,3 +218,7 @@ def get_tensorrt_backend_config_dict():
|
||||
identity_config,
|
||||
]
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"get_tensorrt_backend_config_dict",
|
||||
]
|
@ -1,41 +1,8 @@
|
||||
from typing import Dict, Any, List, Callable, Union, Tuple
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.utils import get_combined_dict
|
||||
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict
|
||||
import torch.nn as nn
|
||||
from .quantize_handler import get_quantize_handler_cls
|
||||
from .fuse_handler import get_fuse_handler_cls
|
||||
from .native import get_native_backend_config_dict
|
||||
from ..quantization_types import Pattern, QuantizerCls
|
||||
|
||||
def get_pattern_to_quantize_handlers(
|
||||
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, QuantizerCls]:
|
||||
"""
|
||||
Note: Quantize handler is just a holder for some check methods like
|
||||
(should_insert_observer_for_output), maybe this can be a enum as well,
|
||||
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
|
||||
new path, this is not exposed to backend developers
|
||||
"""
|
||||
pattern_to_quantize_handlers = dict()
|
||||
for config in backend_config_dict.get("configs", []):
|
||||
pattern = config["pattern"]
|
||||
observation_type = config.get("observation_type", None)
|
||||
dtype_configs = config["dtype_configs"]
|
||||
num_tensor_args_to_observation_type = config.get("num_tensor_args_to_observation_type", {})
|
||||
overwrite_fake_quantizer = config.get("_overwrite_output_fake_quantizer", None)
|
||||
overwrite_observer = config.get("_overwrite_output_observer", None)
|
||||
input_output_observed = config.get("_input_output_observed", True)
|
||||
pattern_to_quantize_handlers[pattern] = \
|
||||
get_quantize_handler_cls(
|
||||
observation_type,
|
||||
dtype_configs,
|
||||
num_tensor_args_to_observation_type,
|
||||
overwrite_fake_quantizer,
|
||||
overwrite_observer,
|
||||
input_output_observed)
|
||||
|
||||
return pattern_to_quantize_handlers
|
||||
from ..quantization_types import Pattern
|
||||
|
||||
def get_pattern_to_dtype_configs(
|
||||
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, List[Dict[str, Any]]]:
|
||||
@ -83,17 +50,6 @@ def get_root_module_to_quantized_reference_module(
|
||||
mapping[config["root_module"]] = config["reference_quantized_module_for_root"]
|
||||
return mapping
|
||||
|
||||
def get_fusion_pattern_to_fuse_handler_cls(
|
||||
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Callable]:
|
||||
fusion_pattern_to_fuse_handlers = dict()
|
||||
for config in backend_config_dict.get("configs", []):
|
||||
if "fuser_method" in config:
|
||||
pattern = config["pattern"]
|
||||
fusion_pattern_to_fuse_handlers[pattern] = \
|
||||
get_fuse_handler_cls()
|
||||
|
||||
return fusion_pattern_to_fuse_handlers
|
||||
|
||||
def get_fuser_method_mapping(
|
||||
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Union[nn.Sequential, Callable]]:
|
||||
fuser_method_mapping : Dict[Pattern, Union[nn.Sequential, Callable]] = dict()
|
||||
@ -159,19 +115,3 @@ def get_fusion_pattern_to_extra_inputs_getter(
|
||||
extra_inputs_getter_mapping[pattern] = extra_inputs_getter
|
||||
|
||||
return extra_inputs_getter_mapping
|
||||
|
||||
# TODO: remove when all uses are changed to backend_config_dict
|
||||
def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]:
|
||||
"""
|
||||
Return a map from pattern to quantize handlers based on the default patterns and the native backend_config_dict.
|
||||
The returned map is sorted such that longer patterns will be encountered first when iterating through it.
|
||||
"""
|
||||
patterns = get_default_quant_patterns()
|
||||
if additional_quant_patterns is not None:
|
||||
patterns = get_combined_dict(patterns, additional_quant_patterns)
|
||||
# TODO: currently we just extend the quantize handlers generated from
|
||||
# `get_native_backend_config_dict`
|
||||
# in the future we can just assign backend_config_dict when everything is defined
|
||||
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config_dict()).items():
|
||||
patterns[pattern] = quantize_handler
|
||||
return sorted_patterns_dict(patterns)
|
@ -1,4 +1,3 @@
|
||||
from .prepare import prepare
|
||||
from .convert import convert
|
||||
from .fuse import fuse
|
||||
from .backend_config import get_tensorrt_backend_config_dict
|
||||
|
@ -1,5 +0,0 @@
|
||||
from ..fusion_patterns import DefaultFuseHandler
|
||||
|
||||
# TODO: move DefaultFuseHandler
|
||||
def get_fuse_handler_cls():
|
||||
return DefaultFuseHandler
|
@ -1,67 +0,0 @@
|
||||
import torch
|
||||
from typing import Dict, Callable, Any, Optional
|
||||
from .observation_type import ObservationType
|
||||
from ..quantization_patterns import QuantizeHandler
|
||||
from ..quantization_types import Pattern, NodePattern
|
||||
from ...utils import (
|
||||
activation_dtype,
|
||||
)
|
||||
|
||||
def get_quantize_handler_cls(
|
||||
observation_type,
|
||||
dtype_configs,
|
||||
num_tensor_args_to_observation_type,
|
||||
overwrite_output_fake_quantizer,
|
||||
overwrite_output_observer,
|
||||
input_output_observed):
|
||||
|
||||
class ConfigurableQuantizeHandler(QuantizeHandler):
|
||||
def __init__(
|
||||
self,
|
||||
node_pattern: NodePattern,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
root_node_getter: Callable = None):
|
||||
super().__init__(node_pattern, modules, root_node_getter)
|
||||
if num_tensor_args_to_observation_type:
|
||||
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
|
||||
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
|
||||
f" in num_tensor_args_to_observation_type for {node_pattern}"
|
||||
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
|
||||
else:
|
||||
self.observation_type = observation_type
|
||||
self.dtype_configs = dtype_configs
|
||||
self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer
|
||||
self.overwrite_output_observer = overwrite_output_observer
|
||||
self.input_output_observed_ = input_output_observed
|
||||
|
||||
def is_general_tensor_value_op(self) -> bool:
|
||||
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
|
||||
|
||||
# TODO: change this to output activation
|
||||
def get_activation_ctr(
|
||||
self,
|
||||
qconfig: Any,
|
||||
pattern: Pattern,
|
||||
is_training: bool,
|
||||
) -> Optional[Callable]:
|
||||
"""
|
||||
Returns the constructor for the activation observer which should be
|
||||
used for the pattern matched to this handler. Some handlers override
|
||||
this to a different value than what is specified in the qconfig.
|
||||
"""
|
||||
act_dtype = activation_dtype(qconfig)
|
||||
# TODO: change to is_qat
|
||||
if is_training:
|
||||
if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None:
|
||||
return self.overwrite_output_fake_quantizer
|
||||
else:
|
||||
if act_dtype == torch.quint8 and self.overwrite_output_observer is not None:
|
||||
return self.overwrite_output_observer
|
||||
return qconfig.activation
|
||||
|
||||
# This is temporary, and will be removed soon
|
||||
def input_output_observed(self):
|
||||
return self.input_output_observed_
|
||||
|
||||
|
||||
return ConfigurableQuantizeHandler
|
141
torch/ao/quantization/fx/backend_config_utils.py
Normal file
141
torch/ao/quantization/fx/backend_config_utils.py
Normal file
@ -0,0 +1,141 @@
|
||||
import torch
|
||||
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict
|
||||
from torch.ao.quantization.backend_config import get_native_backend_config_dict
|
||||
from torch.ao.quantization.backend_config.observation_type import ObservationType
|
||||
from torch.ao.quantization.quantization_types import (
|
||||
Pattern,
|
||||
NodePattern,
|
||||
QuantizerCls,
|
||||
)
|
||||
from torch.ao.quantization.utils import (
|
||||
activation_dtype,
|
||||
get_combined_dict,
|
||||
)
|
||||
|
||||
from .quantization_patterns import QuantizeHandler
|
||||
from .fusion_patterns import DefaultFuseHandler
|
||||
|
||||
from typing import Dict, Any, Callable, Optional
|
||||
|
||||
def get_quantize_handler_cls(
|
||||
observation_type,
|
||||
dtype_configs,
|
||||
num_tensor_args_to_observation_type,
|
||||
overwrite_output_fake_quantizer,
|
||||
overwrite_output_observer,
|
||||
input_output_observed):
|
||||
|
||||
class ConfigurableQuantizeHandler(QuantizeHandler):
|
||||
def __init__(
|
||||
self,
|
||||
node_pattern: NodePattern,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
root_node_getter: Callable = None):
|
||||
super().__init__(node_pattern, modules, root_node_getter)
|
||||
if num_tensor_args_to_observation_type:
|
||||
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
|
||||
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
|
||||
f" in num_tensor_args_to_observation_type for {node_pattern}"
|
||||
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
|
||||
else:
|
||||
self.observation_type = observation_type
|
||||
self.dtype_configs = dtype_configs
|
||||
self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer
|
||||
self.overwrite_output_observer = overwrite_output_observer
|
||||
self.input_output_observed_ = input_output_observed
|
||||
|
||||
def is_general_tensor_value_op(self) -> bool:
|
||||
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
|
||||
|
||||
# TODO: change this to output activation
|
||||
def get_activation_ctr(
|
||||
self,
|
||||
qconfig: Any,
|
||||
pattern: Pattern,
|
||||
is_training: bool,
|
||||
) -> Optional[Callable]:
|
||||
"""
|
||||
Returns the constructor for the activation observer which should be
|
||||
used for the pattern matched to this handler. Some handlers override
|
||||
this to a different value than what is specified in the qconfig.
|
||||
"""
|
||||
act_dtype = activation_dtype(qconfig)
|
||||
# TODO: change to is_qat
|
||||
if is_training:
|
||||
if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None:
|
||||
return self.overwrite_output_fake_quantizer
|
||||
else:
|
||||
if act_dtype == torch.quint8 and self.overwrite_output_observer is not None:
|
||||
return self.overwrite_output_observer
|
||||
return qconfig.activation
|
||||
|
||||
# This is temporary, and will be removed soon
|
||||
def input_output_observed(self):
|
||||
return self.input_output_observed_
|
||||
|
||||
|
||||
return ConfigurableQuantizeHandler
|
||||
|
||||
def get_pattern_to_quantize_handlers(
|
||||
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, QuantizerCls]:
|
||||
"""
|
||||
Note: Quantize handler is just a holder for some check methods like
|
||||
(should_insert_observer_for_output), maybe this can be a enum as well,
|
||||
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
|
||||
new path, this is not exposed to backend developers
|
||||
"""
|
||||
pattern_to_quantize_handlers = dict()
|
||||
for config in backend_config_dict.get("configs", []):
|
||||
pattern = config["pattern"]
|
||||
observation_type = config.get("observation_type", None)
|
||||
dtype_configs = config["dtype_configs"]
|
||||
num_tensor_args_to_observation_type = config.get("num_tensor_args_to_observation_type", {})
|
||||
overwrite_fake_quantizer = config.get("_overwrite_output_fake_quantizer", None)
|
||||
overwrite_observer = config.get("_overwrite_output_observer", None)
|
||||
input_output_observed = config.get("_input_output_observed", True)
|
||||
pattern_to_quantize_handlers[pattern] = \
|
||||
get_quantize_handler_cls(
|
||||
observation_type,
|
||||
dtype_configs,
|
||||
num_tensor_args_to_observation_type,
|
||||
overwrite_fake_quantizer,
|
||||
overwrite_observer,
|
||||
input_output_observed)
|
||||
|
||||
return pattern_to_quantize_handlers
|
||||
|
||||
def get_fusion_pattern_to_fuse_handler_cls(
|
||||
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Callable]:
|
||||
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = dict()
|
||||
for config in backend_config_dict.get("configs", []):
|
||||
if "fuser_method" in config:
|
||||
pattern = config["pattern"]
|
||||
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
|
||||
|
||||
return fusion_pattern_to_fuse_handlers
|
||||
|
||||
# TODO: remove when all uses are changed to backend_config_dict
|
||||
def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]:
|
||||
"""
|
||||
Return a map from pattern to quantize handlers based on the default patterns and the native backend_config_dict.
|
||||
The returned map is sorted such that longer patterns will be encountered first when iterating through it.
|
||||
"""
|
||||
patterns = get_default_quant_patterns()
|
||||
if additional_quant_patterns is not None:
|
||||
patterns = get_combined_dict(patterns, additional_quant_patterns)
|
||||
# TODO: currently we just extend the quantize handlers generated from
|
||||
# `get_native_backend_config_dict`
|
||||
# in the future we can just assign backend_config_dict when everything is defined
|
||||
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config_dict()).items():
|
||||
patterns[pattern] = quantize_handler
|
||||
return sorted_patterns_dict(patterns)
|
||||
|
||||
get_fusion_pattern_to_fuse_handler_cls.__module__ = "torch.ao.quantization.fx.backend_config_utils"
|
||||
get_native_quant_patterns.__module__ = "torch.ao.quantization.fx.backend_config_utils"
|
||||
get_pattern_to_quantize_handlers.__module__ = "torch.ao.quantization.fx.backend_config_utils"
|
||||
|
||||
__all__ = [
|
||||
"get_fusion_pattern_to_fuse_handler_cls",
|
||||
"get_native_quant_patterns",
|
||||
"get_pattern_to_quantize_handlers",
|
||||
]
|
@ -31,13 +31,13 @@ from .qconfig_utils import (
|
||||
update_qconfig_for_fusion,
|
||||
is_qconfig_supported_by_dtype_configs,
|
||||
)
|
||||
from .backend_config.utils import (
|
||||
from torch.ao.quantization.backend_config.utils import (
|
||||
get_root_module_to_quantized_reference_module,
|
||||
get_pattern_to_dtype_configs,
|
||||
get_fused_module_classes,
|
||||
get_qat_module_classes,
|
||||
)
|
||||
from .backend_config import get_native_backend_config_dict
|
||||
from torch.ao.quantization.backend_config import get_native_backend_config_dict
|
||||
from .graph_module import (
|
||||
QuantizedGraphModule,
|
||||
is_observed_module,
|
||||
|
@ -15,17 +15,17 @@ from .pattern_utils import (
|
||||
sorted_patterns_dict,
|
||||
)
|
||||
|
||||
from .backend_config.utils import get_fusion_pattern_to_fuse_handler_cls
|
||||
from .backend_config.utils import get_fuser_method_mapping
|
||||
from .backend_config.utils import get_fusion_pattern_to_root_node_getter
|
||||
from .backend_config.utils import get_fusion_pattern_to_extra_inputs_getter
|
||||
from .backend_config import get_native_backend_config_dict
|
||||
from ..backend_config.utils import get_fuser_method_mapping
|
||||
from ..backend_config.utils import get_fusion_pattern_to_root_node_getter
|
||||
from ..backend_config.utils import get_fusion_pattern_to_extra_inputs_getter
|
||||
from ..backend_config import get_native_backend_config_dict
|
||||
from .backend_config_utils import get_fusion_pattern_to_fuse_handler_cls
|
||||
|
||||
from .fusion_patterns import * # noqa: F401,F403
|
||||
|
||||
from typing import Callable, Tuple, Dict, Any, Optional, List
|
||||
|
||||
from .quantization_types import Pattern, NodePattern
|
||||
from torch.ao.quantization.quantization_types import Pattern, NodePattern
|
||||
|
||||
def fuse(
|
||||
model: GraphModule,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch.fx.graph import Node, Graph
|
||||
from ..utils import _parent_name
|
||||
from .quantization_types import NodePattern, Pattern
|
||||
from torch.ao.quantization.quantization_types import NodePattern, Pattern
|
||||
from ..fuser_method_mappings import get_fuser_method_new
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
|
@ -4,7 +4,7 @@ from torch.fx.graph import (
|
||||
Graph,
|
||||
Node,
|
||||
)
|
||||
from .quantization_types import Pattern
|
||||
from torch.ao.quantization.quantization_types import Pattern
|
||||
from .quantization_patterns import (
|
||||
QuantizeHandler,
|
||||
)
|
||||
|
@ -3,7 +3,7 @@ from typing import Dict, Any, Tuple, List, Optional
|
||||
from torch.fx.graph import (
|
||||
Node,
|
||||
)
|
||||
from .quantization_types import Pattern
|
||||
from torch.ao.quantization.quantization_types import Pattern
|
||||
from ..qconfig import QConfigAny
|
||||
from ..fake_quantize import FixedQParamsFakeQuantize
|
||||
# from .quantization_patterns import BinaryOpQuantizeHandler
|
||||
|
@ -32,7 +32,7 @@ from .quantization_patterns import (
|
||||
QuantizeHandler,
|
||||
)
|
||||
|
||||
from .quantization_types import (
|
||||
from torch.ao.quantization.quantization_types import (
|
||||
Pattern,
|
||||
NodePattern
|
||||
)
|
||||
@ -80,16 +80,18 @@ from ..utils import (
|
||||
activation_is_int8_quantized,
|
||||
)
|
||||
|
||||
from .backend_config.utils import (
|
||||
get_pattern_to_quantize_handlers,
|
||||
from ..backend_config.utils import (
|
||||
get_pattern_to_dtype_configs,
|
||||
get_pattern_to_input_type_to_index,
|
||||
get_module_to_qat_module,
|
||||
get_fusion_pattern_to_root_node_getter,
|
||||
)
|
||||
from .backend_config import (
|
||||
from ..backend_config import (
|
||||
get_native_backend_config_dict,
|
||||
)
|
||||
from .backend_config_utils import (
|
||||
get_pattern_to_quantize_handlers,
|
||||
)
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set
|
||||
from collections import defaultdict
|
||||
|
@ -6,7 +6,7 @@ from torch.fx.graph import (
|
||||
from .utils import (
|
||||
all_node_args_have_no_tensors,
|
||||
)
|
||||
from .quantization_types import (
|
||||
from torch.ao.quantization.quantization_types import (
|
||||
Pattern,
|
||||
NodePattern,
|
||||
)
|
||||
|
@ -1,6 +1,8 @@
|
||||
# TODO: the name of this file is probably confusing, remove this file and move the type
|
||||
# definitions to somewhere else, e.g. to .utils
|
||||
from typing import Any, Tuple, Union
|
||||
from torch.fx import Node
|
||||
from ..utils import Pattern # noqa: F401
|
||||
from .utils import Pattern # noqa: F401
|
||||
|
||||
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
|
||||
|
||||
@ -8,3 +10,9 @@ NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
|
||||
# Define separately to prevent circular imports.
|
||||
# TODO(future PR): improve this.
|
||||
QuantizerCls = Any
|
||||
|
||||
__all__ = [
|
||||
"Pattern",
|
||||
"NodePattern",
|
||||
"QuantizerCls",
|
||||
]
|
@ -8,7 +8,7 @@ from torch.nn.intrinsic import _FusedModule
|
||||
from .fx import fuse # noqa: F401
|
||||
from .fx import prepare # noqa: F401
|
||||
from .fx.convert import convert
|
||||
from .fx import get_tensorrt_backend_config_dict # noqa: F401
|
||||
from .backend_config import get_tensorrt_backend_config_dict # noqa: F401
|
||||
from .fx.graph_module import ObservedGraphModule
|
||||
from .fx.qconfig_utils import (
|
||||
check_is_valid_convert_custom_config_dict,
|
||||
@ -59,7 +59,7 @@ def _fuse_fx(
|
||||
"""
|
||||
_check_is_graph_module(graph_module)
|
||||
return fuse(
|
||||
graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)
|
||||
graph_module, is_qat, fuse_custom_config_dict, backend_config_dict) # type: ignore[operator]
|
||||
|
||||
|
||||
class Scope(object):
|
||||
@ -251,7 +251,7 @@ forward graph of the parent module,
|
||||
equalization_qconfig_dict=equalization_qconfig_dict,
|
||||
backend_config_dict=backend_config_dict,
|
||||
is_standalone_module=is_standalone_module,
|
||||
)
|
||||
) # type: ignore[operator]
|
||||
|
||||
for attr_name in preserved_attributes:
|
||||
setattr(prepared, attr_name, getattr(model, attr_name))
|
||||
@ -643,7 +643,7 @@ def convert_fx(
|
||||
operators should be quantized in the backend, this includes quantization
|
||||
mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
|
||||
observer placement for each operators and fused operators. Detailed
|
||||
documentation can be found in torch/ao/quantization/fx/backend_config/README.md
|
||||
documentation can be found in torch/ao/quantization/backend_config/README.md
|
||||
|
||||
Return:
|
||||
A quantized model (GraphModule)
|
||||
|
@ -15,3 +15,21 @@ from torch.ao.quantization.fx.pattern_utils import (
|
||||
get_default_quant_patterns,
|
||||
get_default_output_activation_post_process_map
|
||||
)
|
||||
|
||||
# QuantizeHandler.__module__ = _NAMESPACE
|
||||
MatchResult.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
register_fusion_pattern.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
get_default_fusion_patterns.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
register_quant_pattern.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
get_default_quant_patterns.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
get_default_output_activation_post_process_map.__module__ = "torch.quantization.fx.pattern_utils"
|
||||
|
||||
# __all__ = [
|
||||
# "QuantizeHandler",
|
||||
# "MatchResult",
|
||||
# "register_fusion_pattern",
|
||||
# "get_default_fusion_patterns",
|
||||
# "register_quant_pattern",
|
||||
# "get_default_quant_patterns",
|
||||
# "get_default_output_activation_post_process_map",
|
||||
# ]
|
||||
|
@ -22,3 +22,18 @@ from torch.ao.quantization.fx.quantization_patterns import (
|
||||
GeneralTensorShapeOpQuantizeHandler,
|
||||
StandaloneModuleQuantizeHandler
|
||||
)
|
||||
|
||||
QuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
BinaryOpQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
CatQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
ConvReluQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
LinearReLUQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
BatchNormQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
EmbeddingQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
RNNDynamicQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
DefaultNodeQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
FixedQParamsOpQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
CopyNodeQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
CustomModuleQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
GeneralTensorShapeOpQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
StandaloneModuleQuantizeHandler.__module__ = "torch.quantization.fx.quantization_patterns"
|
||||
|
@ -6,7 +6,7 @@ If you are adding a new entry/functionality, please, add it to the
|
||||
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
from torch.ao.quantization.fx.quantization_types import (
|
||||
from torch.ao.quantization.quantization_types import (
|
||||
Pattern,
|
||||
QuantizerCls
|
||||
)
|
||||
|
Reference in New Issue
Block a user