mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Quant][fx] Add get_default_qconfig_mapping
Summary: This follows https://github.com/pytorch/pytorch/pull/78452, which replaced the qconfig_dict with QConfigMapping. This PR additionally replaces get_default_*qconfig_dict with get_default_*qconfig_mapping. For backward compatibility, we deprecate the old functions instead of removing them. Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo, supriyar Pull Request resolved: https://github.com/pytorch/pytorch/pull/79618 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
355a1c8c3f
commit
61a1eef7fc
@ -50,8 +50,8 @@ from torch.ao.quantization import (
|
|||||||
float_qparams_weight_only_qconfig_4bit,
|
float_qparams_weight_only_qconfig_4bit,
|
||||||
get_default_qconfig,
|
get_default_qconfig,
|
||||||
get_default_qat_qconfig,
|
get_default_qat_qconfig,
|
||||||
get_default_qconfig_dict,
|
get_default_qconfig_mapping,
|
||||||
get_default_qat_qconfig_dict,
|
get_default_qat_qconfig_mapping,
|
||||||
fuse_modules,
|
fuse_modules,
|
||||||
fuse_modules_qat,
|
fuse_modules_qat,
|
||||||
prepare,
|
prepare,
|
||||||
@ -1854,7 +1854,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||||||
for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
|
for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
|
||||||
for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
|
for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
|
||||||
m = model(relu).eval()
|
m = model(relu).eval()
|
||||||
qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm")
|
qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping("fbgemm")
|
||||||
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
|
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
|
||||||
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
|
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
|
||||||
|
|
||||||
@ -4388,7 +4388,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||||||
for M, is_qat in options:
|
for M, is_qat in options:
|
||||||
m = M1().eval()
|
m = M1().eval()
|
||||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||||
m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
|
m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
|
||||||
m = convert_fx(m)
|
m = convert_fx(m)
|
||||||
node_list = [
|
node_list = [
|
||||||
ns.call_function(torch.quantize_per_tensor),
|
ns.call_function(torch.quantize_per_tensor),
|
||||||
@ -4401,7 +4401,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||||||
expected_node_list=node_list)
|
expected_node_list=node_list)
|
||||||
|
|
||||||
m = M2().eval()
|
m = M2().eval()
|
||||||
m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs)
|
m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
|
||||||
m = convert_fx(m)
|
m = convert_fx(m)
|
||||||
node_occurrence = {
|
node_occurrence = {
|
||||||
ns.call_function(torch.quantize_per_tensor): 0,
|
ns.call_function(torch.quantize_per_tensor): 0,
|
||||||
@ -4426,7 +4426,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
m = M().eval()
|
m = M().eval()
|
||||||
mp = prepare_fx(m, get_default_qconfig_dict(), example_inputs=(torch.randn(1, 1),))
|
mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),))
|
||||||
|
|
||||||
found_stack_trace = False
|
found_stack_trace = False
|
||||||
for n in mp.graph.nodes:
|
for n in mp.graph.nodes:
|
||||||
@ -4541,7 +4541,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
backends = ["qnnpack", "fbgemm"]
|
backends = ["qnnpack", "fbgemm"]
|
||||||
for func in [get_default_qconfig_dict, get_default_qat_qconfig_dict]:
|
for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]:
|
||||||
for backend in backends:
|
for backend in backends:
|
||||||
m = M().eval()
|
m = M().eval()
|
||||||
qconfig_dict = func(backend)
|
qconfig_dict = func(backend)
|
||||||
@ -4581,8 +4581,8 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||||||
prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)
|
prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)
|
||||||
|
|
||||||
# Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
|
# Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
|
||||||
_test(prepare_fx, get_default_qconfig_dict())
|
_test(prepare_fx, get_default_qconfig_mapping())
|
||||||
_test(prepare_qat_fx, get_default_qat_qconfig_dict())
|
_test(prepare_qat_fx, get_default_qat_qconfig_mapping())
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
class TestQuantizeFxOps(QuantizationTestCase):
|
class TestQuantizeFxOps(QuantizationTestCase):
|
||||||
|
@ -335,52 +335,17 @@ default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig(
|
|||||||
eps=2 ** -12),
|
eps=2 ** -12),
|
||||||
weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
|
weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
|
||||||
|
|
||||||
def _get_default_qconfig_dict_helper(qconfig, qconfig_transpose):
|
|
||||||
return {
|
|
||||||
"": qconfig,
|
|
||||||
"object_type": [("reshape", default_reuse_input_qconfig),
|
|
||||||
(torch.nn.Conv1d, qconfig),
|
|
||||||
(torch.nn.Conv2d, qconfig),
|
|
||||||
(torch.nn.Conv3d, qconfig),
|
|
||||||
(torch.nn.ConvTranspose1d, qconfig_transpose),
|
|
||||||
(torch.nn.ConvTranspose2d, qconfig_transpose),
|
|
||||||
(torch.nn.ConvTranspose3d, qconfig_transpose),
|
|
||||||
(torch.nn.Linear, qconfig),
|
|
||||||
(torch.nn.functional.conv1d, qconfig),
|
|
||||||
(torch.nn.functional.conv2d, qconfig),
|
|
||||||
(torch.nn.functional.conv3d, qconfig),
|
|
||||||
(torch.nn.functional.conv_transpose1d, qconfig_transpose),
|
|
||||||
(torch.nn.functional.conv_transpose2d, qconfig_transpose),
|
|
||||||
(torch.nn.functional.conv_transpose3d, qconfig_transpose),
|
|
||||||
(torch.nn.functional.linear, qconfig),
|
|
||||||
(torch.nn.ReLU, qconfig),
|
|
||||||
(torch.nn.functional.relu, qconfig),
|
|
||||||
(torch.relu, qconfig),
|
|
||||||
(torch.nn.BatchNorm1d, qconfig),
|
|
||||||
(torch.nn.BatchNorm2d, qconfig),
|
|
||||||
(torch.nn.BatchNorm3d, qconfig)]}
|
|
||||||
|
|
||||||
def get_default_qconfig_dict(backend='fbgemm', version=0):
|
def get_default_qconfig_dict(backend='fbgemm', version=0):
|
||||||
qconfig = get_default_qconfig(backend, version)
|
warnings.warn(
|
||||||
qconfig_transpose = qconfig
|
"torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in "
|
||||||
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
|
"a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")
|
||||||
# so we have to modify the weight observer to default_weight_observer or another
|
return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict()
|
||||||
# per tensor supported observer.
|
|
||||||
# see https://github.com/pytorch/pytorch/issues/47535
|
|
||||||
if backend == "fbgemm":
|
|
||||||
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_observer)
|
|
||||||
return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
|
|
||||||
|
|
||||||
def get_default_qat_qconfig_dict(backend='fbgemm', version=1):
|
def get_default_qat_qconfig_dict(backend='fbgemm', version=1):
|
||||||
qconfig = get_default_qat_qconfig(backend, version)
|
warnings.warn(
|
||||||
qconfig_transpose = qconfig
|
"torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in "
|
||||||
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
|
"a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")
|
||||||
# so we have to modify the weight observer to default_weight_observer or another
|
return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict()
|
||||||
# per tensor supported observer
|
|
||||||
# see https://github.com/pytorch/pytorch/issues/47535
|
|
||||||
if backend == "fbgemm":
|
|
||||||
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_fake_quant)
|
|
||||||
return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
|
|
||||||
|
|
||||||
def assert_valid_qconfig(qconfig: Optional[QConfig],
|
def assert_valid_qconfig(qconfig: Optional[QConfig],
|
||||||
mod: torch.nn.Module) -> None:
|
mod: torch.nn.Module) -> None:
|
||||||
|
@ -2,10 +2,22 @@ from __future__ import annotations
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Callable, Dict, Tuple, Union
|
from typing import Any, Callable, Dict, Tuple, Union
|
||||||
|
|
||||||
from .qconfig import QConfigAny
|
import torch
|
||||||
|
|
||||||
|
from .fake_quantize import default_weight_fake_quant
|
||||||
|
from .observer import default_weight_observer
|
||||||
|
from .qconfig import (
|
||||||
|
default_reuse_input_qconfig,
|
||||||
|
get_default_qconfig,
|
||||||
|
get_default_qat_qconfig,
|
||||||
|
QConfig,
|
||||||
|
QConfigAny
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"get_default_qconfig_mapping",
|
||||||
|
"get_default_qat_qconfig_mapping",
|
||||||
"QConfigMapping",
|
"QConfigMapping",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -18,6 +30,62 @@ MODULE_NAME_DICT_KEY = "module_name"
|
|||||||
MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
|
MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int):
|
||||||
|
"""
|
||||||
|
Return the default QConfigMapping for the given quantization type and backend.
|
||||||
|
"""
|
||||||
|
if is_qat:
|
||||||
|
qconfig = get_default_qat_qconfig(backend, version)
|
||||||
|
else:
|
||||||
|
qconfig = get_default_qconfig(backend, version)
|
||||||
|
|
||||||
|
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
|
||||||
|
# so we have to modify the weight observer to default_weight_observer or another
|
||||||
|
# per tensor supported observer.
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/47535
|
||||||
|
if backend == "fbgemm":
|
||||||
|
default_weight = default_weight_fake_quant if is_qat else default_weight_observer
|
||||||
|
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
|
||||||
|
else:
|
||||||
|
qconfig_transpose = qconfig
|
||||||
|
|
||||||
|
return QConfigMapping() \
|
||||||
|
.set_global(qconfig) \
|
||||||
|
.set_object_type("reshape", default_reuse_input_qconfig) \
|
||||||
|
.set_object_type(torch.nn.Conv1d, qconfig) \
|
||||||
|
.set_object_type(torch.nn.Conv2d, qconfig) \
|
||||||
|
.set_object_type(torch.nn.Conv3d, qconfig) \
|
||||||
|
.set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
|
||||||
|
.set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
|
||||||
|
.set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
|
||||||
|
.set_object_type(torch.nn.Linear, qconfig) \
|
||||||
|
.set_object_type(torch.nn.functional.conv1d, qconfig) \
|
||||||
|
.set_object_type(torch.nn.functional.conv2d, qconfig) \
|
||||||
|
.set_object_type(torch.nn.functional.conv3d, qconfig) \
|
||||||
|
.set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
|
||||||
|
.set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
|
||||||
|
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
|
||||||
|
.set_object_type(torch.nn.functional.linear, qconfig) \
|
||||||
|
.set_object_type(torch.nn.ReLU, qconfig) \
|
||||||
|
.set_object_type(torch.nn.functional.relu, qconfig) \
|
||||||
|
.set_object_type(torch.relu, qconfig) \
|
||||||
|
.set_object_type(torch.nn.BatchNorm1d, qconfig) \
|
||||||
|
.set_object_type(torch.nn.BatchNorm2d, qconfig) \
|
||||||
|
.set_object_type(torch.nn.BatchNorm3d, qconfig)
|
||||||
|
|
||||||
|
def get_default_qconfig_mapping(backend="fbgemm", version=0):
|
||||||
|
"""
|
||||||
|
Return the default QConfigMapping for post training quantization.
|
||||||
|
"""
|
||||||
|
return _get_default_qconfig_mapping(False, backend, version)
|
||||||
|
|
||||||
|
def get_default_qat_qconfig_mapping(backend="fbgemm", version=1):
|
||||||
|
"""
|
||||||
|
Return the default QConfigMapping for quantization aware training.
|
||||||
|
"""
|
||||||
|
return _get_default_qconfig_mapping(True, backend, version)
|
||||||
|
|
||||||
|
|
||||||
class QConfigMapping:
|
class QConfigMapping:
|
||||||
"""
|
"""
|
||||||
Mapping from model ops to :class:`torch.ao.quantization.QConfig`s.
|
Mapping from model ops to :class:`torch.ao.quantization.QConfig`s.
|
||||||
|
Reference in New Issue
Block a user