[Quant][fx] Add get_default_qconfig_mapping

Summary: This follows https://github.com/pytorch/pytorch/pull/78452,
which replaced the qconfig_dict with QConfigMapping. This PR
additionally replaces get_default_*qconfig_dict with
get_default_*qconfig_mapping. For backward compatibility, we
deprecate the old functions instead of removing them.

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo, supriyar

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79618

Approved by: https://github.com/jerryzh168
This commit is contained in:
Andrew Or
2022-06-15 09:57:27 -07:00
committed by PyTorch MergeBot
parent 355a1c8c3f
commit 61a1eef7fc
3 changed files with 86 additions and 53 deletions

View File

@ -50,8 +50,8 @@ from torch.ao.quantization import (
float_qparams_weight_only_qconfig_4bit, float_qparams_weight_only_qconfig_4bit,
get_default_qconfig, get_default_qconfig,
get_default_qat_qconfig, get_default_qat_qconfig,
get_default_qconfig_dict, get_default_qconfig_mapping,
get_default_qat_qconfig_dict, get_default_qat_qconfig_mapping,
fuse_modules, fuse_modules,
fuse_modules_qat, fuse_modules_qat,
prepare, prepare,
@ -1854,7 +1854,7 @@ class TestQuantizeFx(QuantizationTestCase):
for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]: for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]: for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
m = model(relu).eval() m = model(relu).eval()
qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm") qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping("fbgemm")
# should not crash as in https://github.com/pytorch/pytorch/issues/75825 # should not crash as in https://github.com/pytorch/pytorch/issues/75825
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
@ -4388,7 +4388,7 @@ class TestQuantizeFx(QuantizationTestCase):
for M, is_qat in options: for M, is_qat in options:
m = M1().eval() m = M1().eval()
example_inputs = (torch.randn(1, 3, 3, 3),) example_inputs = (torch.randn(1, 3, 3, 3),)
m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs) m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
m = convert_fx(m) m = convert_fx(m)
node_list = [ node_list = [
ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.quantize_per_tensor),
@ -4401,7 +4401,7 @@ class TestQuantizeFx(QuantizationTestCase):
expected_node_list=node_list) expected_node_list=node_list)
m = M2().eval() m = M2().eval()
m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs) m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
m = convert_fx(m) m = convert_fx(m)
node_occurrence = { node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 0, ns.call_function(torch.quantize_per_tensor): 0,
@ -4426,7 +4426,7 @@ class TestQuantizeFx(QuantizationTestCase):
return x return x
m = M().eval() m = M().eval()
mp = prepare_fx(m, get_default_qconfig_dict(), example_inputs=(torch.randn(1, 1),)) mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),))
found_stack_trace = False found_stack_trace = False
for n in mp.graph.nodes: for n in mp.graph.nodes:
@ -4541,7 +4541,7 @@ class TestQuantizeFx(QuantizationTestCase):
return x return x
backends = ["qnnpack", "fbgemm"] backends = ["qnnpack", "fbgemm"]
for func in [get_default_qconfig_dict, get_default_qat_qconfig_dict]: for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]:
for backend in backends: for backend in backends:
m = M().eval() m = M().eval()
qconfig_dict = func(backend) qconfig_dict = func(backend)
@ -4581,8 +4581,8 @@ class TestQuantizeFx(QuantizationTestCase):
prepare_fn(m2, qconfig_dict, example_inputs=example_inputs) prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)
# Ensure prepare_fx and prepare_qat_fx work in both training and eval modes # Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
_test(prepare_fx, get_default_qconfig_dict()) _test(prepare_fx, get_default_qconfig_mapping())
_test(prepare_qat_fx, get_default_qat_qconfig_dict()) _test(prepare_qat_fx, get_default_qat_qconfig_mapping())
@skipIfNoFBGEMM @skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase): class TestQuantizeFxOps(QuantizationTestCase):

View File

