mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
[Quant][fx] Implement BackendConfig (part 1) (#81469)
Summary: Following https://github.com/pytorch/pytorch/pull/78452 and https://github.com/pytorch/pytorch/pull/79066, this commit is part 1 of the broader effort to replace `backend_config_dict` with a python config object, a more formal and robust API that leads to better user experience. Note that there is no change in behavior in this commit by itself. A future commit (part 2) will replace all existing usages of `backend_config_dict` with the `BackendConfig` object added in this commit. Test Plan: python test/test_quantization.py TestBackendConfig Reviewers: jerryzh168 Subscribers: jerryzh168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/81469 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a18ff3247
commit
194255bb56
319
test/quantization/core/test_backend_config.py
Normal file
319
test/quantization/core/test_backend_config.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
# Owner(s): ["oncall: quantization"]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.intrinsic as nni
|
||||||
|
import torch.nn.qat as nnqat
|
||||||
|
import torch.nn.quantized._reference as nnqr
|
||||||
|
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||||
|
|
||||||
|
from torch.ao.quantization.backend_config import (
|
||||||
|
BackendConfig,
|
||||||
|
BackendPatternConfig,
|
||||||
|
DTypeConfig,
|
||||||
|
ObservationType,
|
||||||
|
)
|
||||||
|
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
|
||||||
|
from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
|
||||||
|
from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter
|
||||||
|
from torch.ao.quantization.observer import default_fixed_qparams_range_0to1_observer
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackendConfig(QuantizationTestCase):
|
||||||
|
|
||||||
|
# =============
|
||||||
|
# DTypeConfig
|
||||||
|
# =============
|
||||||
|
|
||||||
|
dtype_config1 = DTypeConfig(
|
||||||
|
input_dtype=torch.quint8,
|
||||||
|
output_dtype=torch.quint8,
|
||||||
|
weight_dtype=torch.qint8,
|
||||||
|
bias_dtype=torch.float
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype_config2 = DTypeConfig(
|
||||||
|
input_dtype=torch.float16,
|
||||||
|
output_dtype=torch.float,
|
||||||
|
is_dynamic=True
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype_config_dict1 = {
|
||||||
|
"input_dtype": torch.quint8,
|
||||||
|
"output_dtype": torch.quint8,
|
||||||
|
"weight_dtype": torch.qint8,
|
||||||
|
"bias_dtype": torch.float,
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype_config_dict2 = {
|
||||||
|
"input_dtype": torch.float16,
|
||||||
|
"output_dtype": torch.float,
|
||||||
|
"is_dynamic": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_dtype_config_from_dict(self):
|
||||||
|
self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict1), self.dtype_config1)
|
||||||
|
self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict2), self.dtype_config2)
|
||||||
|
|
||||||
|
def test_dtype_config_to_dict(self):
|
||||||
|
self.assertEqual(self.dtype_config1.to_dict(), self.dtype_config_dict1)
|
||||||
|
self.assertEqual(self.dtype_config2.to_dict(), self.dtype_config_dict2)
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# BackendPatternConfig
|
||||||
|
# ======================
|
||||||
|
|
||||||
|
_fuser_method = reverse_sequential_wrapper2(nni.LinearReLU)
|
||||||
|
|
||||||
|
_num_tensor_args_to_observation_type = {
|
||||||
|
0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||||
|
1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
|
||||||
|
2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||||
|
}
|
||||||
|
_input_type_to_index = {
|
||||||
|
"bias": 0,
|
||||||
|
"input": 1,
|
||||||
|
"weight": 2,
|
||||||
|
}
|
||||||
|
_fake_quantize = FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)
|
||||||
|
|
||||||
|
def _extra_inputs_getter(self, p):
|
||||||
|
return (torch.rand(3, 3),)
|
||||||
|
|
||||||
|
def _get_backend_op_config1(self):
|
||||||
|
return BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) \
|
||||||
|
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
|
||||||
|
.add_dtype_config(self.dtype_config1) \
|
||||||
|
.add_dtype_config(self.dtype_config2) \
|
||||||
|
.set_root_module(torch.nn.Linear) \
|
||||||
|
.set_qat_module(nnqat.Linear) \
|
||||||
|
.set_reference_quantized_module(nnqr.Linear) \
|
||||||
|
.set_fused_module(nni.LinearReLU) \
|
||||||
|
.set_fuser_method(self._fuser_method)
|
||||||
|
|
||||||
|
def _get_backend_op_config2(self):
|
||||||
|
return BackendPatternConfig(torch.add) \
|
||||||
|
.add_dtype_config(self.dtype_config2) \
|
||||||
|
._set_root_node_getter(_default_root_node_getter) \
|
||||||
|
._set_extra_inputs_getter(self._extra_inputs_getter) \
|
||||||
|
._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \
|
||||||
|
._set_input_type_to_index(self._input_type_to_index) \
|
||||||
|
._set_input_output_observed(False) \
|
||||||
|
._set_overwrite_output_fake_quantize(self._fake_quantize) \
|
||||||
|
._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer)
|
||||||
|
|
||||||
|
def _get_backend_pattern_config_dict1(self):
|
||||||
|
return {
|
||||||
|
"pattern": (torch.nn.ReLU, torch.nn.Linear),
|
||||||
|
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||||
|
"dtype_configs": [self.dtype_config_dict1, self.dtype_config_dict2],
|
||||||
|
"root_module": torch.nn.Linear,
|
||||||
|
"qat_module": nnqat.Linear,
|
||||||
|
"reference_quantized_module_for_root": nnqr.Linear,
|
||||||
|
"fused_module": nni.LinearReLU,
|
||||||
|
"fuser_method": self._fuser_method,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_backend_pattern_config_dict2(self):
|
||||||
|
return {
|
||||||
|
"pattern": torch.add,
|
||||||
|
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||||
|
"dtype_configs": [self.dtype_config_dict2],
|
||||||
|
"root_node_getter": _default_root_node_getter,
|
||||||
|
"extra_inputs_getter": self._extra_inputs_getter,
|
||||||
|
"num_tensor_args_to_observation_type": self._num_tensor_args_to_observation_type,
|
||||||
|
"input_type_to_index": self._input_type_to_index,
|
||||||
|
"input_output_observed": False,
|
||||||
|
"overwrite_output_fake_quantize": self._fake_quantize,
|
||||||
|
"overwrite_output_observer": default_fixed_qparams_range_0to1_observer
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_backend_op_config_set_observation_type(self):
|
||||||
|
conf = BackendPatternConfig(torch.nn.Linear)
|
||||||
|
self.assertEqual(conf.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
|
||||||
|
conf.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
|
||||||
|
self.assertEqual(conf.observation_type, ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
|
||||||
|
|
||||||
|
def test_backend_op_config_add_dtype_config(self):
|
||||||
|
conf = BackendPatternConfig(torch.nn.Linear)
|
||||||
|
self.assertEqual(len(conf.dtype_configs), 0)
|
||||||
|
conf.add_dtype_config(self.dtype_config1)
|
||||||
|
conf.add_dtype_config(self.dtype_config2)
|
||||||
|
self.assertEqual(len(conf.dtype_configs), 2)
|
||||||
|
self.assertEqual(conf.dtype_configs[0], self.dtype_config1)
|
||||||
|
self.assertEqual(conf.dtype_configs[1], self.dtype_config2)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_root_module(self):
|
||||||
|
conf = BackendPatternConfig(nni.LinearReLU)
|
||||||
|
self.assertTrue(conf.root_module is None)
|
||||||
|
conf.set_root_module(torch.nn.Linear)
|
||||||
|
self.assertEqual(conf.root_module, torch.nn.Linear)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_qat_module(self):
|
||||||
|
conf = BackendPatternConfig(torch.nn.Linear)
|
||||||
|
self.assertTrue(conf.qat_module is None)
|
||||||
|
conf.set_qat_module(nnqat.Linear)
|
||||||
|
self.assertEqual(conf.qat_module, nnqat.Linear)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_reference_quantized_module(self):
|
||||||
|
conf = BackendPatternConfig(torch.nn.Linear)
|
||||||
|
self.assertTrue(conf.reference_quantized_module is None)
|
||||||
|
conf.set_reference_quantized_module(nnqr.Linear)
|
||||||
|
self.assertEqual(conf.reference_quantized_module, nnqr.Linear)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_fused_module(self):
|
||||||
|
conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
|
||||||
|
self.assertTrue(conf.fused_module is None)
|
||||||
|
conf.set_fused_module(nni.LinearReLU)
|
||||||
|
self.assertEqual(conf.fused_module, nni.LinearReLU)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_fuser_method(self):
|
||||||
|
conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
|
||||||
|
self.assertTrue(conf.fuser_method is None)
|
||||||
|
conf.set_fuser_method(self._fuser_method)
|
||||||
|
self.assertEqual(conf.fuser_method, self._fuser_method)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_root_node_getter(self):
|
||||||
|
conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear))
|
||||||
|
self.assertTrue(conf._root_node_getter is None)
|
||||||
|
conf._set_root_node_getter(_default_root_node_getter)
|
||||||
|
self.assertEqual(conf._root_node_getter, _default_root_node_getter)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_extra_inputs_getter(self):
|
||||||
|
conf = BackendPatternConfig(torch.nn.Linear)
|
||||||
|
self.assertTrue(conf._extra_inputs_getter is None)
|
||||||
|
conf._set_extra_inputs_getter(self._extra_inputs_getter)
|
||||||
|
self.assertEqual(conf._extra_inputs_getter, self._extra_inputs_getter)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_num_tensor_args_to_observation_type(self):
|
||||||
|
conf = BackendPatternConfig(torch.add)
|
||||||
|
self.assertEqual(len(conf._num_tensor_args_to_observation_type), 0)
|
||||||
|
conf._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type)
|
||||||
|
self.assertEqual(conf._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_input_type_to_index(self):
|
||||||
|
conf = BackendPatternConfig(torch.addmm)
|
||||||
|
self.assertEqual(len(conf._input_type_to_index), 0)
|
||||||
|
conf._set_input_type_to_index(self._input_type_to_index)
|
||||||
|
self.assertEqual(conf._input_type_to_index, self._input_type_to_index)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_input_output_observed(self):
|
||||||
|
conf = BackendPatternConfig(torch.nn.Embedding)
|
||||||
|
self.assertTrue(conf._input_output_observed is None)
|
||||||
|
conf._set_input_output_observed(False)
|
||||||
|
self.assertEqual(conf._input_output_observed, False)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_overwrite_output_fake_quantize(self):
|
||||||
|
conf = BackendPatternConfig(torch.sigmoid)
|
||||||
|
self.assertTrue(conf._overwrite_output_fake_quantize is None)
|
||||||
|
conf._set_overwrite_output_fake_quantize(self._fake_quantize)
|
||||||
|
self.assertEqual(conf._overwrite_output_fake_quantize, self._fake_quantize)
|
||||||
|
|
||||||
|
def test_backend_op_config_set_overwrite_output_observer(self):
|
||||||
|
conf = BackendPatternConfig(torch.sigmoid)
|
||||||
|
self.assertTrue(conf._overwrite_output_observer is None)
|
||||||
|
conf._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer)
|
||||||
|
self.assertEqual(conf._overwrite_output_observer, default_fixed_qparams_range_0to1_observer)
|
||||||
|
|
||||||
|
def test_backend_op_config_from_dict(self):
|
||||||
|
conf_dict1 = self._get_backend_pattern_config_dict1()
|
||||||
|
conf1 = BackendPatternConfig.from_dict(conf_dict1)
|
||||||
|
self.assertEqual(conf1.pattern, (torch.nn.ReLU, torch.nn.Linear))
|
||||||
|
self.assertEqual(conf1.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
|
||||||
|
self.assertEqual(conf1.root_module, torch.nn.Linear)
|
||||||
|
self.assertEqual(conf1.qat_module, nnqat.Linear)
|
||||||
|
self.assertEqual(conf1.reference_quantized_module, nnqr.Linear)
|
||||||
|
self.assertEqual(conf1.fused_module, nni.LinearReLU)
|
||||||
|
self.assertEqual(conf1.fuser_method, self._fuser_method)
|
||||||
|
self.assertTrue(conf1._root_node_getter is None)
|
||||||
|
self.assertTrue(conf1._extra_inputs_getter is None)
|
||||||
|
self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0)
|
||||||
|
self.assertEqual(len(conf1._input_type_to_index), 0)
|
||||||
|
self.assertTrue(conf1._input_output_observed is None)
|
||||||
|
self.assertTrue(conf1._overwrite_output_fake_quantize is None)
|
||||||
|
self.assertTrue(conf1._overwrite_output_observer is None)
|
||||||
|
# Test temporary/internal keys
|
||||||
|
conf_dict2 = self._get_backend_pattern_config_dict2()
|
||||||
|
conf2 = BackendPatternConfig.from_dict(conf_dict2)
|
||||||
|
self.assertEqual(conf2.pattern, torch.add)
|
||||||
|
self.assertEqual(conf2.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
|
||||||
|
self.assertTrue(conf2.root_module is None)
|
||||||
|
self.assertTrue(conf2.qat_module is None)
|
||||||
|
self.assertTrue(conf2.reference_quantized_module is None)
|
||||||
|
self.assertTrue(conf2.fused_module is None)
|
||||||
|
self.assertTrue(conf2.fuser_method is None)
|
||||||
|
self.assertEqual(conf2._root_node_getter, _default_root_node_getter)
|
||||||
|
self.assertEqual(conf2._extra_inputs_getter, self._extra_inputs_getter)
|
||||||
|
self.assertEqual(conf2._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type)
|
||||||
|
self.assertEqual(conf2._input_type_to_index, self._input_type_to_index)
|
||||||
|
self.assertEqual(conf2._input_output_observed, False)
|
||||||
|
self.assertEqual(conf2._overwrite_output_fake_quantize, self._fake_quantize)
|
||||||
|
self.assertEqual(conf2._overwrite_output_observer, default_fixed_qparams_range_0to1_observer)
|
||||||
|
|
||||||
|
def test_backend_op_config_to_dict(self):
|
||||||
|
conf1 = self._get_backend_op_config1()
|
||||||
|
conf2 = self._get_backend_op_config2()
|
||||||
|
conf_dict1 = self._get_backend_pattern_config_dict1()
|
||||||
|
conf_dict2 = self._get_backend_pattern_config_dict2()
|
||||||
|
self.assertEqual(conf1.to_dict(), conf_dict1)
|
||||||
|
self.assertEqual(conf2.to_dict(), conf_dict2)
|
||||||
|
|
||||||
|
# ===============
|
||||||
|
# BackendConfig
|
||||||
|
# ===============
|
||||||
|
|
||||||
|
def test_backend_config_set_name(self):
|
||||||
|
conf = BackendConfig("name1")
|
||||||
|
self.assertEqual(conf.name, "name1")
|
||||||
|
conf.set_name("name2")
|
||||||
|
self.assertEqual(conf.name, "name2")
|
||||||
|
|
||||||
|
def test_backend_config_set_backend_pattern_config(self):
|
||||||
|
conf = BackendConfig("name1")
|
||||||
|
self.assertEqual(len(conf.configs), 0)
|
||||||
|
backend_op_config1 = self._get_backend_op_config1()
|
||||||
|
backend_op_config2 = self._get_backend_op_config2()
|
||||||
|
conf.set_backend_pattern_config(backend_op_config1)
|
||||||
|
self.assertEqual(conf.configs, {
|
||||||
|
(torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
|
||||||
|
})
|
||||||
|
conf.set_backend_pattern_config(backend_op_config2)
|
||||||
|
self.assertEqual(conf.configs, {
|
||||||
|
(torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
|
||||||
|
torch.add: backend_op_config2
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_backend_config_from_dict(self):
|
||||||
|
op1 = self._get_backend_op_config1()
|
||||||
|
op2 = self._get_backend_op_config2()
|
||||||
|
op_dict1 = self._get_backend_pattern_config_dict1()
|
||||||
|
op_dict2 = self._get_backend_pattern_config_dict2()
|
||||||
|
conf_dict = {
|
||||||
|
"name": "name1",
|
||||||
|
"configs": [op_dict1, op_dict2],
|
||||||
|
}
|
||||||
|
conf = BackendConfig.from_dict(conf_dict)
|
||||||
|
self.assertEqual(conf.name, "name1")
|
||||||
|
self.assertEqual(len(conf.configs), 2)
|
||||||
|
key1 = (torch.nn.ReLU, torch.nn.Linear)
|
||||||
|
key2 = torch.add
|
||||||
|
self.assertTrue(key1 in conf.configs)
|
||||||
|
self.assertTrue(key2 in conf.configs)
|
||||||
|
self.assertEqual(conf.configs[key1].to_dict(), op_dict1)
|
||||||
|
self.assertEqual(conf.configs[key2].to_dict(), op_dict2)
|
||||||
|
|
||||||
|
def test_backend_config_to_dict(self):
|
||||||
|
op1 = self._get_backend_op_config1()
|
||||||
|
op2 = self._get_backend_op_config2()
|
||||||
|
op_dict1 = self._get_backend_pattern_config_dict1()
|
||||||
|
op_dict2 = self._get_backend_pattern_config_dict2()
|
||||||
|
conf = BackendConfig("name1").set_backend_pattern_config(op1).set_backend_pattern_config(op2)
|
||||||
|
conf_dict = {
|
||||||
|
"name": "name1",
|
||||||
|
"configs": [op_dict1, op_dict2],
|
||||||
|
}
|
||||||
|
self.assertEqual(conf.to_dict(), conf_dict)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
raise RuntimeError("This _test file is not meant to be run directly, use:\n\n"
|
||||||
|
"\tpython _test/_test_quantization.py TESTNAME\n\n"
|
||||||
|
"instead.")
|
@ -1858,6 +1858,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||||||
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
|
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
|
||||||
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
|
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
|
||||||
|
|
||||||
|
# TODO: move QConfigMapping tests to test/quantization/core
|
||||||
def test_qconfig_mapping_set_global(self):
|
def test_qconfig_mapping_set_global(self):
|
||||||
qconfig = get_default_qconfig()
|
qconfig = get_default_qconfig()
|
||||||
qconfig_mapping = QConfigMapping()
|
qconfig_mapping = QConfigMapping()
|
||||||
|
@ -36,6 +36,7 @@ from quantization.core.test_workflow_module import TestRecordHistogramObserver
|
|||||||
from quantization.core.test_workflow_module import TestHistogramObserver # noqa: F401
|
from quantization.core.test_workflow_module import TestHistogramObserver # noqa: F401
|
||||||
from quantization.core.test_workflow_module import TestDistributed # noqa: F401
|
from quantization.core.test_workflow_module import TestDistributed # noqa: F401
|
||||||
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
|
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
|
||||||
|
from quantization.core.test_backend_config import TestBackendConfig # noqa: F401
|
||||||
from quantization.core.test_utils import TestUtils # noqa: F401
|
from quantization.core.test_utils import TestUtils # noqa: F401
|
||||||
from quantization.core.test_docs import TestQuantizationDocs # noqa: F401
|
from quantization.core.test_docs import TestQuantizationDocs # noqa: F401
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from .tensorrt import get_tensorrt_backend_config_dict
|
from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig
|
||||||
from .native import get_native_backend_config_dict
|
from .native import get_native_backend_config_dict
|
||||||
|
from .observation_type import ObservationType
|
||||||
|
from .tensorrt import get_tensorrt_backend_config_dict
|
||||||
|
|
||||||
# TODO: add more validations
|
# TODO: add more validations
|
||||||
def validate_backend_config_dict(backend_config_dict):
|
def validate_backend_config_dict(backend_config_dict):
|
||||||
|
390
torch/ao/quantization/backend_config/backend_config.py
Normal file
390
torch/ao/quantization/backend_config/backend_config.py
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.ao.quantization.backend_config.observation_type import ObservationType
|
||||||
|
from torch.ao.quantization.observer import _PartialWrapper
|
||||||
|
from torch.ao.quantization.utils import Pattern
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BackendConfig",
|
||||||
|
"BackendPatternConfig",
|
||||||
|
"DTypeConfig",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# DTypeConfig dict keys
|
||||||
|
INPUT_DTYPE_DICT_KEY = "input_dtype"
|
||||||
|
OUTPUT_DTYPE_DICT_KEY = "output_dtype"
|
||||||
|
WEIGHT_DTYPE_DICT_KEY = "weight_dtype"
|
||||||
|
BIAS_DTYPE_DICT_KEY = "bias_dtype"
|
||||||
|
IS_DYNAMIC_DICT_KEY = "is_dynamic"
|
||||||
|
|
||||||
|
# BackendConfig dict keys
|
||||||
|
NAME_DICT_KEY = "name"
|
||||||
|
CONFIGS_DICT_KEY = "configs"
|
||||||
|
|
||||||
|
# BackendPatternConfig dict keys
|
||||||
|
PATTERN_DICT_KEY = "pattern"
|
||||||
|
OBSERVATION_TYPE_DICT_KEY = "observation_type"
|
||||||
|
DTYPE_CONFIGS_DICT_KEY = "dtype_configs"
|
||||||
|
ROOT_MODULE_DICT_KEY = "root_module"
|
||||||
|
QAT_MODULE_DICT_KEY = "qat_module"
|
||||||
|
REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root"
|
||||||
|
FUSED_MODULE_DICT_KEY = "fused_module"
|
||||||
|
FUSER_METHOD_DICT_KEY = "fuser_method"
|
||||||
|
ROOT_NODE_GETTER_DICT_KEY = "root_node_getter"
|
||||||
|
EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter"
|
||||||
|
NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type"
|
||||||
|
INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index"
|
||||||
|
INPUT_OUTPUT_OBSERVED_DICT_KEY = "input_output_observed"
|
||||||
|
OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY = "overwrite_output_fake_quantize"
|
||||||
|
OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DTypeConfig:
|
||||||
|
"""
|
||||||
|
Config for the set of supported input/output activation, weight, and bias data types for the
|
||||||
|
patterns defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`.
|
||||||
|
"""
|
||||||
|
input_dtype: Optional[torch.dtype] = None
|
||||||
|
output_dtype: Optional[torch.dtype] = None
|
||||||
|
weight_dtype: Optional[torch.dtype] = None
|
||||||
|
bias_dtype: Optional[torch.dtype] = None
|
||||||
|
is_dynamic: Optional[bool] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig:
|
||||||
|
"""
|
||||||
|
Create a `DTypeConfig` from a dictionary with the following items (all optional):
|
||||||
|
|
||||||
|
"input_dtype": torch.dtype
|
||||||
|
"output_dtype": torch.dtype
|
||||||
|
"weight_dtype": torch.dtype
|
||||||
|
"bias_type": torch.dtype
|
||||||
|
"is_dynamic": bool
|
||||||
|
"""
|
||||||
|
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
|
||||||
|
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
|
||||||
|
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
|
||||||
|
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
|
||||||
|
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
|
||||||
|
return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert this `DTypeConfig` to a dictionary with the items described in
|
||||||
|
:func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`.
|
||||||
|
"""
|
||||||
|
dtype_config_dict: Dict[str, Any] = {}
|
||||||
|
if self.input_dtype is not None:
|
||||||
|
dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype
|
||||||
|
if self.output_dtype is not None:
|
||||||
|
dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = self.output_dtype
|
||||||
|
if self.weight_dtype is not None:
|
||||||
|
dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = self.weight_dtype
|
||||||
|
if self.bias_dtype is not None:
|
||||||
|
dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype
|
||||||
|
if self.is_dynamic is not None:
|
||||||
|
dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic
|
||||||
|
return dtype_config_dict
|
||||||
|
|
||||||
|
|
||||||
|
class BackendConfig:
|
||||||
|
# TODO: refer to NativeBackendConfig once that is implemented
|
||||||
|
"""
|
||||||
|
Config that defines the set of patterns that can be quantized on a given backend, and how reference
|
||||||
|
quantized models can be produced from these patterns.
|
||||||
|
|
||||||
|
A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph
|
||||||
|
of the above. Each pattern supported on the target backend can be individually configured through
|
||||||
|
:class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of:
|
||||||
|
(1) The supported input/output activation, weight, and bias data types
|
||||||
|
(2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and
|
||||||
|
(3) (Optionally) Fusion, QAT, and reference module mappings.
|
||||||
|
|
||||||
|
The format of the patterns is described in:
|
||||||
|
https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
|
||||||
|
|
||||||
|
Example usage::
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType
|
||||||
|
from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
|
||||||
|
|
||||||
|
weighted_int8_dtype_config = DTypeConfig(
|
||||||
|
input_dtype=torch.quint8,
|
||||||
|
output_dtype=torch.quint8,
|
||||||
|
weight_dtype=torch.qint8,
|
||||||
|
bias_type=torch.float)
|
||||||
|
|
||||||
|
linear_config = BackendPatternConfig(torch.nn.Linear) \
|
||||||
|
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
|
||||||
|
.add_dtype_config(weighted_int8_dtype_config) \
|
||||||
|
.set_root_module(torch.nn.Linear) \
|
||||||
|
.set_qat_module(torch.nn.qat.Linear) \
|
||||||
|
.set_reference_quantized_module(torch.nn.quantized._reference.Linear)
|
||||||
|
|
||||||
|
conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \
|
||||||
|
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
|
||||||
|
.add_dtype_config(weighted_int8_dtype_config) \
|
||||||
|
.set_fused_module(torch.nn.intrinsic.ConvReLU2d) \
|
||||||
|
.set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d))
|
||||||
|
|
||||||
|
backend_config = BackendConfig("my_backend") \
|
||||||
|
.set_backend_pattern_config(linear_config) \
|
||||||
|
.set_backend_pattern_config(conv_relu_config)
|
||||||
|
"""
|
||||||
|
def __init__(self, name: str = ""):
|
||||||
|
self.name = name
|
||||||
|
self.configs: Dict[Pattern, BackendPatternConfig] = {}
|
||||||
|
|
||||||
|
def set_name(self, name: str) -> BackendConfig:
|
||||||
|
"""
|
||||||
|
Set the name of the target backend.
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig:
|
||||||
|
"""
|
||||||
|
Set the config for an op that can be run on the target backend.
|
||||||
|
This overrides any existing config for the given op.
|
||||||
|
"""
|
||||||
|
self.configs[config.pattern] = config
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig:
|
||||||
|
"""
|
||||||
|
Create a `BackendConfig` from a dictionary with the following items:
|
||||||
|
|
||||||
|
"name": the name of the target backend
|
||||||
|
"configs": a list of dictionaries that each represents a `BackendPatternConfig`
|
||||||
|
"""
|
||||||
|
for dict_key in [NAME_DICT_KEY, CONFIGS_DICT_KEY]:
|
||||||
|
if dict_key not in backend_config_dict:
|
||||||
|
raise ValueError("backend_config_dict must contain '%s'" % dict_key)
|
||||||
|
conf = cls(backend_config_dict[NAME_DICT_KEY])
|
||||||
|
for d in backend_config_dict[CONFIGS_DICT_KEY]:
|
||||||
|
if isinstance(d, BackendPatternConfig):
|
||||||
|
conf.set_backend_pattern_config(d)
|
||||||
|
elif isinstance(d, Dict):
|
||||||
|
conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d))
|
||||||
|
else:
|
||||||
|
raise ValueError("Expected backend_config_dict['%s'] to be a dictionary" % CONFIGS_DICT_KEY)
|
||||||
|
return conf
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert this `BackendConfig` to a dictionary with the items described in
|
||||||
|
:func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
NAME_DICT_KEY: self.name,
|
||||||
|
CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs.values()],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BackendPatternConfig:
|
||||||
|
"""
|
||||||
|
Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`.
|
||||||
|
|
||||||
|
The user can configure how a operator pattern graph is handled on a given backend using the following methods:
|
||||||
|
`set_observation_type`: sets how observers should be inserted for this pattern.
|
||||||
|
See :class:`~torch.ao.quantization.backend_config.ObservationType`
|
||||||
|
`add_dtype_config`: add a set of supported data types for this pattern
|
||||||
|
`set_root_module`: sets the module that represents the root for this pattern
|
||||||
|
`set_qat_module`: sets the module that represents the QAT implementation for this pattern
|
||||||
|
`set_reference_quantized_module`: sets the module that represents the reference quantized
|
||||||
|
implementation for this pattern's root module.
|
||||||
|
`set_fused_module`: sets the module that represents the fused implementation for this pattern
|
||||||
|
`set_fuser_method`: sets the function that specifies how to fuse the pattern for this pattern
|
||||||
|
|
||||||
|
For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
|
||||||
|
"""
|
||||||
|
def __init__(self, pattern: Pattern):
|
||||||
|
self.pattern = pattern
|
||||||
|
self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
|
||||||
|
self.dtype_configs: List[DTypeConfig] = []
|
||||||
|
self.root_module: Optional[torch.nn.Module] = None
|
||||||
|
self.qat_module: Optional[torch.nn.Module] = None
|
||||||
|
self.reference_quantized_module: Optional[torch.nn.Module] = None
|
||||||
|
self.fused_module: Optional[torch.nn.Module] = None
|
||||||
|
self.fuser_method: Optional[Callable] = None
|
||||||
|
|
||||||
|
# Temporary/internal configs
|
||||||
|
self._root_node_getter: Optional[Callable] = None
|
||||||
|
self._extra_inputs_getter: Optional[Callable] = None
|
||||||
|
self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {}
|
||||||
|
self._input_type_to_index: Dict[str, int] = {}
|
||||||
|
self._input_output_observed: Optional[bool] = None
|
||||||
|
self._overwrite_output_fake_quantize: Optional[_PartialWrapper] = None
|
||||||
|
self._overwrite_output_observer: Optional[_PartialWrapper] = None
|
||||||
|
|
||||||
|
def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig:
|
||||||
|
"""
|
||||||
|
Set how observers should be inserted for this pattern.
|
||||||
|
"""
|
||||||
|
self.observation_type = observation_type
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig:
|
||||||
|
"""
|
||||||
|
Register a set of supported input/output activation, weight, and bias data types for this pattern.
|
||||||
|
"""
|
||||||
|
self.dtype_configs.append(dtype_config)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_root_module(self, root_module: torch.nn.Module) -> BackendPatternConfig:
|
||||||
|
"""
|
||||||
|
Set the module that represents the root for this pattern.
|
||||||
|
For example, the root module for :class:`torch.nn.intrinsic.LinearReLU` should be :class:`torch.nn.Linear`.
|
||||||
|
"""
|
||||||
|
self.root_module = root_module
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_qat_module(self, qat_module: torch.nn.Module) -> BackendPatternConfig:
|
||||||
|
"""
|
||||||
|
Set the module that represents the QAT implementation for this pattern.
|
||||||
|
"""
|
||||||
|
self.qat_module = qat_module
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_reference_quantized_module(self, reference_quantized_module: torch.nn.Module) -> BackendPatternConfig:
|
||||||
|
"""
|
||||||
|
Set the module that represents the reference quantized implementation for this pattern's root module.
|
||||||
|
"""
|
||||||
|
self.reference_quantized_module = reference_quantized_module
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_fused_module(self, fused_module: torch.nn.Module) -> BackendPatternConfig:
|
||||||
|
"""
|
||||||
|
Set the module that represents the fused implementation for this pattern.
|
||||||
|
"""
|
||||||
|
self.fused_module = fused_module
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig:
|
||||||
|
"""
|
||||||
|
Set the function that specifies how to fuse the pattern for this pattern.
|
||||||
|
"""
|
||||||
|
self.fuser_method = fuser_method
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig:
|
||||||
|
self._root_node_getter = root_node_getter
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _set_extra_inputs_getter(self, extra_inputs_getter: Callable) -> BackendPatternConfig:
|
||||||
|
self._extra_inputs_getter = extra_inputs_getter
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _set_num_tensor_args_to_observation_type(
|
||||||
|
self, num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> BackendPatternConfig:
|
||||||
|
self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _set_input_type_to_index(self, input_type_to_index: Dict[str, int]) -> BackendPatternConfig:
|
||||||
|
self._input_type_to_index = input_type_to_index
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _set_input_output_observed(self, input_output_observed: bool) -> BackendPatternConfig:
|
||||||
|
self._input_output_observed = input_output_observed
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _set_overwrite_output_fake_quantize(self, overwrite_output_fake_quantize: _PartialWrapper) -> BackendPatternConfig:
|
||||||
|
self._overwrite_output_fake_quantize = overwrite_output_fake_quantize
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _set_overwrite_output_observer(self, overwrite_output_observer: _PartialWrapper) -> BackendPatternConfig:
|
||||||
|
self._overwrite_output_observer = overwrite_output_observer
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig:
|
||||||
|
"""
|
||||||
|
Create a `BackendPatternConfig` from a dictionary with the following items:
|
||||||
|
|
||||||
|
"pattern": the pattern being configured
|
||||||
|
"observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how
|
||||||
|
observers should be inserted for this pattern
|
||||||
|
"dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig`s
|
||||||
|
"root_module": a :class:`torch.nn.Module` that represents the root for this pattern
|
||||||
|
"qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern
|
||||||
|
"reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized
|
||||||
|
implementation for this pattern's root module.
|
||||||
|
"fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern
|
||||||
|
"fuser_method": a function that specifies how to fuse the pattern for this pattern
|
||||||
|
"""
|
||||||
|
def _get_dtype_config(obj: Any) -> DTypeConfig:
|
||||||
|
"""
|
||||||
|
Convert the given object into a `DTypeConfig` if possible, else throw an exception.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, DTypeConfig):
|
||||||
|
return obj
|
||||||
|
if isinstance(obj, Dict):
|
||||||
|
return DTypeConfig.from_dict(obj)
|
||||||
|
raise ValueError("Expected a list of DTypeConfigs in backend_pattern_config_dict[\"%s\"], got '%s'" %
|
||||||
|
(DTYPE_CONFIGS_DICT_KEY, type(obj)))
|
||||||
|
|
||||||
|
if PATTERN_DICT_KEY not in backend_pattern_config_dict:
|
||||||
|
raise ValueError("backend_pattern_config_dict must contain '%s'" % PATTERN_DICT_KEY)
|
||||||
|
conf = cls(backend_pattern_config_dict[PATTERN_DICT_KEY])
|
||||||
|
if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict:
|
||||||
|
conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY])
|
||||||
|
for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
|
||||||
|
conf.add_dtype_config(_get_dtype_config(d))
|
||||||
|
conf.set_root_module(backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None))
|
||||||
|
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None))
|
||||||
|
conf.set_reference_quantized_module(backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None))
|
||||||
|
conf.set_fused_module(backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None))
|
||||||
|
conf.set_fuser_method(backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None))
|
||||||
|
conf._set_root_node_getter(backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None))
|
||||||
|
conf._set_extra_inputs_getter(backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None))
|
||||||
|
conf._set_num_tensor_args_to_observation_type(
|
||||||
|
backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {}))
|
||||||
|
conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}))
|
||||||
|
conf._set_input_output_observed(backend_pattern_config_dict.get(INPUT_OUTPUT_OBSERVED_DICT_KEY, None))
|
||||||
|
conf._set_overwrite_output_fake_quantize(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY, None))
|
||||||
|
conf._set_overwrite_output_observer(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_OBSERVER_DICT_KEY, None))
|
||||||
|
return conf
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert this `BackendPatternConfig` to a dictionary with the items described in
|
||||||
|
:func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`.
|
||||||
|
"""
|
||||||
|
backend_pattern_config_dict: Dict[str, Any] = {
|
||||||
|
PATTERN_DICT_KEY: self.pattern,
|
||||||
|
OBSERVATION_TYPE_DICT_KEY: self.observation_type,
|
||||||
|
DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs],
|
||||||
|
}
|
||||||
|
if self.root_module is not None:
|
||||||
|
backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module
|
||||||
|
if self.qat_module is not None:
|
||||||
|
backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module
|
||||||
|
if self.reference_quantized_module is not None:
|
||||||
|
backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = self.reference_quantized_module
|
||||||
|
if self.fused_module is not None:
|
||||||
|
backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module
|
||||||
|
if self.fuser_method is not None:
|
||||||
|
backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method
|
||||||
|
if self._root_node_getter is not None:
|
||||||
|
backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = self._root_node_getter
|
||||||
|
if self._extra_inputs_getter is not None:
|
||||||
|
backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = self._extra_inputs_getter
|
||||||
|
if len(self._num_tensor_args_to_observation_type) > 0:
|
||||||
|
backend_pattern_config_dict[NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY] = self._num_tensor_args_to_observation_type
|
||||||
|
if len(self._input_type_to_index) > 0:
|
||||||
|
backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index
|
||||||
|
if self._input_output_observed is not None:
|
||||||
|
backend_pattern_config_dict[INPUT_OUTPUT_OBSERVED_DICT_KEY] = self._input_output_observed
|
||||||
|
if self._overwrite_output_fake_quantize is not None:
|
||||||
|
backend_pattern_config_dict[OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY] = self._overwrite_output_fake_quantize
|
||||||
|
if self._overwrite_output_observer is not None:
|
||||||
|
backend_pattern_config_dict[OVERWRITE_OUTPUT_OBSERVER_DICT_KEY] = self._overwrite_output_observer
|
||||||
|
return backend_pattern_config_dict
|
Reference in New Issue
Block a user