[Quant][fx] Add get_default_qconfig_mapping

Summary: This follows https://github.com/pytorch/pytorch/pull/78452,
which replaced the qconfig_dict with QConfigMapping. This PR
additionally replaces get_default_*qconfig_dict with
get_default_*qconfig_mapping. For backward compatibility, we
deprecate the old functions instead of removing them.

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo, supriyar

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79618

Approved by: https://github.com/jerryzh168
This commit is contained in:
Andrew Or
2022-06-15 09:57:27 -07:00
committed by PyTorch MergeBot
parent 355a1c8c3f
commit 61a1eef7fc
3 changed files with 86 additions and 53 deletions

View File

@ -50,8 +50,8 @@ from torch.ao.quantization import (
float_qparams_weight_only_qconfig_4bit,
get_default_qconfig,
get_default_qat_qconfig,
get_default_qconfig_dict,
get_default_qat_qconfig_dict,
get_default_qconfig_mapping,
get_default_qat_qconfig_mapping,
fuse_modules,
fuse_modules_qat,
prepare,
@ -1854,7 +1854,7 @@ class TestQuantizeFx(QuantizationTestCase):
for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
m = model(relu).eval()
qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm")
qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping("fbgemm")
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
@ -4388,7 +4388,7 @@ class TestQuantizeFx(QuantizationTestCase):
for M, is_qat in options:
m = M1().eval()
example_inputs = (torch.randn(1, 3, 3, 3),)
m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
m = convert_fx(m)
node_list = [
ns.call_function(torch.quantize_per_tensor),
@ -4401,7 +4401,7 @@ class TestQuantizeFx(QuantizationTestCase):
expected_node_list=node_list)
m = M2().eval()
m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
m = convert_fx(m)
node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 0,
@ -4426,7 +4426,7 @@ class TestQuantizeFx(QuantizationTestCase):
return x
m = M().eval()
mp = prepare_fx(m, get_default_qconfig_dict(), example_inputs=(torch.randn(1, 1),))
mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),))
found_stack_trace = False
for n in mp.graph.nodes:
@ -4541,7 +4541,7 @@ class TestQuantizeFx(QuantizationTestCase):
return x
backends = ["qnnpack", "fbgemm"]
for func in [get_default_qconfig_dict, get_default_qat_qconfig_dict]:
for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]:
for backend in backends:
m = M().eval()
qconfig_dict = func(backend)
@ -4581,8 +4581,8 @@ class TestQuantizeFx(QuantizationTestCase):
prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)
# Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
_test(prepare_fx, get_default_qconfig_dict())
_test(prepare_qat_fx, get_default_qat_qconfig_dict())
_test(prepare_fx, get_default_qconfig_mapping())
_test(prepare_qat_fx, get_default_qat_qconfig_mapping())
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):

View File