@ -335,52 +335,17 @@ default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig(
eps=2 ** -12), eps=2 ** -12),
weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127) weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
def _get_default_qconfig_dict_helper(qconfig, qconfig_transpose):
return {
"": qconfig,
"object_type": [("reshape", default_reuse_input_qconfig),
(torch.nn.Conv1d, qconfig),
(torch.nn.Conv2d, qconfig),
(torch.nn.Conv3d, qconfig),
(torch.nn.ConvTranspose1d, qconfig_transpose),
(torch.nn.ConvTranspose2d, qconfig_transpose),
(torch.nn.ConvTranspose3d, qconfig_transpose),
(torch.nn.Linear, qconfig),
(torch.nn.functional.conv1d, qconfig),
(torch.nn.functional.conv2d, qconfig),
(torch.nn.functional.conv3d, qconfig),
(torch.nn.functional.conv_transpose1d, qconfig_transpose),
(torch.nn.functional.conv_transpose2d, qconfig_transpose),
(torch.nn.functional.conv_transpose3d, qconfig_transpose),
(torch.nn.functional.linear, qconfig),
(torch.nn.ReLU, qconfig),
(torch.nn.functional.relu, qconfig),
(torch.relu, qconfig),
(torch.nn.BatchNorm1d, qconfig),
(torch.nn.BatchNorm2d, qconfig),
(torch.nn.BatchNorm3d, qconfig)]}
def get_default_qconfig_dict(backend='fbgemm', version=0): def get_default_qconfig_dict(backend='fbgemm', version=0):
qconfig = get_default_qconfig(backend, version) warnings.warn(
qconfig_transpose = qconfig "torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in "
# default_per_channel_weight_observer is not currently compatible with fbgemm backend "a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")
# so we have to modify the weight observer to default_weight_observer or another return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict()
# per tensor supported observer.
# see https://github.com/pytorch/pytorch/issues/47535
if backend == "fbgemm":
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_observer)
return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
def get_default_qat_qconfig_dict(backend='fbgemm', version=1): def get_default_qat_qconfig_dict(backend='fbgemm', version=1):
qconfig = get_default_qat_qconfig(backend, version) warnings.warn(
qconfig_transpose = qconfig "torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in "
# default_per_channel_weight_observer is not currently compatible with fbgemm backend "a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")
# so we have to modify the weight observer to default_weight_observer or another return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict()
# per tensor supported observer
# see https://github.com/pytorch/pytorch/issues/47535
if backend == "fbgemm":
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_fake_quant)
return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
def assert_valid_qconfig(qconfig: Optional[QConfig], def assert_valid_qconfig(qconfig: Optional[QConfig],
mod: torch.nn.Module) -> None: mod: torch.nn.Module) -> None:

View File

@ -2,10 +2,22 @@ from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, Dict, Tuple, Union from typing import Any, Callable, Dict, Tuple, Union
from .qconfig import QConfigAny import torch
from .fake_quantize import default_weight_fake_quant
from .observer import default_weight_observer
from .qconfig import (
default_reuse_input_qconfig,
get_default_qconfig,
get_default_qat_qconfig,
QConfig,
QConfigAny
)
__all__ = [ __all__ = [
"get_default_qconfig_mapping",
"get_default_qat_qconfig_mapping",
"QConfigMapping", "QConfigMapping",
] ]
@ -18,6 +30,62 @@ MODULE_NAME_DICT_KEY = "module_name"
MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int):
"""
Return the default QConfigMapping for the given quantization type and backend.
"""
if is_qat:
qconfig = get_default_qat_qconfig(backend, version)
else:
qconfig = get_default_qconfig(backend, version)
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
# so we have to modify the weight observer to default_weight_observer or another
# per tensor supported observer.
# see https://github.com/pytorch/pytorch/issues/47535
if backend == "fbgemm":
default_weight = default_weight_fake_quant if is_qat else default_weight_observer
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
else:
qconfig_transpose = qconfig
return QConfigMapping() \
.set_global(qconfig) \
.set_object_type("reshape", default_reuse_input_qconfig) \
.set_object_type(torch.nn.Conv1d, qconfig) \
.set_object_type(torch.nn.Conv2d, qconfig) \
.set_object_type(torch.nn.Conv3d, qconfig) \
.set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
.set_object_type(torch.nn.Linear, qconfig) \
.set_object_type(torch.nn.functional.conv1d, qconfig) \
.set_object_type(torch.nn.functional.conv2d, qconfig) \
.set_object_type(torch.nn.functional.conv3d, qconfig) \
.set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
.set_object_type(torch.nn.functional.linear, qconfig) \
.set_object_type(torch.nn.ReLU, qconfig) \
.set_object_type(torch.nn.functional.relu, qconfig) \
.set_object_type(torch.relu, qconfig) \
.set_object_type(torch.nn.BatchNorm1d, qconfig) \
.set_object_type(torch.nn.BatchNorm2d, qconfig) \
.set_object_type(torch.nn.BatchNorm3d, qconfig)
def get_default_qconfig_mapping(backend="fbgemm", version=0):
"""
Return the default QConfigMapping for post training quantization.
"""
return _get_default_qconfig_mapping(False, backend, version)
def get_default_qat_qconfig_mapping(backend="fbgemm", version=1):
"""
Return the default QConfigMapping for quantization aware training.
"""
return _get_default_qconfig_mapping(True, backend, version)
class QConfigMapping: class QConfigMapping:
""" """
Mapping from model ops to :class:`torch.ao.quantization.QConfig`s. Mapping from model ops to :class:`torch.ao.quantization.QConfig`s.