[quant][fx] Add _convert_to_reference_decomposed (#87094)

Summary:
_convert_to_reference_decomposed is a private convert function in fx graph mode quantization flow to convert
a calibrated/trained model to a reference quantized model with decomposed quantized tensor representations.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test__convert_to_reference_decomposed_fx

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87094
Approved by: https://github.com/andrewor14
This commit is contained in:
Jerry Zhang
2022-10-26 14:43:41 -07:00
committed by PyTorch MergeBot
parent a12d3d6b49
commit 0e3b5ea026
5 changed files with 175 additions and 26 deletions

View File

@ -18,10 +18,12 @@ from torch.ao.quantization.quantize_fx import (
prepare_fx,
convert_fx,
convert_to_reference_fx,
_convert_to_reference_decomposed_fx,
prepare_qat_fx,
fuse_fx,
)
from torch.ao.quantization.fx.quantization_patterns import DefaultNodeQuantizeHandler
from torch.ao.quantization.fx.match_utils import (
@ -5237,6 +5239,30 @@ class TestQuantizeFx(QuantizationTestCase):
with self.assertRaisesRegex(AssertionError, "not supported"):
qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend)
def test__convert_to_reference_decomposed_fx(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
m = M().eval()
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
example_inputs = (torch.randn(1, 5),)
m = prepare_fx(m, qconfig_mapping, example_inputs)
m = _convert_to_reference_decomposed_fx(m)
expected_occurrence = {
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2,
}
self.checkGraphModuleNodes(
m,
expected_node_occurrence=expected_occurrence)
# make sure it runs
m(*example_inputs)
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
def setUp(self):

View File

@ -69,6 +69,8 @@ from .custom_config import (
PrepareCustomConfig,
)
from .lower_to_fbgemm import lower_to_fbgemm
# importing the lib so that the quantized_decomposed ops are registered
from ._decomposed import quantized_decomposed_lib # noqa: F401
# TODO: revisit this list. Many helper methods shouldn't be public
@ -485,7 +487,8 @@ def convert(
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None) -> torch.nn.Module:
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
is_decomposed: bool = False) -> torch.nn.Module:
"""
We will convert an observed model (a module with observer calls) to a reference
quantized model, the rule is simple:
@ -497,13 +500,21 @@ def convert(
is stored in observed_node_names, we can decide whether we need to swap the
module based on this set
standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Args:
* `is_standalone_module`: when this flag is True, it means we are quantizing
a submodule that is not inlined in parent module, and will be quantized
separately as one unit.
Returns a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config, with
input_quantized_idxs, output_quantized_idxs, please
see docs for prepare_fx for details
* `is_decomposed`: a boolean flag to indicate whether we want to use the
quantize operator for decomposed quantized tensor
(torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
quantized tensor (torch.quantize_per_tensor)
Returns:
a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config, with
input_quantized_idxs, output_quantized_idxs, please
see docs for :func:`~torch.ao.quantization.prepare_fx` for details
"""
if convert_custom_config is None:
convert_custom_config = ConvertCustomConfig()
@ -595,7 +606,8 @@ def convert(
node: Node,
modules: Dict[str, torch.nn.Module],
node_name_to_scope: Dict[str, Tuple[str, type]],
node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
node_name_to_qconfig: Dict[str, QConfigAny],
is_decomposed: bool) -> None:
""" Replace activation_post_process module call node with quantize and
dequantize node
@ -608,7 +620,7 @@ def convert(
assert isinstance(node.target, str)
module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
observer_module = modules[node.target]
maybe_quantize_node_info = get_quantize_node_info(observer_module)
maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed)
# Skip replacing observers to quant/dequant nodes if the qconfigs of all
# consumers and producers of this observer are None
skip_replacement = all([
@ -626,7 +638,7 @@ def convert(
# replace observer node with quant - dequant node
with graph.inserting_before(node):
input_node = node.args[0]
inputs = [input_node]
quantize_op_inputs = [input_node]
for key, value in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
@ -634,13 +646,22 @@ def convert(
# For scale and zero_point values we register them as buffers in the root module.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
inputs.append(qparam_node)
quantize_op_inputs.append(qparam_node)
else:
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
inputs.append(value)
quantize_op_inputs.append(value)
quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
if is_decomposed:
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
dequantized_node = graph.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor,
tuple(dq_inputs),
{}
)
else:
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
node.replace_all_uses_with(dequantized_node)
graph.erase_node(node)
@ -711,7 +732,7 @@ def convert(
else:
replace_observer_with_quantize_dequantize_node(
model, model.graph, node, modules, node_name_to_scope,
node_name_to_qconfig)
node_name_to_qconfig, is_decomposed)
elif isinstance(mod, DeQuantStub):
replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
elif is_observed_standalone_module(mod):

View File

@ -17,6 +17,7 @@ from torch.ao.quantization.utils import (
activation_is_statically_quantized,
is_per_tensor,
is_per_channel,
to_underlying_dtype,
)
from torch.ao.quantization.quantize import is_activation_post_process
@ -27,6 +28,8 @@ from torch.fx.graph import (
Node,
)
from .custom_config import PrepareCustomConfig
# importing the lib so that the quantized_decomposed ops are registered
from ._decomposed import quantized_decomposed_lib # noqa: F401
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
from collections import namedtuple
@ -160,11 +163,22 @@ def get_per_tensor_qparams(activation_post_process):
dtype = activation_post_process.dtype
return scale, zero_point, dtype
def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[str, Union[Callable, str], Dict[str, Any]]]:
''' Given an activation_post_process module,
return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
of extracted qparams from the module
'''
def get_quantize_node_info(
activation_post_process: Callable,
is_decomposed: bool
) -> Optional[Tuple[str, Union[Callable[..., Any], str], Dict[str, Any]]]:
""" Extract information about quantize op from activation_post_process module
Args:
* `activation_post_process`: observer module instance or fake quant module instance
after calibration/QAT
* `is_decomposed`: a boolean flag to indicate whether we want to use the
quantize operator for decomposed quantized tensor (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
quantized tensor (torch.quantize_per_tensor)
Returns
node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
of extracted qparams from the module
"""
dtype = activation_post_process.dtype # type: ignore[attr-defined]
compute_dtype = None
if hasattr(activation_post_process, "compute_dtype"):
@ -177,17 +191,36 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined]
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype}
quantize_op = torch.quantize_per_channel
if is_decomposed:
raise NotImplementedError("decomposed quantize_per_channel op not implemented yet")
else:
quantize_op = torch.quantize_per_channel
else:
scale = float(scale)
zero_point = int(zero_point)
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
quantize_op = torch.quantize_per_tensor
if is_decomposed:
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
dtype = to_underlying_dtype(dtype)
qparams = {
"_scale_": scale,
"_zero_point_": zero_point,
"_quant_min": quant_max,
"_quant_max": quant_max,
"_dtype_": dtype
}
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor
else:
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
quantize_op = torch.quantize_per_tensor
elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]:
# TODO(future PR): switch compute_dtype to is_dynamic
# dynamic quantization
node_type = "call_function"
quantize_op = torch.quantize_per_tensor_dynamic
if is_decomposed:
raise NotImplementedError("decomposed quantize_per_tensor_dynamic op not implemented yet")
else:
quantize_op = torch.quantize_per_tensor_dynamic
# TODO: get reduce range from observer
# reduce_range = activation_post_process.reduce_range
reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
@ -199,8 +232,9 @@ def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[
else:
warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}")
return None
return node_type, quantize_op, qparams
return node_type, quantize_op, qparams # type: ignore[return-value]
# TODO: looks like this is not used, remove
def quantize_node(
in_node: Node,
obs_module: torch.nn.Module,
@ -247,7 +281,8 @@ def quantize_node(
module_path = ""
root_module = modules['']
graph = quantized_graph
maybe_quantize_node_info = get_quantize_node_info(obs_module)
is_decomposed_qtensor = False
maybe_quantize_node_info = get_quantize_node_info(obs_module, is_decomposed_qtensor)
assert maybe_quantize_node_info is not None, \
f"Expecting quantize node info not to be None, observer: {obs_module}"
node_type, quantize_op, qparams = maybe_quantize_node_info

View File

@ -530,6 +530,7 @@ def _convert_fx(
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
is_decomposed: bool = False,
) -> torch.nn.Module:
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
"""
@ -552,6 +553,7 @@ def _convert_fx(
_remove_qconfig_flag=_remove_qconfig,
qconfig_mapping=qconfig_mapping,
backend_config=backend_config,
is_decomposed=is_decomposed,
)
preserved_attributes = convert_custom_config.preserved_attributes
@ -676,6 +678,59 @@ def convert_to_reference_fx(
backend_config=backend_config,
)
def _convert_to_reference_decomposed_fx(
graph_module: GraphModule,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> torch.nn.Module:
r""" Convert a calibrated or trained model to a reference quantized model, with
decomposed representation for quantized Tensor
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
reference quantzied model is a standard representation of a quantized model provided
by FX Graph Mode Quantization, it can be further lowered to run on the target
hardware, like accelerators
Note: this is not public API
Args:
* `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
* `backend_config` (BackendConfig): A configuration for the backend which describes how
operators should be quantized in the backend. See
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
Return:
A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor
Example::
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
# e.g. backend_config = get_default_backend_config("fbgemm")
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_to_reference_decomposed_fx")
return _convert_fx(
graph_module,
is_reference=True,
convert_custom_config=convert_custom_config,
_remove_qconfig=_remove_qconfig,
qconfig_mapping=qconfig_mapping,
backend_config=backend_config,
is_decomposed=True,
)
def _convert_standalone_module_fx(
graph_module: GraphModule,

View File

@ -140,6 +140,17 @@ def getattr_from_fqn(obj: Any, fqn: str) -> Any:
"""
return functools.reduce(getattr, fqn.split("."), obj)
def to_underlying_dtype(qdtype):
DTYPE_MAPPING = {
torch.quint8: torch.uint8,
torch.qint8: torch.int8,
torch.qint32: torch.int32,
torch.quint4x2: torch.uint8,
torch.quint2x4: torch.uint8,
}
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + qdtype
return DTYPE_MAPPING[qdtype]
def get_qparam_dict(observer_or_fake_quant):
qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
dtype = observer_or_fake_quant.dtype
@ -562,4 +573,5 @@ __all__ = [
"calculate_qmin_qmax",
"has_no_children_ignoring_parametrizations",
"get_fqn_to_example_inputs",
"to_underlying_dtype",
]