@ -335,52 +335,17 @@ default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig(
eps=2 ** -12),
weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
def _get_default_qconfig_dict_helper(qconfig, qconfig_transpose):
return {
"": qconfig,
"object_type": [("reshape", default_reuse_input_qconfig),
(torch.nn.Conv1d, qconfig),
(torch.nn.Conv2d, qconfig),
(torch.nn.Conv3d, qconfig),
(torch.nn.ConvTranspose1d, qconfig_transpose),
(torch.nn.ConvTranspose2d, qconfig_transpose),
(torch.nn.ConvTranspose3d, qconfig_transpose),
(torch.nn.Linear, qconfig),
(torch.nn.functional.conv1d, qconfig),
(torch.nn.functional.conv2d, qconfig),
(torch.nn.functional.conv3d, qconfig),
(torch.nn.functional.conv_transpose1d, qconfig_transpose),
(torch.nn.functional.conv_transpose2d, qconfig_transpose),
(torch.nn.functional.conv_transpose3d, qconfig_transpose),
(torch.nn.functional.linear, qconfig),
(torch.nn.ReLU, qconfig),
(torch.nn.functional.relu, qconfig),
(torch.relu, qconfig),
(torch.nn.BatchNorm1d, qconfig),
(torch.nn.BatchNorm2d, qconfig),
(torch.nn.BatchNorm3d, qconfig)]}
def get_default_qconfig_dict(backend='fbgemm', version=0):
qconfig = get_default_qconfig(backend, version)
qconfig_transpose = qconfig
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
# so we have to modify the weight observer to default_weight_observer or another
# per tensor supported observer.
# see https://github.com/pytorch/pytorch/issues/47535
if backend == "fbgemm":
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_observer)
return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
warnings.warn(
"torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in "
"a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")
return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict()
def get_default_qat_qconfig_dict(backend='fbgemm', version=1):
qconfig = get_default_qat_qconfig(backend, version)
qconfig_transpose = qconfig
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
# so we have to modify the weight observer to default_weight_observer or another
# per tensor supported observer
# see https://github.com/pytorch/pytorch/issues/47535
if backend == "fbgemm":
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_fake_quant)
return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
warnings.warn(
"torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in "
"a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")
return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict()
def assert_valid_qconfig(qconfig: Optional[QConfig],
mod: torch.nn.Module) -> None:

View File

@ -2,10 +2,22 @@ from __future__ import annotations
from collections import OrderedDict
from typing import Any, Callable, Dict, Tuple, Union
from .qconfig import QConfigAny
import torch
from .fake_quantize import default_weight_fake_quant
from .observer import default_weight_observer
from .qconfig import (
default_reuse_input_qconfig,
get_default_qconfig,
get_default_qat_qconfig,
QConfig,
QConfigAny
)
__all__ = [
"get_default_qconfig_mapping",
"get_default_qat_qconfig_mapping",
"QConfigMapping",
]
@ -18,6 +30,62 @@ MODULE_NAME_DICT_KEY = "module_name"
MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int):
"""
Return the default QConfigMapping for the given quantization type and backend.
"""
if is_qat:
qconfig = get_default_qat_qconfig(backend, version)
else:
qconfig = get_default_qconfig(backend, version)
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
# so we have to modify the weight observer to default_weight_observer or another
# per tensor supported observer.
# see https://github.com/pytorch/pytorch/issues/47535
if backend == "fbgemm":
default_weight = default_weight_fake_quant if is_qat else default_weight_observer
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
else:
qconfig_transpose = qconfig
return QConfigMapping() \
.set_global(qconfig) \
.set_object_type("reshape", default_reuse_input_qconfig) \
.set_object_type(torch.nn.Conv1d, qconfig) \
.set_object_type(torch.nn.Conv2d, qconfig) \
.set_object_type(torch.nn.Conv3d, qconfig) \
.set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
.set_object_type(torch.nn.Linear, qconfig) \
.set_object_type(torch.nn.functional.conv1d, qconfig) \
.set_object_type(torch.nn.functional.conv2d, qconfig) \
.set_object_type(torch.nn.functional.conv3d, qconfig) \
.set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
.set_object_type(torch.nn.functional.linear, qconfig) \
.set_object_type(torch.nn.ReLU, qconfig) \
.set_object_type(torch.nn.functional.relu, qconfig) \
.set_object_type(torch.relu, qconfig) \
.set_object_type(torch.nn.BatchNorm1d, qconfig) \
.set_object_type(torch.nn.BatchNorm2d, qconfig) \
.set_object_type(torch.nn.BatchNorm3d, qconfig)
def get_default_qconfig_mapping(backend="fbgemm", version=0):
"""
Return the default QConfigMapping for post training quantization.
"""
return _get_default_qconfig_mapping(False, backend, version)
def get_default_qat_qconfig_mapping(backend="fbgemm", version=1):
"""
Return the default QConfigMapping for quantization aware training.
"""
return _get_default_qconfig_mapping(True, backend, version)
class QConfigMapping:
"""
Mapping from model ops to :class:`torch.ao.quantization.QConfig`s.