[Quant][fx] Implement BackendConfig (part 1) (#81469)

Summary: Following https://github.com/pytorch/pytorch/pull/78452
and https://github.com/pytorch/pytorch/pull/79066, this commit
is part 1 of the broader effort to replace `backend_config_dict`
with a python config object, a more formal and robust API that
leads to better user experience. Note that there is no change in
behavior in this commit by itself. A future commit (part 2) will
replace all existing usages of `backend_config_dict` with the
`BackendConfig` object added in this commit.

Test Plan:
python test/test_quantization.py TestBackendConfig

Reviewers: jerryzh168

Subscribers: jerryzh168
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81469
Approved by: https://github.com/jerryzh168
This commit is contained in:
Andrew Or
2022-07-22 19:12:32 -07:00
committed by PyTorch MergeBot
parent 1a18ff3247
commit 194255bb56
5 changed files with 714 additions and 1 deletions

View File

@ -0,0 +1,319 @@
# Owner(s): ["oncall: quantization"]
import torch
import torch.nn.intrinsic as nni
import torch.nn.qat as nnqat
import torch.nn.quantized._reference as nnqr
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.ao.quantization.backend_config import (
BackendConfig,
BackendPatternConfig,
DTypeConfig,
ObservationType,
)
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter
from torch.ao.quantization.observer import default_fixed_qparams_range_0to1_observer
class TestBackendConfig(QuantizationTestCase):
# =============
# DTypeConfig
# =============
dtype_config1 = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float
)
dtype_config2 = DTypeConfig(
input_dtype=torch.float16,
output_dtype=torch.float,
is_dynamic=True
)
dtype_config_dict1 = {
"input_dtype": torch.quint8,
"output_dtype": torch.quint8,
"weight_dtype": torch.qint8,
"bias_dtype": torch.float,
}
dtype_config_dict2 = {
"input_dtype": torch.float16,
"output_dtype": torch.float,
"is_dynamic": True,
}
def test_dtype_config_from_dict(self):
self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict1), self.dtype_config1)
self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict2), self.dtype_config2)
def test_dtype_config_to_dict(self):
self.assertEqual(self.dtype_config1.to_dict(), self.dtype_config_dict1)
self.assertEqual(self.dtype_config2.to_dict(), self.dtype_config_dict2)
# ======================
# BackendPatternConfig
# ======================
_fuser_method = reverse_sequential_wrapper2(nni.LinearReLU)
_num_tensor_args_to_observation_type = {
0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
}
_input_type_to_index = {
"bias": 0,
"input": 1,
"weight": 2,
}
_fake_quantize = FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)
def _extra_inputs_getter(self, p):
return (torch.rand(3, 3),)
def _get_backend_op_config1(self):
return BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) \
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(self.dtype_config1) \
.add_dtype_config(self.dtype_config2) \
.set_root_module(torch.nn.Linear) \
.set_qat_module(nnqat.Linear) \
.set_reference_quantized_module(nnqr.Linear) \
.set_fused_module(nni.LinearReLU) \
.set_fuser_method(self._fuser_method)
def _get_backend_op_config2(self):
return BackendPatternConfig(torch.add) \
.add_dtype_config(self.dtype_config2) \
._set_root_node_getter(_default_root_node_getter) \
._set_extra_inputs_getter(self._extra_inputs_getter) \
._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \
._set_input_type_to_index(self._input_type_to_index) \
._set_input_output_observed(False) \
._set_overwrite_output_fake_quantize(self._fake_quantize) \
._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer)
def _get_backend_pattern_config_dict1(self):
return {
"pattern": (torch.nn.ReLU, torch.nn.Linear),
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [self.dtype_config_dict1, self.dtype_config_dict2],
"root_module": torch.nn.Linear,
"qat_module": nnqat.Linear,
"reference_quantized_module_for_root": nnqr.Linear,
"fused_module": nni.LinearReLU,
"fuser_method": self._fuser_method,
}
def _get_backend_pattern_config_dict2(self):
return {
"pattern": torch.add,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [self.dtype_config_dict2],
"root_node_getter": _default_root_node_getter,
"extra_inputs_getter": self._extra_inputs_getter,
"num_tensor_args_to_observation_type": self._num_tensor_args_to_observation_type,
"input_type_to_index": self._input_type_to_index,
"input_output_observed": False,
"overwrite_output_fake_quantize": self._fake_quantize,
"overwrite_output_observer": default_fixed_qparams_range_0to1_observer
}
def test_backend_op_config_set_observation_type(self):
conf = BackendPatternConfig(torch.nn.Linear)
self.assertEqual(conf.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
conf.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
self.assertEqual(conf.observation_type, ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
def test_backend_op_config_add_dtype_config(self):
conf = BackendPatternConfig(torch.nn.Linear)
self.assertEqual(len(conf.dtype_configs), 0)
conf.add_dtype_config(self.dtype_config1)
conf.add_dtype_config(self.dtype_config2)
self.assertEqual(len(conf.dtype_configs), 2)
self.assertEqual(conf.dtype_configs[0], self.dtype_config1)
self.assertEqual(conf.dtype_configs[1], self.dtype_config2)
def test_backend_op_config_set_root_module(self):
conf = BackendPatternConfig(nni.LinearReLU)
self.assertTrue(conf.root_module is None)
conf.set_root_module(torch.nn.Linear)
self.assertEqual(conf.root_module, torch.nn.Linear)
def test_backend_op_config_set_qat_module(self):
conf = BackendPatternConfig(torch.nn.Linear)
self.assertTrue(conf.qat_module is None)
conf.set_qat_module(nnqat.Linear)
self.assertEqual(conf.qat_module, nnqat.Linear)
def test_backend_op_config_set_reference_quantized_module(self):
conf = BackendPatternConfig(torch.nn.Linear)
self.assertTrue(conf.reference_quantized_module is None)
conf.set_reference_quantized_module(nnqr.Linear)
self.assertEqual(conf.reference_quantized_module, nnqr.Linear)
def test_backend_op_config_set_fused_module(self):
conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
self.assertTrue(conf.fused_module is None)
conf.set_fused_module(nni.LinearReLU)
self.assertEqual(conf.fused_module, nni.LinearReLU)
def test_backend_op_config_set_fuser_method(self):
conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
self.assertTrue(conf.fuser_method is None)
conf.set_fuser_method(self._fuser_method)
self.assertEqual(conf.fuser_method, self._fuser_method)
def test_backend_op_config_set_root_node_getter(self):
conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
self.assertTrue(conf._root_node_getter is None)
conf._set_root_node_getter(_default_root_node_getter)
self.assertEqual(conf._root_node_getter, _default_root_node_getter)
def test_backend_op_config_set_extra_inputs_getter(self):
conf = BackendPatternConfig(torch.nn.Linear)
self.assertTrue(conf._extra_inputs_getter is None)
conf._set_extra_inputs_getter(self._extra_inputs_getter)
self.assertEqual(conf._extra_inputs_getter, self._extra_inputs_getter)
def test_backend_op_config_set_num_tensor_args_to_observation_type(self):
conf = BackendPatternConfig(torch.add)
self.assertEqual(len(conf._num_tensor_args_to_observation_type), 0)
conf._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type)
self.assertEqual(conf._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type)
def test_backend_op_config_set_input_type_to_index(self):
conf = BackendPatternConfig(torch.addmm)
self.assertEqual(len(conf._input_type_to_index), 0)
conf._set_input_type_to_index(self._input_type_to_index)
self.assertEqual(conf._input_type_to_index, self._input_type_to_index)
def test_backend_op_config_set_input_output_observed(self):
conf = BackendPatternConfig(torch.nn.Embedding)
self.assertTrue(conf._input_output_observed is None)
conf._set_input_output_observed(False)
self.assertEqual(conf._input_output_observed, False)
def test_backend_op_config_set_overwrite_output_fake_quantize(self):
conf = BackendPatternConfig(torch.sigmoid)
self.assertTrue(conf._overwrite_output_fake_quantize is None)
conf._set_overwrite_output_fake_quantize(self._fake_quantize)
self.assertEqual(conf._overwrite_output_fake_quantize, self._fake_quantize)
def test_backend_op_config_set_overwrite_output_observer(self):
conf = BackendPatternConfig(torch.sigmoid)
self.assertTrue(conf._overwrite_output_observer is None)
conf._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer)
self.assertEqual(conf._overwrite_output_observer, default_fixed_qparams_range_0to1_observer)
def test_backend_op_config_from_dict(self):
conf_dict1 = self._get_backend_pattern_config_dict1()
conf1 = BackendPatternConfig.from_dict(conf_dict1)
self.assertEqual(conf1.pattern, (torch.nn.ReLU, torch.nn.Linear))
self.assertEqual(conf1.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
self.assertEqual(conf1.root_module, torch.nn.Linear)
self.assertEqual(conf1.qat_module, nnqat.Linear)
self.assertEqual(conf1.reference_quantized_module, nnqr.Linear)
self.assertEqual(conf1.fused_module, nni.LinearReLU)
self.assertEqual(conf1.fuser_method, self._fuser_method)
self.assertTrue(conf1._root_node_getter is None)
self.assertTrue(conf1._extra_inputs_getter is None)
self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0)
self.assertEqual(len(conf1._input_type_to_index), 0)
self.assertTrue(conf1._input_output_observed is None)
self.assertTrue(conf1._overwrite_output_fake_quantize is None)
self.assertTrue(conf1._overwrite_output_observer is None)
# Test temporary/internal keys
conf_dict2 = self._get_backend_pattern_config_dict2()
conf2 = BackendPatternConfig.from_dict(conf_dict2)
self.assertEqual(conf2.pattern, torch.add)
self.assertEqual(conf2.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
self.assertTrue(conf2.root_module is None)
self.assertTrue(conf2.qat_module is None)
self.assertTrue(conf2.reference_quantized_module is None)
self.assertTrue(conf2.fused_module is None)
self.assertTrue(conf2.fuser_method is None)
self.assertEqual(conf2._root_node_getter, _default_root_node_getter)
self.assertEqual(conf2._extra_inputs_getter, self._extra_inputs_getter)
self.assertEqual(conf2._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type)
self.assertEqual(conf2._input_type_to_index, self._input_type_to_index)
self.assertEqual(conf2._input_output_observed, False)
self.assertEqual(conf2._overwrite_output_fake_quantize, self._fake_quantize)
self.assertEqual(conf2._overwrite_output_observer, default_fixed_qparams_range_0to1_observer)
def test_backend_op_config_to_dict(self):
conf1 = self._get_backend_op_config1()
conf2 = self._get_backend_op_config2()
conf_dict1 = self._get_backend_pattern_config_dict1()
conf_dict2 = self._get_backend_pattern_config_dict2()
self.assertEqual(conf1.to_dict(), conf_dict1)
self.assertEqual(conf2.to_dict(), conf_dict2)
# ===============
# BackendConfig
# ===============
def test_backend_config_set_name(self):
conf = BackendConfig("name1")
self.assertEqual(conf.name, "name1")
conf.set_name("name2")
self.assertEqual(conf.name, "name2")
def test_backend_config_set_backend_pattern_config(self):
conf = BackendConfig("name1")
self.assertEqual(len(conf.configs), 0)
backend_op_config1 = self._get_backend_op_config1()
backend_op_config2 = self._get_backend_op_config2()
conf.set_backend_pattern_config(backend_op_config1)
self.assertEqual(conf.configs, {
(torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
})
conf.set_backend_pattern_config(backend_op_config2)
self.assertEqual(conf.configs, {
(torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
torch.add: backend_op_config2
})
def test_backend_config_from_dict(self):
op1 = self._get_backend_op_config1()
op2 = self._get_backend_op_config2()
op_dict1 = self._get_backend_pattern_config_dict1()
op_dict2 = self._get_backend_pattern_config_dict2()
conf_dict = {
"name": "name1",
"configs": [op_dict1, op_dict2],
}
conf = BackendConfig.from_dict(conf_dict)
self.assertEqual(conf.name, "name1")
self.assertEqual(len(conf.configs), 2)
key1 = (torch.nn.ReLU, torch.nn.Linear)
key2 = torch.add
self.assertTrue(key1 in conf.configs)
self.assertTrue(key2 in conf.configs)
self.assertEqual(conf.configs[key1].to_dict(), op_dict1)
self.assertEqual(conf.configs[key2].to_dict(), op_dict2)
def test_backend_config_to_dict(self):
op1 = self._get_backend_op_config1()
op2 = self._get_backend_op_config2()
op_dict1 = self._get_backend_pattern_config_dict1()
op_dict2 = self._get_backend_pattern_config_dict2()
conf = BackendConfig("name1").set_backend_pattern_config(op1).set_backend_pattern_config(op2)
conf_dict = {
"name": "name1",
"configs": [op_dict1, op_dict2],
}
self.assertEqual(conf.to_dict(), conf_dict)
if __name__ == '__main__':
raise RuntimeError("This _test file is not meant to be run directly, use:\n\n"
"\tpython _test/_test_quantization.py TESTNAME\n\n"
"instead.")

View File

@ -1858,6 +1858,7 @@ class TestQuantizeFx(QuantizationTestCase):
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
# TODO: move QConfigMapping tests to test/quantization/core
def test_qconfig_mapping_set_global(self):
qconfig = get_default_qconfig()
qconfig_mapping = QConfigMapping()

View File

@ -36,6 +36,7 @@ from quantization.core.test_workflow_module import TestRecordHistogramObserver
from quantization.core.test_workflow_module import TestHistogramObserver # noqa: F401
from quantization.core.test_workflow_module import TestDistributed # noqa: F401
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
from quantization.core.test_backend_config import TestBackendConfig # noqa: F401
from quantization.core.test_utils import TestUtils # noqa: F401
from quantization.core.test_docs import TestQuantizationDocs # noqa: F401

View File

@ -1,5 +1,7 @@
from .tensorrt import get_tensorrt_backend_config_dict
from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig
from .native import get_native_backend_config_dict
from .observation_type import ObservationType
from .tensorrt import get_tensorrt_backend_config_dict
# TODO: add more validations
def validate_backend_config_dict(backend_config_dict):

View File

@ -0,0 +1,390 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.ao.quantization.backend_config.observation_type import ObservationType
from torch.ao.quantization.observer import _PartialWrapper
from torch.ao.quantization.utils import Pattern
__all__ = [
"BackendConfig",
"BackendPatternConfig",
"DTypeConfig",
]
# DTypeConfig dict keys
INPUT_DTYPE_DICT_KEY = "input_dtype"
OUTPUT_DTYPE_DICT_KEY = "output_dtype"
WEIGHT_DTYPE_DICT_KEY = "weight_dtype"
BIAS_DTYPE_DICT_KEY = "bias_dtype"
IS_DYNAMIC_DICT_KEY = "is_dynamic"
# BackendConfig dict keys
NAME_DICT_KEY = "name"
CONFIGS_DICT_KEY = "configs"
# BackendPatternConfig dict keys
PATTERN_DICT_KEY = "pattern"
OBSERVATION_TYPE_DICT_KEY = "observation_type"
DTYPE_CONFIGS_DICT_KEY = "dtype_configs"
ROOT_MODULE_DICT_KEY = "root_module"
QAT_MODULE_DICT_KEY = "qat_module"
REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root"
FUSED_MODULE_DICT_KEY = "fused_module"
FUSER_METHOD_DICT_KEY = "fuser_method"
ROOT_NODE_GETTER_DICT_KEY = "root_node_getter"
EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter"
NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type"
INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index"
INPUT_OUTPUT_OBSERVED_DICT_KEY = "input_output_observed"
OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY = "overwrite_output_fake_quantize"
OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer"
@dataclass
class DTypeConfig:
"""
Config for the set of supported input/output activation, weight, and bias data types for the
patterns defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`.
"""
input_dtype: Optional[torch.dtype] = None
output_dtype: Optional[torch.dtype] = None
weight_dtype: Optional[torch.dtype] = None
bias_dtype: Optional[torch.dtype] = None
is_dynamic: Optional[bool] = None
@classmethod
def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig:
"""
Create a `DTypeConfig` from a dictionary with the following items (all optional):
"input_dtype": torch.dtype
"output_dtype": torch.dtype
"weight_dtype": torch.dtype
"bias_type": torch.dtype
"is_dynamic": bool
"""
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
def to_dict(self) -> Dict[str, Any]:
"""
Convert this `DTypeConfig` to a dictionary with the items described in
:func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`.
"""
dtype_config_dict: Dict[str, Any] = {}
if self.input_dtype is not None:
dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype
if self.output_dtype is not None:
dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = self.output_dtype
if self.weight_dtype is not None:
dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = self.weight_dtype
if self.bias_dtype is not None:
dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype
if self.is_dynamic is not None:
dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic
return dtype_config_dict
class BackendConfig:
# TODO: refer to NativeBackendConfig once that is implemented
"""
Config that defines the set of patterns that can be quantized on a given backend, and how reference
quantized models can be produced from these patterns.
A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph
of the above. Each pattern supported on the target backend can be individually configured through
:class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of:
(1) The supported input/output activation, weight, and bias data types
(2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and
(3) (Optionally) Fusion, QAT, and reference module mappings.
The format of the patterns is described in:
https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
Example usage::
import torch
from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType
from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
weighted_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_type=torch.float)
linear_config = BackendPatternConfig(torch.nn.Linear) \
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(weighted_int8_dtype_config) \
.set_root_module(torch.nn.Linear) \
.set_qat_module(torch.nn.qat.Linear) \
.set_reference_quantized_module(torch.nn.quantized._reference.Linear)
conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(weighted_int8_dtype_config) \
.set_fused_module(torch.nn.intrinsic.ConvReLU2d) \
.set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d))
backend_config = BackendConfig("my_backend") \
.set_backend_pattern_config(linear_config) \
.set_backend_pattern_config(conv_relu_config)
"""
def __init__(self, name: str = ""):
self.name = name
self.configs: Dict[Pattern, BackendPatternConfig] = {}
def set_name(self, name: str) -> BackendConfig:
"""
Set the name of the target backend.
"""
self.name = name
return self
def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig:
"""
Set the config for an op that can be run on the target backend.
This overrides any existing config for the given op.
"""
self.configs[config.pattern] = config
return self
@classmethod
def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig:
"""
Create a `BackendConfig` from a dictionary with the following items:
"name": the name of the target backend
"configs": a list of dictionaries that each represents a `BackendPatternConfig`
"""
for dict_key in [NAME_DICT_KEY, CONFIGS_DICT_KEY]:
if dict_key not in backend_config_dict:
raise ValueError("backend_config_dict must contain '%s'" % dict_key)
conf = cls(backend_config_dict[NAME_DICT_KEY])
for d in backend_config_dict[CONFIGS_DICT_KEY]:
if isinstance(d, BackendPatternConfig):
conf.set_backend_pattern_config(d)
elif isinstance(d, Dict):
conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d))
else:
raise ValueError("Expected backend_config_dict['%s'] to be a dictionary" % CONFIGS_DICT_KEY)
return conf
def to_dict(self) -> Dict[str, Any]:
"""
Convert this `BackendConfig` to a dictionary with the items described in
:func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`.
"""
return {
NAME_DICT_KEY: self.name,
CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs.values()],
}
class BackendPatternConfig:
"""
Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`.
The user can configure how a operator pattern graph is handled on a given backend using the following methods:
`set_observation_type`: sets how observers should be inserted for this pattern.
See :class:`~torch.ao.quantization.backend_config.ObservationType`
`add_dtype_config`: add a set of supported data types for this pattern
`set_root_module`: sets the module that represents the root for this pattern
`set_qat_module`: sets the module that represents the QAT implementation for this pattern
`set_reference_quantized_module`: sets the module that represents the reference quantized
implementation for this pattern's root module.
`set_fused_module`: sets the module that represents the fused implementation for this pattern
`set_fuser_method`: sets the function that specifies how to fuse the pattern for this pattern
For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
"""
def __init__(self, pattern: Pattern):
self.pattern = pattern
self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
self.dtype_configs: List[DTypeConfig] = []
self.root_module: Optional[torch.nn.Module] = None
self.qat_module: Optional[torch.nn.Module] = None
self.reference_quantized_module: Optional[torch.nn.Module] = None
self.fused_module: Optional[torch.nn.Module] = None
self.fuser_method: Optional[Callable] = None
# Temporary/internal configs
self._root_node_getter: Optional[Callable] = None
self._extra_inputs_getter: Optional[Callable] = None
self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {}
self._input_type_to_index: Dict[str, int] = {}
self._input_output_observed: Optional[bool] = None
self._overwrite_output_fake_quantize: Optional[_PartialWrapper] = None
self._overwrite_output_observer: Optional[_PartialWrapper] = None
def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig:
"""
Set how observers should be inserted for this pattern.
"""
self.observation_type = observation_type
return self
def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig:
"""
Register a set of supported input/output activation, weight, and bias data types for this pattern.
"""
self.dtype_configs.append(dtype_config)
return self
def set_root_module(self, root_module: torch.nn.Module) -> BackendPatternConfig:
"""
Set the module that represents the root for this pattern.
For example, the root module for :class:`torch.nn.intrinsic.LinearReLU` should be :class:`torch.nn.Linear`.
"""
self.root_module = root_module
return self
def set_qat_module(self, qat_module: torch.nn.Module) -> BackendPatternConfig:
"""
Set the module that represents the QAT implementation for this pattern.
"""
self.qat_module = qat_module
return self
def set_reference_quantized_module(self, reference_quantized_module: torch.nn.Module) -> BackendPatternConfig:
"""
Set the module that represents the reference quantized implementation for this pattern's root module.
"""
self.reference_quantized_module = reference_quantized_module
return self
def set_fused_module(self, fused_module: torch.nn.Module) -> BackendPatternConfig:
"""
Set the module that represents the fused implementation for this pattern.
"""
self.fused_module = fused_module
return self
def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig:
"""
Set the function that specifies how to fuse the pattern for this pattern.
"""
self.fuser_method = fuser_method
return self
def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig:
self._root_node_getter = root_node_getter
return self
def _set_extra_inputs_getter(self, extra_inputs_getter: Callable) -> BackendPatternConfig:
self._extra_inputs_getter = extra_inputs_getter
return self
def _set_num_tensor_args_to_observation_type(
self, num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> BackendPatternConfig:
self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type
return self
def _set_input_type_to_index(self, input_type_to_index: Dict[str, int]) -> BackendPatternConfig:
self._input_type_to_index = input_type_to_index
return self
def _set_input_output_observed(self, input_output_observed: bool) -> BackendPatternConfig:
self._input_output_observed = input_output_observed
return self
def _set_overwrite_output_fake_quantize(self, overwrite_output_fake_quantize: _PartialWrapper) -> BackendPatternConfig:
self._overwrite_output_fake_quantize = overwrite_output_fake_quantize
return self
def _set_overwrite_output_observer(self, overwrite_output_observer: _PartialWrapper) -> BackendPatternConfig:
self._overwrite_output_observer = overwrite_output_observer
return self
@classmethod
def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig:
"""
Create a `BackendPatternConfig` from a dictionary with the following items:
"pattern": the pattern being configured
"observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how
observers should be inserted for this pattern
"dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig`s
"root_module": a :class:`torch.nn.Module` that represents the root for this pattern
"qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern
"reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized
implementation for this pattern's root module.
"fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern
"fuser_method": a function that specifies how to fuse the pattern for this pattern
"""
def _get_dtype_config(obj: Any) -> DTypeConfig:
"""
Convert the given object into a `DTypeConfig` if possible, else throw an exception.
"""
if isinstance(obj, DTypeConfig):
return obj
if isinstance(obj, Dict):
return DTypeConfig.from_dict(obj)
raise ValueError("Expected a list of DTypeConfigs in backend_pattern_config_dict[\"%s\"], got '%s'" %
(DTYPE_CONFIGS_DICT_KEY, type(obj)))
if PATTERN_DICT_KEY not in backend_pattern_config_dict:
raise ValueError("backend_pattern_config_dict must contain '%s'" % PATTERN_DICT_KEY)
conf = cls(backend_pattern_config_dict[PATTERN_DICT_KEY])
if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict:
conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY])
for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
conf.add_dtype_config(_get_dtype_config(d))
conf.set_root_module(backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None))
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None))
conf.set_reference_quantized_module(backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None))
conf.set_fused_module(backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None))
conf.set_fuser_method(backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None))
conf._set_root_node_getter(backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None))
conf._set_extra_inputs_getter(backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None))
conf._set_num_tensor_args_to_observation_type(
backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {}))
conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}))
conf._set_input_output_observed(backend_pattern_config_dict.get(INPUT_OUTPUT_OBSERVED_DICT_KEY, None))
conf._set_overwrite_output_fake_quantize(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY, None))
conf._set_overwrite_output_observer(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_OBSERVER_DICT_KEY, None))
return conf
def to_dict(self) -> Dict[str, Any]:
"""
Convert this `BackendPatternConfig` to a dictionary with the items described in
:func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`.
"""
backend_pattern_config_dict: Dict[str, Any] = {
PATTERN_DICT_KEY: self.pattern,
OBSERVATION_TYPE_DICT_KEY: self.observation_type,
DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs],
}
if self.root_module is not None:
backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module
if self.qat_module is not None:
backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module
if self.reference_quantized_module is not None:
backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = self.reference_quantized_module
if self.fused_module is not None:
backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module
if self.fuser_method is not None:
backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method
if self._root_node_getter is not None:
backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = self._root_node_getter
if self._extra_inputs_getter is not None:
backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = self._extra_inputs_getter
if len(self._num_tensor_args_to_observation_type) > 0:
backend_pattern_config_dict[NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY] = self._num_tensor_args_to_observation_type
if len(self._input_type_to_index) > 0:
backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index
if self._input_output_observed is not None:
backend_pattern_config_dict[INPUT_OUTPUT_OBSERVED_DICT_KEY] = self._input_output_observed
if self._overwrite_output_fake_quantize is not None:
backend_pattern_config_dict[OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY] = self._overwrite_output_fake_quantize
if self._overwrite_output_observer is not None:
backend_pattern_config_dict[OVERWRITE_OUTPUT_OBSERVER_DICT_KEY] = self._overwrite_output_observer
return backend_pattern_config_dict