mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 23:15:01 +08:00
[quant][fx] Add _convert_to_reference_decomposed (#87094)
Summary: _convert_to_reference_decomposed is a private convert function in fx graph mode quantization flow to convert a calibrated/trained model to a reference quantized model with decomposed quantized tensor representations. Test Plan: python test/test_quantization.py TestQuantizeFx.test__convert_to_reference_decomposed_fx Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/87094 Approved by: https://github.com/andrewor14
This commit is contained in:
committed by
PyTorch MergeBot
parent
a12d3d6b49
commit
0e3b5ea026
@ -18,10 +18,12 @@ from torch.ao.quantization.quantize_fx import (
|
||||
prepare_fx,
|
||||
convert_fx,
|
||||
convert_to_reference_fx,
|
||||
_convert_to_reference_decomposed_fx,
|
||||
prepare_qat_fx,
|
||||
fuse_fx,
|
||||
)
|
||||
|
||||
|
||||
from torch.ao.quantization.fx.quantization_patterns import DefaultNodeQuantizeHandler
|
||||
|
||||
from torch.ao.quantization.fx.match_utils import (
|
||||
@ -5237,6 +5239,30 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
with self.assertRaisesRegex(AssertionError, "not supported"):
|
||||
qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend)
|
||||
|
||||
def test__convert_to_reference_decomposed_fx(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
m = M().eval()
|
||||
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
|
||||
example_inputs = (torch.randn(1, 5),)
|
||||
m = prepare_fx(m, qconfig_mapping, example_inputs)
|
||||
m = _convert_to_reference_decomposed_fx(m)
|
||||
expected_occurrence = {
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2,
|
||||
}
|
||||
self.checkGraphModuleNodes(
|
||||
m,
|
||||
expected_node_occurrence=expected_occurrence)
|
||||
# make sure it runs
|
||||
m(*example_inputs)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
class TestQuantizeFxOps(QuantizationTestCase):
|
||||
def setUp(self):
|
||||
|
||||
@ -69,6 +69,8 @@ from .custom_config import (
|
||||
PrepareCustomConfig,
|
||||
)
|
||||
from .lower_to_fbgemm import lower_to_fbgemm
|
||||
# importing the lib so that the quantized_decomposed ops are registered
|
||||
from ._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
|
||||
|
||||
# TODO: revisit this list. Many helper methods shouldn't be public
|
||||
@ -485,7 +487,8 @@ def convert(
|
||||
is_standalone_module: bool = False,
|
||||
_remove_qconfig_flag: bool = True,
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None) -> torch.nn.Module:
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
is_decomposed: bool = False) -> torch.nn.Module:
|
||||
"""
|
||||
We will convert an observed model (a module with observer calls) to a reference
|
||||
quantized model, the rule is simple:
|
||||
@ -497,13 +500,21 @@ def convert(
|
||||
is stored in observed_node_names, we can decide whether we need to swap the
|
||||
module based on this set
|
||||
|
||||
standalone_module means it a submodule that is not inlined in
|
||||
parent module, and will be quantized separately as one unit.
|
||||
Args:
|
||||
* `is_standalone_module`: when this flag is True, it means we are quantizing
|
||||
a submodule that is not inlined in parent module, and will be quantized
|
||||
separately as one unit.
|
||||
|
||||
Returns a quantized standalone module, whether input/output is quantized is
|
||||
specified by prepare_custom_config, with
|
||||
input_quantized_idxs, output_quantized_idxs, please
|
||||
see docs for prepare_fx for details
|
||||
* `is_decomposed`: a boolean flag to indicate whether we want to use the
|
||||
quantize operator for decomposed quantized tensor
|
||||
(torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
|
||||
quantized tensor (torch.quantize_per_tensor)
|
||||
|
||||
Returns:
|
||||
a quantized standalone module, whether input/output is quantized is
|
||||
specified by prepare_custom_config, with
|
||||
input_quantized_idxs, output_quantized_idxs, please
|
||||
see docs for :func:`~torch.ao.quantization.prepare_fx` for details
|
||||
"""
|
||||
if convert_custom_config is None:
|
||||
convert_custom_config = ConvertCustomConfig()
|
||||
@ -595,7 +606,8 @@ def convert(
|
||||
node: Node,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
|
||||
node_name_to_qconfig: Dict[str, QConfigAny],
|
||||
is_decomposed: bool) -> None:
|
||||
""" Replace activation_post_process module call node with quantize and
|
||||
dequantize node
|
||||
|
||||
@ -608,7 +620,7 @@ def convert(
|
||||
assert isinstance(node.target, str)
|
||||
module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
|
||||
observer_module = modules[node.target]
|
||||
maybe_quantize_node_info = get_quantize_node_info(observer_module)
|
||||
maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed)
|
||||
# Skip replacing observers to quant/dequant nodes if the qconfigs of all
|
||||
# consumers and producers of this observer are None
|
||||
skip_replacement = all([
|
||||
@ -626,7 +638,7 @@ def convert(
|
||||
# replace observer node with quant - dequant node
|
||||
with graph.inserting_before(node):
|
||||
input_node = node.args[0]
|
||||
inputs = [input_node]
|
||||
quantize_op_inputs = [input_node]
|
||||
for key, value in qparams.items():
|
||||
# TODO: we can add the information of whether a value needs to
|
||||
# be registered as an attribute in qparams dict itself
|
||||
@ -634,13 +646,22 @@ def convert(
|
||||
# For scale and zero_point values we register them as buffers in the root module.
|
||||
# TODO: maybe need more complex attr name here
|
||||
qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
|
||||
inputs.append(qparam_node)
|
||||
quantize_op_inputs.append(qparam_node)
|
||||
else:
|
||||
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
|
||||
inputs.append(value)
|
||||
quantize_op_inputs.append(value)
|
||||
|
||||
quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
|
||||
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
||||
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
||||
if is_decomposed:
|
||||
# use the same qparams from quantize op
|
||||
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
|
||||
dequantized_node = graph.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor,
|
||||
tuple(dq_inputs),
|
||||
{}
|
||||
)
|
||||
else:
|
||||
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
||||
node.replace_all_uses_with(dequantized_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
@ -711,7 +732,7 @@ def convert(
|
||||
else:
|
||||
replace_observer_with_quantize_dequantize_node(
|
||||
model, model.graph, node, modules, node_name_to_scope,
|
||||
node_name_to_qconfig)
|
||||
node_name_to_qconfig, is_decomposed)
|
||||
elif isinstance(mod, DeQuantStub):
|
||||
replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
|
||||
elif is_observed_standalone_module(mod):
|
||||
|
||||
@ -17,6 +17,7 @@ from torch.ao.quantization.utils import (
|
||||
activation_is_statically_quantized,
|
||||
is_per_tensor,
|
||||
is_per_channel,
|
||||
to_underlying_dtype,
|
||||
)
|
||||
from torch.ao.quantization.quantize import is_activation_post_process
|
||||
|
||||
@ -27,6 +28,8 @@ from torch.fx.graph import (
|
||||
Node,
|
||||
)
|
||||
from .custom_config import PrepareCustomConfig
|
||||
# importing the lib so that the quantized_decomposed ops are registered
|
||||
from ._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
|
||||
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
|
||||
from collections import namedtuple
|
||||
@ -160,11 +163,22 @@ def get_per_tensor_qparams(activation_post_process):
|
||||
dtype = activation_post_process.dtype
|
||||
return scale, zero_point, dtype
|
||||
|
||||
def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[str, Union[Callable, str], Dict[str, Any]]]:
|
||||
''' Given an activation_post_process module,
|
||||
return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
|
||||
of extracted qparams from the module
|
||||
'''
|
||||
def get_quantize_node_info(
|
||||
activation_post_process: Callable,
|
||||
is_decomposed: bool
|
||||
) -> Optional[Tuple[str, Union[Callable[..., Any], str], Dict[str, Any]]]:
|
||||
""" Extract information about quantize op from activation_post_process module
|
||||
Args:
|
||||
* `activation_post_process`: observer module instance or fake quant module instance
|
||||
after calibration/QAT
|
||||
* `is_decomposed`: a boolean flag to indicate whether we want to use the
|
||||
quantize operator for decomposed quantized tensor (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
|
||||
quantized tensor (torch.quantize_per_tensor)
|
||||
|
||||
Returns
|
||||
node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
|
||||
of extracted qparams from the module
|
||||
"""
|
||||
dtype = activation_post_process.dtype # type: ignore[attr-defined]
|
||||
compute_dtype = None
|
||||
if hasattr(activation_post_process, "compute_dtype"):
|
||||
@ -177,17 +191,36 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[
|
||||
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
|
||||
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined]
|
||||
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype}
|
||||
quantize_op = torch.quantize_per_channel
|
||||
if is_decomposed:
|
||||
raise NotImplementedError("decomposed quantize_per_channel op not implemented yet")
|
||||
else:
|
||||
quantize_op = torch.quantize_per_channel
|
||||
else:
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
|
||||
quantize_op = torch.quantize_per_tensor
|
||||
if is_decomposed:
|
||||
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
|
||||
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
|
||||
dtype = to_underlying_dtype(dtype)
|
||||
qparams = {
|
||||
"_scale_": scale,
|
||||
"_zero_point_": zero_point,
|
||||
"_quant_min": quant_max,
|
||||
"_quant_max": quant_max,
|
||||
"_dtype_": dtype
|
||||
}
|
||||
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor
|
||||
else:
|
||||
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
|
||||
quantize_op = torch.quantize_per_tensor
|
||||
elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]:
|
||||
# TODO(future PR): switch compute_dtype to is_dynamic
|
||||
# dynamic quantization
|
||||
node_type = "call_function"
|
||||
quantize_op = torch.quantize_per_tensor_dynamic
|
||||
if is_decomposed:
|
||||
raise NotImplementedError("decomposed quantize_per_tensor_dynamic op not implemented yet")
|
||||
else:
|
||||
quantize_op = torch.quantize_per_tensor_dynamic
|
||||
# TODO: get reduce range from observer
|
||||
# reduce_range = activation_post_process.reduce_range
|
||||
reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
|
||||
@ -199,8 +232,9 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[
|
||||
else:
|
||||
warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}")
|
||||
return None
|
||||
return node_type, quantize_op, qparams
|
||||
return node_type, quantize_op, qparams # type: ignore[return-value]
|
||||
|
||||
# TODO: looks like this is not used, remove
|
||||
def quantize_node(
|
||||
in_node: Node,
|
||||
obs_module: torch.nn.Module,
|
||||
@ -247,7 +281,8 @@ def quantize_node(
|
||||
module_path = ""
|
||||
root_module = modules['']
|
||||
graph = quantized_graph
|
||||
maybe_quantize_node_info = get_quantize_node_info(obs_module)
|
||||
is_decomposed_qtensor = False
|
||||
maybe_quantize_node_info = get_quantize_node_info(obs_module, is_decomposed_qtensor)
|
||||
assert maybe_quantize_node_info is not None, \
|
||||
f"Expecting quantize node info not to be None, observer: {obs_module}"
|
||||
node_type, quantize_op, qparams = maybe_quantize_node_info
|
||||
|
||||
@ -530,6 +530,7 @@ def _convert_fx(
|
||||
_remove_qconfig: bool = True,
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
is_decomposed: bool = False,
|
||||
) -> torch.nn.Module:
|
||||
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
|
||||
"""
|
||||
@ -552,6 +553,7 @@ def _convert_fx(
|
||||
_remove_qconfig_flag=_remove_qconfig,
|
||||
qconfig_mapping=qconfig_mapping,
|
||||
backend_config=backend_config,
|
||||
is_decomposed=is_decomposed,
|
||||
)
|
||||
|
||||
preserved_attributes = convert_custom_config.preserved_attributes
|
||||
@ -676,6 +678,59 @@ def convert_to_reference_fx(
|
||||
backend_config=backend_config,
|
||||
)
|
||||
|
||||
def _convert_to_reference_decomposed_fx(
|
||||
graph_module: GraphModule,
|
||||
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
|
||||
_remove_qconfig: bool = True,
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
) -> torch.nn.Module:
|
||||
r""" Convert a calibrated or trained model to a reference quantized model, with
|
||||
decomposed representation for quantized Tensor
|
||||
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
|
||||
reference quantzied model is a standard representation of a quantized model provided
|
||||
by FX Graph Mode Quantization, it can be further lowered to run on the target
|
||||
hardware, like accelerators
|
||||
|
||||
Note: this is not public API
|
||||
|
||||
Args:
|
||||
* `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
|
||||
|
||||
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
|
||||
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
||||
|
||||
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
|
||||
|
||||
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
|
||||
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
||||
|
||||
* `backend_config` (BackendConfig): A configuration for the backend which describes how
|
||||
operators should be quantized in the backend. See
|
||||
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
||||
|
||||
Return:
|
||||
A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor
|
||||
|
||||
Example::
|
||||
|
||||
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
|
||||
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
|
||||
# e.g. backend_config = get_default_backend_config("fbgemm")
|
||||
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
|
||||
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_to_reference_decomposed_fx")
|
||||
return _convert_fx(
|
||||
graph_module,
|
||||
is_reference=True,
|
||||
convert_custom_config=convert_custom_config,
|
||||
_remove_qconfig=_remove_qconfig,
|
||||
qconfig_mapping=qconfig_mapping,
|
||||
backend_config=backend_config,
|
||||
is_decomposed=True,
|
||||
)
|
||||
|
||||
|
||||
def _convert_standalone_module_fx(
|
||||
graph_module: GraphModule,
|
||||
|
||||
@ -140,6 +140,17 @@ def getattr_from_fqn(obj: Any, fqn: str) -> Any:
|
||||
"""
|
||||
return functools.reduce(getattr, fqn.split("."), obj)
|
||||
|
||||
def to_underlying_dtype(qdtype):
|
||||
DTYPE_MAPPING = {
|
||||
torch.quint8: torch.uint8,
|
||||
torch.qint8: torch.int8,
|
||||
torch.qint32: torch.int32,
|
||||
torch.quint4x2: torch.uint8,
|
||||
torch.quint2x4: torch.uint8,
|
||||
}
|
||||
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + qdtype
|
||||
return DTYPE_MAPPING[qdtype]
|
||||
|
||||
def get_qparam_dict(observer_or_fake_quant):
|
||||
qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
|
||||
dtype = observer_or_fake_quant.dtype
|
||||
@ -562,4 +573,5 @@ __all__ = [
|
||||
"calculate_qmin_qmax",
|
||||
"has_no_children_ignoring_parametrizations",
|
||||
"get_fqn_to_example_inputs",
|
||||
"to_underlying_dtype",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user