mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Quant][fx] Add get_default_qconfig_mapping
Summary: This follows https://github.com/pytorch/pytorch/pull/78452, which replaced the qconfig_dict with QConfigMapping. This PR additionally replaces get_default_*qconfig_dict with get_default_*qconfig_mapping. For backward compatibility, we deprecate the old functions instead of removing them. Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo, supriyar Pull Request resolved: https://github.com/pytorch/pytorch/pull/79618 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
355a1c8c3f
commit
61a1eef7fc
@ -50,8 +50,8 @@ from torch.ao.quantization import (
|
||||
float_qparams_weight_only_qconfig_4bit,
|
||||
get_default_qconfig,
|
||||
get_default_qat_qconfig,
|
||||
get_default_qconfig_dict,
|
||||
get_default_qat_qconfig_dict,
|
||||
get_default_qconfig_mapping,
|
||||
get_default_qat_qconfig_mapping,
|
||||
fuse_modules,
|
||||
fuse_modules_qat,
|
||||
prepare,
|
||||
@ -1854,7 +1854,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
|
||||
for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
|
||||
m = model(relu).eval()
|
||||
qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm")
|
||||
qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping("fbgemm")
|
||||
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
|
||||
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
|
||||
|
||||
@ -4388,7 +4388,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
for M, is_qat in options:
|
||||
m = M1().eval()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
|
||||
m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
|
||||
m = convert_fx(m)
|
||||
node_list = [
|
||||
ns.call_function(torch.quantize_per_tensor),
|
||||
@ -4401,7 +4401,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
expected_node_list=node_list)
|
||||
|
||||
m = M2().eval()
|
||||
m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
|
||||
m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
|
||||
m = convert_fx(m)
|
||||
node_occurrence = {
|
||||
ns.call_function(torch.quantize_per_tensor): 0,
|
||||
@ -4426,7 +4426,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
mp = prepare_fx(m, get_default_qconfig_dict(), example_inputs=(torch.randn(1, 1),))
|
||||
mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),))
|
||||
|
||||
found_stack_trace = False
|
||||
for n in mp.graph.nodes:
|
||||
@ -4541,7 +4541,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
return x
|
||||
|
||||
backends = ["qnnpack", "fbgemm"]
|
||||
for func in [get_default_qconfig_dict, get_default_qat_qconfig_dict]:
|
||||
for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]:
|
||||
for backend in backends:
|
||||
m = M().eval()
|
||||
qconfig_dict = func(backend)
|
||||
@ -4581,8 +4581,8 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)
|
||||
|
||||
# Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
|
||||
_test(prepare_fx, get_default_qconfig_dict())
|
||||
_test(prepare_qat_fx, get_default_qat_qconfig_dict())
|
||||
_test(prepare_fx, get_default_qconfig_mapping())
|
||||
_test(prepare_qat_fx, get_default_qat_qconfig_mapping())
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
class TestQuantizeFxOps(QuantizationTestCase):
|
||||
|
@ -335,52 +335,17 @@ default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig(
|
||||
eps=2 ** -12),
|
||||
weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
|
||||
|
||||
def _get_default_qconfig_dict_helper(qconfig, qconfig_transpose):
|
||||
return {
|
||||
"": qconfig,
|
||||
"object_type": [("reshape", default_reuse_input_qconfig),
|
||||
(torch.nn.Conv1d, qconfig),
|
||||
(torch.nn.Conv2d, qconfig),
|
||||
(torch.nn.Conv3d, qconfig),
|
||||
(torch.nn.ConvTranspose1d, qconfig_transpose),
|
||||
(torch.nn.ConvTranspose2d, qconfig_transpose),
|
||||
(torch.nn.ConvTranspose3d, qconfig_transpose),
|
||||
(torch.nn.Linear, qconfig),
|
||||
(torch.nn.functional.conv1d, qconfig),
|
||||
(torch.nn.functional.conv2d, qconfig),
|
||||
(torch.nn.functional.conv3d, qconfig),
|
||||
(torch.nn.functional.conv_transpose1d, qconfig_transpose),
|
||||
(torch.nn.functional.conv_transpose2d, qconfig_transpose),
|
||||
(torch.nn.functional.conv_transpose3d, qconfig_transpose),
|
||||
(torch.nn.functional.linear, qconfig),
|
||||
(torch.nn.ReLU, qconfig),
|
||||
(torch.nn.functional.relu, qconfig),
|
||||
(torch.relu, qconfig),
|
||||
(torch.nn.BatchNorm1d, qconfig),
|
||||
(torch.nn.BatchNorm2d, qconfig),
|
||||
(torch.nn.BatchNorm3d, qconfig)]}
|
||||
|
||||
def get_default_qconfig_dict(backend='fbgemm', version=0):
|
||||
qconfig = get_default_qconfig(backend, version)
|
||||
qconfig_transpose = qconfig
|
||||
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
|
||||
# so we have to modify the weight observer to default_weight_observer or another
|
||||
# per tensor supported observer.
|
||||
# see https://github.com/pytorch/pytorch/issues/47535
|
||||
if backend == "fbgemm":
|
||||
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_observer)
|
||||
return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
|
||||
warnings.warn(
|
||||
"torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in "
|
||||
"a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")
|
||||
return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict()
|
||||
|
||||
def get_default_qat_qconfig_dict(backend='fbgemm', version=1):
|
||||
qconfig = get_default_qat_qconfig(backend, version)
|
||||
qconfig_transpose = qconfig
|
||||
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
|
||||
# so we have to modify the weight observer to default_weight_observer or another
|
||||
# per tensor supported observer
|
||||
# see https://github.com/pytorch/pytorch/issues/47535
|
||||
if backend == "fbgemm":
|
||||
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_fake_quant)
|
||||
return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
|
||||
warnings.warn(
|
||||
"torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in "
|
||||
"a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")
|
||||
return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict()
|
||||
|
||||
def assert_valid_qconfig(qconfig: Optional[QConfig],
|
||||
mod: torch.nn.Module) -> None:
|
||||
|
@ -2,10 +2,22 @@ from __future__ import annotations
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
from .qconfig import QConfigAny
|
||||
import torch
|
||||
|
||||
from .fake_quantize import default_weight_fake_quant
|
||||
from .observer import default_weight_observer
|
||||
from .qconfig import (
|
||||
default_reuse_input_qconfig,
|
||||
get_default_qconfig,
|
||||
get_default_qat_qconfig,
|
||||
QConfig,
|
||||
QConfigAny
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_default_qconfig_mapping",
|
||||
"get_default_qat_qconfig_mapping",
|
||||
"QConfigMapping",
|
||||
]
|
||||
|
||||
@ -18,6 +30,62 @@ MODULE_NAME_DICT_KEY = "module_name"
|
||||
MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
|
||||
|
||||
|
||||
def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int):
|
||||
"""
|
||||
Return the default QConfigMapping for the given quantization type and backend.
|
||||
"""
|
||||
if is_qat:
|
||||
qconfig = get_default_qat_qconfig(backend, version)
|
||||
else:
|
||||
qconfig = get_default_qconfig(backend, version)
|
||||
|
||||
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
|
||||
# so we have to modify the weight observer to default_weight_observer or another
|
||||
# per tensor supported observer.
|
||||
# see https://github.com/pytorch/pytorch/issues/47535
|
||||
if backend == "fbgemm":
|
||||
default_weight = default_weight_fake_quant if is_qat else default_weight_observer
|
||||
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
|
||||
else:
|
||||
qconfig_transpose = qconfig
|
||||
|
||||
return QConfigMapping() \
|
||||
.set_global(qconfig) \
|
||||
.set_object_type("reshape", default_reuse_input_qconfig) \
|
||||
.set_object_type(torch.nn.Conv1d, qconfig) \
|
||||
.set_object_type(torch.nn.Conv2d, qconfig) \
|
||||
.set_object_type(torch.nn.Conv3d, qconfig) \
|
||||
.set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
|
||||
.set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
|
||||
.set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
|
||||
.set_object_type(torch.nn.Linear, qconfig) \
|
||||
.set_object_type(torch.nn.functional.conv1d, qconfig) \
|
||||
.set_object_type(torch.nn.functional.conv2d, qconfig) \
|
||||
.set_object_type(torch.nn.functional.conv3d, qconfig) \
|
||||
.set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
|
||||
.set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
|
||||
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
|
||||
.set_object_type(torch.nn.functional.linear, qconfig) \
|
||||
.set_object_type(torch.nn.ReLU, qconfig) \
|
||||
.set_object_type(torch.nn.functional.relu, qconfig) \
|
||||
.set_object_type(torch.relu, qconfig) \
|
||||
.set_object_type(torch.nn.BatchNorm1d, qconfig) \
|
||||
.set_object_type(torch.nn.BatchNorm2d, qconfig) \
|
||||
.set_object_type(torch.nn.BatchNorm3d, qconfig)
|
||||
|
||||
def get_default_qconfig_mapping(backend="fbgemm", version=0):
|
||||
"""
|
||||
Return the default QConfigMapping for post training quantization.
|
||||
"""
|
||||
return _get_default_qconfig_mapping(False, backend, version)
|
||||
|
||||
def get_default_qat_qconfig_mapping(backend="fbgemm", version=1):
|
||||
"""
|
||||
Return the default QConfigMapping for quantization aware training.
|
||||
"""
|
||||
return _get_default_qconfig_mapping(True, backend, version)
|
||||
|
||||
|
||||
class QConfigMapping:
|
||||
"""
|
||||
Mapping from model ops to :class:`torch.ao.quantization.QConfig`s.
|
||||
|
Reference in New Issue
Block a user