[quant][fx] Add cat to backend_config_dict (#75259)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75259

att

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestFXNumericSuiteCoreAPIs

Imported from OSS

Reviewed By: andrewor14

Differential Revision: D35403586

fbshipit-source-id: 066ace239a7ca5a49463f6fcc2fa10e3efef8794
(cherry picked from commit e4b7d91cc48f3a4c913940bf292272c5418c5cb0)
This commit is contained in:
Jerry Zhang
2022-04-06 13:13:40 -07:00
committed by Nikita Shulga
parent 86485f61c5
commit 2f3a94996c
2 changed files with 10 additions and 3 deletions

View File

@ -344,6 +344,14 @@ _HARDSIGMOID_MODULE_CONFIG = {
],
}
_CAT_CONFIG = {
"pattern": torch.cat,
"observation_type": ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
"dtype_configs": [
default_op_quint8_dtype_config,
]
}
def get_native_backend_config_dict():
""" Get backend_config_dict for PyTorch Native backend (fbgemm/qnnpack). """
return {
@ -355,5 +363,6 @@ def get_native_backend_config_dict():
*_get_conv_configs(),
*_get_binary_op_configs(),
_HARDSIGMOID_MODULE_CONFIG,
_CAT_CONFIG,
],
}

View File

@ -121,10 +121,8 @@ class QuantizeHandler(ABC):
class BinaryOpQuantizeHandler(QuantizeHandler):
pass
@register_quant_pattern(torch.cat)
class CatQuantizeHandler(QuantizeHandler):
def is_general_tensor_value_op(self) -> bool:
return True
pass
# TODO: remove this class
class ConvReluQuantizeHandler(QuantizeHandler):