mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58443 Test Plan: arc lint Reviewed By: vkuzo Differential Revision: D28494119 fbshipit-source-id: 05abf4e12051afc237096812fb0ee08a8b9447f9
1233 lines
60 KiB
Python
1233 lines
60 KiB
Python
import torch
|
|
from torch.fx.graph import (
|
|
Node,
|
|
)
|
|
import torch.nn.quantized as nnq
|
|
import torch.nn.quantized.dynamic as nnqd
|
|
from torch.quantization import (
|
|
default_affine_fixed_qparams_fake_quant,
|
|
default_symmetric_fixed_qparams_fake_quant,
|
|
)
|
|
|
|
from ..quantization_mappings import (
|
|
get_static_quant_module_class,
|
|
get_dynamic_quant_module_class,
|
|
get_quantized_operator,
|
|
)
|
|
from ..utils import (
|
|
get_swapped_custom_module_class,
|
|
activation_is_statically_quantized,
|
|
activation_is_int8_quantized,
|
|
weight_is_statically_quantized,
|
|
get_qconfig_dtypes,
|
|
activation_dtype,
|
|
)
|
|
|
|
from ..quantize import (
|
|
is_activation_post_process,
|
|
)
|
|
|
|
from .pattern_utils import (
|
|
register_quant_pattern,
|
|
get_default_output_activation_post_process_map,
|
|
Pattern,
|
|
)
|
|
|
|
from .utils import (
|
|
_parent_name,
|
|
all_node_args_have_no_tensors,
|
|
quantize_node,
|
|
get_per_tensor_qparams,
|
|
get_linear_prepack_op_for_dtype,
|
|
create_qparam_nodes,
|
|
get_qconv_prepack_op,
|
|
get_qconv_op,
|
|
)
|
|
|
|
from .quantization_types import QuantizerCls
|
|
|
|
from abc import ABC, abstractmethod
|
|
import operator
|
|
import warnings
|
|
|
|
from typing import Any, Callable, Dict, Union, Optional, Tuple, List
|
|
|
|
# -------------------------
|
|
# Pattern Registrations
|
|
# -------------------------
|
|
|
|
# 1. Post Training Static Quantization and Quantization Aware Training Patterns
|
|
|
|
# Base Pattern Handler
|
|
class QuantizeHandler(ABC):
|
|
""" Base handler class for the quantizer patterns
|
|
"""
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
""" Records pattern information in __init__, which will be used
|
|
in convert
|
|
"""
|
|
# this is an indicator of whether all the inputs are Node or not
|
|
# since some op might be quantized differently depending on whether
|
|
# all inputs are tensors or not, e.g. add/mul
|
|
self.num_tensor_args = len(node.args)
|
|
self.all_node_args_are_tensors = True
|
|
# the last node of the matched pattern
|
|
self.last_node = node
|
|
|
|
def _maybe_get_last_node_only_observer(
|
|
self,
|
|
quantizer: QuantizerCls,
|
|
) -> Optional[Union[torch.quantization.ObserverBase, torch.quantization.FakeQuantizeBase]]:
|
|
"""
|
|
If the last node of the pattern is observed, return the observer
|
|
instance. Otherwise, return None.
|
|
"""
|
|
for maybe_obs_node, _ in self.last_node.users.items():
|
|
if maybe_obs_node.op == 'call_module':
|
|
maybe_obs = quantizer.modules[maybe_obs_node.target]
|
|
if is_activation_post_process(maybe_obs):
|
|
return maybe_obs
|
|
return None
|
|
|
|
|
|
def input_output_observed(self) -> bool:
|
|
"""
|
|
Returns True if the pattern matched to this qhandler could be
|
|
be observed, and False it it should not be observed.
|
|
"""
|
|
return True
|
|
|
|
def should_insert_observer_for_output(
|
|
self,
|
|
qconfig: Any,
|
|
model_is_training: bool,
|
|
) -> bool:
|
|
"""
|
|
Returns true if an observer should be inserted for the output of
|
|
the pattern matched to this QuantizeHandler instance during the
|
|
prepare step.
|
|
"""
|
|
# TODO(future PR): potentially clean up and deduplicate these
|
|
# mappings.
|
|
return self.all_node_args_are_tensors and self.input_output_observed()
|
|
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
) -> bool:
|
|
"""
|
|
Returns true if after convert, the output of the matched pattern is
|
|
quantized iff the first input is also quantized.
|
|
"""
|
|
return False
|
|
|
|
def get_activation_ctr(
|
|
self,
|
|
qconfig: Any,
|
|
pattern: Pattern,
|
|
) -> Optional[Callable]:
|
|
"""
|
|
Returns the constructor for the activation observer which should be
|
|
used for the pattern matched to this handler. Some handlers override
|
|
this to a different value than what is specified in the qconfig.
|
|
"""
|
|
return qconfig.activation
|
|
|
|
|
|
@abstractmethod
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
""" Convert the given node to a quantized node and insert
|
|
it to the quantized graph
|
|
"""
|
|
return NotImplemented
|
|
|
|
|
|
# Binary op configs
|
|
|
|
# Supported combinations are:
|
|
# quant_type | activation (compute_type) | weight
|
|
# static quint8 qint8
|
|
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
# these are supported types for common binary ops like add/mul etc.
|
|
binary_op_all_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
(torch.float16, torch.float16, None),
|
|
]
|
|
binary_op_float16_dtypes = [
|
|
(torch.float16, torch.float16, None)
|
|
]
|
|
binary_op_int8_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
]
|
|
binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
|
|
operator.add: binary_op_all_dtypes,
|
|
torch.add: binary_op_all_dtypes,
|
|
operator.mul: binary_op_all_dtypes,
|
|
torch.mul: binary_op_all_dtypes,
|
|
torch.bmm: binary_op_float16_dtypes,
|
|
torch.sub: binary_op_float16_dtypes,
|
|
operator.sub: binary_op_float16_dtypes,
|
|
torch.div: binary_op_float16_dtypes,
|
|
operator.truediv: binary_op_float16_dtypes,
|
|
torch.sum: binary_op_float16_dtypes
|
|
}
|
|
binary_reference_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
|
|
torch.bmm: binary_op_int8_dtypes,
|
|
}
|
|
|
|
|
|
@register_quant_pattern(operator.add)
|
|
@register_quant_pattern(operator.sub)
|
|
@register_quant_pattern(operator.mul)
|
|
@register_quant_pattern(operator.truediv)
|
|
@register_quant_pattern(torch.add)
|
|
@register_quant_pattern(torch.sub)
|
|
@register_quant_pattern(torch.mul)
|
|
@register_quant_pattern(torch.div)
|
|
@register_quant_pattern(torch.sum)
|
|
@register_quant_pattern(torch.bmm)
|
|
@register_quant_pattern((torch.nn.ReLU, operator.add))
|
|
@register_quant_pattern((torch.nn.ReLU, operator.mul))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.add))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.mul))
|
|
@register_quant_pattern((torch.nn.functional.relu, operator.add))
|
|
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.add))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.mul))
|
|
class BinaryOpQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.relu_node = None
|
|
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
|
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
|
self.relu_node = node
|
|
node = node.args[0] # type: ignore[assignment]
|
|
self.binary_op_node = node
|
|
self.binary_op = node.target
|
|
|
|
# determine how many of the first two args are Tensors (versus scalars)
|
|
# this distinguishes things like "x + y" from "x + 2" or "2 + x"
|
|
self.num_tensor_args = 0
|
|
cache_for_no_tensor_check: Dict[Node, bool] = dict()
|
|
for arg_idx in range(len(self.binary_op_node.args)):
|
|
arg = self.binary_op_node.args[arg_idx]
|
|
if isinstance(arg, Node) and (not all_node_args_have_no_tensors(arg, quantizer.modules, cache_for_no_tensor_check)):
|
|
self.num_tensor_args += 1
|
|
self.all_node_args_are_tensors = \
|
|
(self.num_tensor_args == len(self.binary_op_node.args))
|
|
|
|
qbin_op_mapping: Dict[Union[Callable, str], Callable] = {
|
|
operator.add: torch.ops.quantized.add,
|
|
torch.add: torch.ops.quantized.add,
|
|
operator.mul: torch.ops.quantized.mul,
|
|
torch.mul: torch.ops.quantized.mul,
|
|
}
|
|
qbin_relu_op_mapping: Dict[Union[Callable, str], Callable] = {
|
|
operator.add: torch.ops.quantized.add_relu,
|
|
torch.add: torch.ops.quantized.add_relu,
|
|
operator.mul: torch.ops.quantized.mul_relu,
|
|
torch.mul: torch.ops.quantized.mul_relu,
|
|
}
|
|
# corresponding quantized op
|
|
self.quantized_binary_op: Optional[Callable] = None
|
|
if self.binary_op in qbin_op_mapping:
|
|
self.quantized_binary_op = qbin_relu_op_mapping[self.binary_op] \
|
|
if self.relu_node is not None \
|
|
else qbin_op_mapping[self.binary_op]
|
|
|
|
def should_insert_observer_for_output(
|
|
self,
|
|
qconfig: Any,
|
|
model_is_training: bool,
|
|
) -> bool:
|
|
"""
|
|
Returns true if an observer should be inserted for the output of
|
|
the pattern matched to this QuantizeHandler instance during the
|
|
prepare step.
|
|
"""
|
|
if self.num_tensor_args == 1:
|
|
return activation_dtype(qconfig) == torch.float16
|
|
elif self.all_node_args_are_tensors and self.input_output_observed():
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def input_output_observed(self):
|
|
# for x + y where x and y are scalars, we do not observe anything
|
|
return self.num_tensor_args > 0
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
|
|
if self.num_tensor_args == 0:
|
|
# example: x + y, when x and y are scalars
|
|
return quantizer.quantized_graph.node_copy(
|
|
node, load_arg(quantized=None))
|
|
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
|
|
if is_reference and self.binary_op in binary_reference_op_supported_dtypes and \
|
|
dtypes in binary_reference_op_supported_dtypes[self.binary_op]:
|
|
if dtypes in binary_op_int8_dtypes:
|
|
args = load_arg(quantized=[0, 1])(node.args)
|
|
args = load_arg(quantized=False)(node.args)
|
|
kwargs = load_arg(quantized=False)(node.kwargs)
|
|
op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
return quantize_node(
|
|
quantizer, op_out, activation_post_process,
|
|
node, is_input=False)
|
|
else:
|
|
warnings.warn(
|
|
"No implementation found for dtype combination: {}"
|
|
"for op {} with is_reference={} despite it being listed as supported"
|
|
"this should not happen".format(dtypes, self.binary_op, is_reference))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
elif not is_reference and self.binary_op in binary_op_supported_dtypes and \
|
|
dtypes in binary_op_supported_dtypes[self.binary_op]:
|
|
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
assert self.quantized_binary_op is not None
|
|
if self.num_tensor_args == 1:
|
|
# add/mul scalar
|
|
first_arg = self.binary_op_node.args[0]
|
|
cache_for_no_tensor_check: Dict[Node, bool] = dict()
|
|
if isinstance(first_arg, Node) and (
|
|
not all_node_args_have_no_tensors(
|
|
first_arg, quantizer.modules, cache_for_no_tensor_check)):
|
|
quantized_index = 0
|
|
else:
|
|
quantized_index = 1
|
|
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_function', self.quantized_binary_op,
|
|
load_arg(quantized=[quantized_index])(self.binary_op_node.args), self.binary_op_node.kwargs)
|
|
else:
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
|
|
|
if self.relu_node is not None:
|
|
op = torch.ops.quantized.add_relu
|
|
else:
|
|
op = torch.ops.quantized.add
|
|
kwargs = {**self.binary_op_node.kwargs}
|
|
add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg)
|
|
op = quantizer.quantized_graph.create_node(
|
|
'call_function', self.quantized_binary_op, add_args, kwargs)
|
|
return op
|
|
else:
|
|
assert dtypes == (torch.float16, torch.float16, None)
|
|
# TODO (refactor) this is duplicated, maybe have a helper function
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
else:
|
|
# leave the op unquantized if the dtype,reference combination is not supported
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by {} for is_reference={}. "
|
|
"Supported non-reference dtype combinations are: {} "
|
|
"Supported reference dtype combinations are: {}"
|
|
"".format(dtypes,
|
|
self.binary_op,
|
|
is_reference,
|
|
binary_op_supported_dtypes[self.binary_op],
|
|
(
|
|
[] if self.binary_op not in binary_reference_op_supported_dtypes.keys()
|
|
else binary_reference_op_supported_dtypes[self.binary_op]
|
|
)
|
|
)
|
|
)
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
|
|
|
|
@register_quant_pattern(torch.cat)
|
|
class CatQuantizeHandler(QuantizeHandler):
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if not self.all_node_args_are_tensors:
|
|
return NotImplemented
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=True))
|
|
|
|
# handle conv, maybe followed by relu
|
|
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
|
|
@register_quant_pattern(torch.nn.Conv1d)
|
|
@register_quant_pattern(torch.nn.Conv2d)
|
|
@register_quant_pattern(torch.nn.Conv3d)
|
|
@register_quant_pattern(torch.nn.functional.conv1d)
|
|
@register_quant_pattern(torch.nn.functional.conv2d)
|
|
@register_quant_pattern(torch.nn.functional.conv3d)
|
|
# TODO: add qat.Conv1d
|
|
@register_quant_pattern(torch.nn.qat.Conv2d)
|
|
@register_quant_pattern(torch.nn.qat.Conv3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.ConvReLU1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.ConvReLU3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU3d)
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv1d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv3d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv1d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv3d))
|
|
# just for error checks
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv3d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
|
|
class ConvReluQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.relu_node = None
|
|
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
|
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
|
self.relu_node = node
|
|
node = node.args[0] # type: ignore[assignment]
|
|
self.conv_node = node
|
|
if node.op == "call_module":
|
|
self.conv = quantizer.modules[self.conv_node.target]
|
|
elif node.op == "call_function":
|
|
self.conv = node.target
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation (compute_type) | weight
|
|
# static quint8 qint8
|
|
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
]
|
|
|
|
# TODO: is_reference option for conv module
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Conv "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
if self.relu_node:
|
|
conv_out = quantizer.quantized_graph.node_copy(self.conv_node, load_arg(quantized=False))
|
|
relu_args = [conv_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
|
|
activation_int8_quantized = activation_is_int8_quantized(qconfig)
|
|
|
|
if self.conv_node.op == 'call_module':
|
|
# note that relu should already be fused into conv module in the fusion step
|
|
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
|
|
'please make sure to run fusion before prepare'
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
# 1. attach activation post process to module
|
|
output_activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert output_activation_post_process is not None
|
|
self.conv.activation_post_process = output_activation_post_process
|
|
# 2. select quantized class
|
|
qconv_cls = get_static_quant_module_class(
|
|
type(self.conv), additional_static_quant_mapping)
|
|
quantized = qconv_cls.from_float(self.conv)
|
|
parent_name, name = _parent_name(self.conv_node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
self.conv_node.target,
|
|
(load_arg(quantized=True)(self.conv_node.args[0]),),
|
|
{})
|
|
else: # call_function
|
|
assert self.conv_node.op == "call_function"
|
|
if is_reference:
|
|
args = load_arg(quantized=[0, 1])(self.conv_node.args)
|
|
args = load_arg(quantized=False)(self.conv_node.args)
|
|
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", self.conv, args, kwargs)
|
|
if self.relu_node:
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
|
|
if activation_int8_quantized:
|
|
root_module = quantizer.modules['']
|
|
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
|
|
act_post_process_node = self.relu_node if self.relu_node else self.conv_node
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
return quantize_node(
|
|
quantizer, op_out, activation_post_process,
|
|
act_post_process_node, is_input=False)
|
|
else:
|
|
# output for dynamically quantized conv op is not quantized
|
|
return op_out
|
|
else:
|
|
assert len(self.conv_node.args) >= 7, \
|
|
"only conv2d calls with all arguments specified is supported right now in is_reference=False option"
|
|
args = load_arg(quantized=[0, 1])(self.conv_node.args)
|
|
# pack weight
|
|
weight = load_arg(quantized=True)(self.conv_node.args[1])
|
|
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
|
|
bias, stride, padding, dilation, groups = other_args
|
|
if self.conv == torch.nn.functional.conv1d:
|
|
# F.conv1d can take `int` as well as `list[int]` for stride,
|
|
# padding, dilation, but the prepack op cannot. Convert
|
|
# these to lists if needed.
|
|
stride = [stride] if isinstance(stride, int) else stride
|
|
padding = [padding] if isinstance(padding, int) else padding
|
|
dilation = [dilation] if isinstance(dilation, int) else dilation
|
|
prepack_args = (weight, bias, stride, padding, dilation, groups)
|
|
prepack_op = get_qconv_prepack_op(self.conv)
|
|
packed_weight = quantizer.quantized_graph.create_node(
|
|
"call_function", prepack_op, prepack_args, {})
|
|
assert activation_int8_quantized, \
|
|
"currently only static quantization is supported for conv"
|
|
# construct conv input
|
|
if activation_int8_quantized:
|
|
qconv_op = get_qconv_op(self.conv, self.relu_node is not None)
|
|
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
|
|
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
|
|
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
|
|
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.conv_node.name, scale, zero_point)
|
|
qconv_args = (conv_input, packed_weight, scale_node, zero_point_node)
|
|
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
|
|
op = quantizer.quantized_graph.create_node(
|
|
'call_function', qconv_op, qconv_args, kwargs)
|
|
# Store the name of the fused op to get the path of node after fusion as well.
|
|
# TODO: may need to change the key to Node regenerate the map in each transformation,
|
|
# since we might not be able to rely on the name
|
|
quantizer.node_name_to_scope[op.name] = quantizer.node_name_to_scope[self.conv_node.name]
|
|
return op
|
|
else:
|
|
# conv2d_dyanmic branch
|
|
raise Exception("Only static quant is supported for conv")
|
|
|
|
|
|
# handle linear, maybe followed by relu
|
|
@register_quant_pattern(torch.nn.Linear)
|
|
@register_quant_pattern(torch.nn.functional.linear)
|
|
@register_quant_pattern(torch.nn.qat.Linear)
|
|
@register_quant_pattern(torch.nn.intrinsic.LinearReLU)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.LinearReLU)
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.linear))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.linear))
|
|
# for error checks
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
|
|
class LinearReLUQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.relu_node = None
|
|
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
|
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
|
self.relu_node = node
|
|
node = node.args[0] # type: ignore[assignment]
|
|
self.linear_node = node
|
|
if node.op == 'call_module':
|
|
self.linear = quantizer.modules[self.linear_node.target]
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation (compute_type) | weight
|
|
# static quint8 qint8
|
|
# dynamic float32 (quint8) qint8
|
|
# weight_only float32 float16
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None),
|
|
# static float16 quantization
|
|
(torch.float16, torch.float16, None),
|
|
]
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Linear "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.node_copy(self.linear_node, load_arg(quantized=False))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
activation_int8_quantized = activation_is_int8_quantized(qconfig)
|
|
weight_dtype = dtypes[1]
|
|
# TODO: reference_model option for linear module
|
|
if self.linear_node.op == 'call_module':
|
|
|
|
output_activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
|
|
# note that relu should already be fused into conv module in the fusion step
|
|
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
|
|
'please make sure to run fusion before prepare'
|
|
# 1. attach output activation post process to linear module
|
|
|
|
if output_activation_post_process:
|
|
self.linear.activation_post_process = output_activation_post_process
|
|
|
|
# 2. select corresponding quantized linear class for the float linear class
|
|
if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]:
|
|
qlinear = nnq.Linear if activation_int8_quantized else nnqd.Linear
|
|
elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]:
|
|
assert activation_int8_quantized, \
|
|
'Only int8 static quantization is supported for LinearReLU'
|
|
qlinear = torch.nn.intrinsic.quantized.LinearReLU
|
|
else:
|
|
raise Exception("unhandled linear type:", type(self.linear))
|
|
quantized = qlinear.from_float(self.linear)
|
|
parent_name, name = _parent_name(self.linear_node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized)
|
|
# activation needs to be quantized for static quantization
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
self.linear_node.target,
|
|
(load_arg(quantized=activation_int8_quantized)(self.linear_node.args[0]),), {})
|
|
else: # call_function
|
|
assert self.linear_node.op == 'call_function'
|
|
if is_reference:
|
|
quantized_input_idxs = []
|
|
if activation_int8_quantized:
|
|
quantized_input_idxs.append(0)
|
|
if weight_is_statically_quantized(qconfig):
|
|
quantized_input_idxs.append(1)
|
|
args = load_arg(quantized=quantized_input_idxs)(self.linear_node.args)
|
|
args = load_arg(quantized=False)(self.linear_node.args)
|
|
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.linear, args, kwargs)
|
|
if self.relu_node:
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
|
|
if activation_int8_quantized:
|
|
# quantize output for statically quantized linear op
|
|
root_module = quantizer.modules['']
|
|
act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name
|
|
act_post_process_node = self.relu_node if self.relu_node else self.linear_node
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
return quantize_node(
|
|
quantizer,
|
|
op_out,
|
|
activation_post_process,
|
|
act_post_process_node,
|
|
is_input=False)
|
|
else:
|
|
# output for dynamically quantized linear op is not quantized
|
|
return op_out
|
|
else: # non-reference option
|
|
# prepacking weights for static int8 quant and dynamic quant
|
|
if dtypes != (torch.float16, torch.float16, None):
|
|
# linear args
|
|
# (x, weight, bias, ...)
|
|
weight_quantized = weight_is_statically_quantized(qconfig)
|
|
linear_weight = load_arg(quantized=weight_quantized)(self.linear_node.args[1])
|
|
|
|
# get other arguments
|
|
kwargs = {**load_arg(quantized=False)(self.linear_node.kwargs)}
|
|
# pack weight
|
|
bias = None
|
|
# all args after bias, including bias
|
|
other_args = load_arg(quantized=False)(self.linear_node.args[2:])
|
|
if len(self.linear_node.args) > 2:
|
|
bias = load_arg(quantized=False)(self.linear_node.args[2])
|
|
other_args = other_args[1:] # remove the bias argument
|
|
else:
|
|
assert 'bias' in kwargs, \
|
|
'expect bias provided as a keyword argument when it is not a positional argument'
|
|
bias = kwargs['bias']
|
|
kwargs.pop('bias')
|
|
prepack_args = (linear_weight, bias)
|
|
prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
|
|
packed_weight = quantizer.quantized_graph.create_node(
|
|
'call_function', prepack_op, prepack_args, {})
|
|
# construct linear input
|
|
if activation_int8_quantized:
|
|
qlinear_op = torch.ops.quantized.linear_relu if self.relu_node else torch.ops.quantized.linear
|
|
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
|
|
|
|
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.linear_node.name, scale, zero_point)
|
|
|
|
qlinear_args = (linear_input, packed_weight, scale_node, zero_point_node)
|
|
op = quantizer.quantized_graph.create_node(
|
|
"call_function", qlinear_op, qlinear_args, kwargs)
|
|
# Store the name of the fused op to get the path of node after fusion as well.
|
|
# TODO: may need to change the key to Node regenerate the map in each transformation,
|
|
# since we might not be able to rely on the name
|
|
quantizer.node_name_to_scope[op.name] = quantizer.node_name_to_scope[self.linear_node.name]
|
|
return op
|
|
elif dtypes in [(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None)]:
|
|
# choose linear dynamic or linear dynamic fp16 op based on weight dtype
|
|
qlinear_op = torch.ops.quantized.linear_dynamic \
|
|
if weight_dtype == torch.qint8 \
|
|
else torch.ops.quantized.linear_dynamic_fp16
|
|
linear_input = load_arg(quantized=False)(self.linear_node.args[0])
|
|
qlinear_args = (linear_input, packed_weight) # type: ignore[assignment]
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", qlinear_op, qlinear_args, kwargs)
|
|
# Store the name of the dynamic op to get the path of node after replacement as well.
|
|
# TODO: may need to change the key to Node regenerate the map in each transformation,
|
|
# since we might not be able to rely on the name
|
|
quantizer.node_name_to_scope[op_out.name] = quantizer.node_name_to_scope[self.linear_node.name]
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.create_node("call_function", torch.nn.functional.relu, (op_out,), {})
|
|
return op_out
|
|
else:
|
|
assert dtypes == (torch.float16, torch.float16, None)
|
|
# TODO (refactor) this is duplicated, maybe have a helper function
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.node_copy(self.linear_node, load_arg(quantized=False))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
|
|
@register_quant_pattern(torch.nn.BatchNorm2d)
|
|
@register_quant_pattern(torch.nn.BatchNorm3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.BNReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.BNReLU3d)
|
|
class BatchNormQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
assert node.op == 'call_module'
|
|
self.bn_node = node
|
|
self.bn = quantizer.modules[self.bn_node.target]
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
# 1. attach activation post process to module
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
self.bn.activation_post_process = activation_post_process
|
|
qbn_cls = get_static_quant_module_class(type(self.bn), additional_static_quant_mapping)
|
|
quantized = qbn_cls.from_float(self.bn)
|
|
parent_name, name = _parent_name(self.bn_node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
self.bn_node.target,
|
|
load_arg(quantized=[0])(self.bn_node.args),
|
|
load_arg(quantized=False)(self.bn_node.kwargs))
|
|
|
|
@register_quant_pattern(torch.nn.Embedding)
|
|
@register_quant_pattern(torch.nn.EmbeddingBag)
|
|
class EmbeddingQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
|
|
def input_output_observed(self) -> bool:
|
|
return False
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation | weight | activation_compute_type
|
|
# weight_only | float32 | quint8 | None
|
|
# weight_only | float32 | quint4x2 | None
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.float32, torch.quint8, None),
|
|
(torch.float32, torch.quint4x2, None),
|
|
]
|
|
assert node.op == 'call_module'
|
|
emb_node = node
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Embedding/EmbeddingBag, "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
emb = quantizer.modules[emb_node.target]
|
|
qemb = get_static_quant_module_class(type(emb))
|
|
quantized = qemb.from_float(emb)
|
|
parent_name, name = _parent_name(emb_node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
emb_node.target,
|
|
load_arg(quantized=False)(emb_node.args),
|
|
load_arg(quantized=False)(emb_node.kwargs))
|
|
|
|
# TODO (maybe): merge with embedding quantize handler
|
|
@register_quant_pattern(torch.nn.GRUCell)
|
|
@register_quant_pattern(torch.nn.LSTMCell)
|
|
@register_quant_pattern(torch.nn.RNNCell)
|
|
@register_quant_pattern(torch.nn.LSTM)
|
|
class RNNDynamicQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
|
|
def input_output_observed(self) -> bool:
|
|
return False
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation | weight | activation_compute_type
|
|
# dynamic | float32 | qint8 | quint8
|
|
# dynamic | float32 | float16 | None
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None),
|
|
]
|
|
assert node.op == 'call_module'
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Embedding/EmbeddingBag, "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
module = quantizer.modules[node.target]
|
|
qmodule_cls = get_dynamic_quant_module_class(type(module))
|
|
qmodule = qmodule_cls.from_float(module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(quantizer.modules[parent_name], name, qmodule)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
node.target,
|
|
load_arg(quantized=False)(node.args),
|
|
load_arg(quantized=False)(node.kwargs))
|
|
|
|
ARGS_TO_SKIP = {
|
|
torch._ops.ops.quantized.hardswish: ['inplace'],
|
|
torch._ops.ops.quantized.instance_norm:
|
|
['running_mean', 'running_var', 'use_input_stats', 'momentum'],
|
|
}
|
|
@register_quant_pattern(torch.nn.ConvTranspose1d)
|
|
@register_quant_pattern(torch.nn.ConvTranspose2d)
|
|
@register_quant_pattern(torch.nn.ELU)
|
|
@register_quant_pattern(torch.nn.LeakyReLU)
|
|
@register_quant_pattern(torch.nn.Hardswish)
|
|
@register_quant_pattern(torch.nn.InstanceNorm1d)
|
|
@register_quant_pattern(torch.nn.InstanceNorm2d)
|
|
@register_quant_pattern(torch.nn.InstanceNorm3d)
|
|
@register_quant_pattern(torch.nn.LayerNorm)
|
|
@register_quant_pattern(torch.nn.SiLU)
|
|
# we currently only support reference patterns for these ops so they have been removed
|
|
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
|
|
# @register_quant_pattern(torch.nn.GELU)
|
|
# @register_quant_pattern(torch.nn.Softmax)
|
|
@register_quant_pattern(torch.nn.functional.hardswish)
|
|
@register_quant_pattern(torch.nn.functional.instance_norm)
|
|
@register_quant_pattern(torch.nn.functional.layer_norm)
|
|
@register_quant_pattern(torch.nn.functional.leaky_relu)
|
|
@register_quant_pattern(torch.nn.functional.silu)
|
|
# we currently only support reference patterns for these ops so they have been removed
|
|
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
|
|
# @register_quant_pattern(torch.nn.functional.gelu)
|
|
# @register_quant_pattern(torch.nn.functional.softmax)
|
|
class DefaultNodeQuantizeHandler(QuantizeHandler):
|
|
''' Common quantized op, first input and first output will be quantized
|
|
'''
|
|
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
if node.op == "call_function" or node.op == "call_method":
|
|
self.op = node.target
|
|
elif node.op == "call_module":
|
|
self.op = type(quantizer.modules[node.target])
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if not self.all_node_args_are_tensors:
|
|
return NotImplemented
|
|
assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
|
|
'call_function are handled in DefaultNode'
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
|
|
all_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
(torch.float16, torch.float16, None)
|
|
]
|
|
int8_dtypes = [
|
|
(torch.quint8, torch.qint8, None)
|
|
]
|
|
fp16_dtypes = [
|
|
(torch.float16, torch.float16, None)
|
|
]
|
|
supported_dtypes = {
|
|
torch.nn.ConvTranspose1d: int8_dtypes,
|
|
torch.nn.ConvTranspose2d: int8_dtypes,
|
|
torch.nn.ELU: int8_dtypes,
|
|
torch.nn.LeakyReLU: int8_dtypes,
|
|
torch.nn.Hardswish: int8_dtypes,
|
|
torch.nn.InstanceNorm1d: int8_dtypes,
|
|
torch.nn.InstanceNorm2d: int8_dtypes,
|
|
torch.nn.InstanceNorm3d: int8_dtypes,
|
|
torch.nn.LayerNorm: all_dtypes,
|
|
torch.nn.SiLU: fp16_dtypes,
|
|
torch.nn.GELU: int8_dtypes,
|
|
torch.nn.Softmax: int8_dtypes,
|
|
torch.nn.functional.hardswish: int8_dtypes,
|
|
torch.nn.functional.instance_norm: int8_dtypes,
|
|
torch.nn.functional.layer_norm: all_dtypes,
|
|
torch.nn.functional.leaky_relu: int8_dtypes,
|
|
torch.nn.functional.silu: fp16_dtypes,
|
|
torch.nn.functional.gelu: int8_dtypes,
|
|
torch.nn.functional.softmax: int8_dtypes,
|
|
}
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
if dtypes not in supported_dtypes[self.op]:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by {} "
|
|
"supported dtype combinations are: {}".format(dtypes, self.op, supported_dtypes[self.op]))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
# TODO: make helper functions for (torch.quint8, torch.qint8, None)
|
|
if not is_reference:
|
|
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
if node.op == 'call_module':
|
|
module = quantizer.modules[node.target]
|
|
module.activation_post_process = activation_post_process
|
|
quantized_module_cls = get_static_quant_module_class(
|
|
type(module), additional_static_quant_mapping)
|
|
quantized_module = quantized_module_cls.from_float(module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized_module)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
node.target,
|
|
load_arg(quantized=[0])(node.args),
|
|
load_arg(quantized=False)(node.kwargs))
|
|
else:
|
|
assert node.op == "call_function"
|
|
# call_function
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
|
|
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
|
|
|
assert not isinstance(node.target, str), "Expecting node.target for "
|
|
"call_function to be a function instead of a string"
|
|
quantized_op = get_quantized_operator(node.target)
|
|
args = load_arg(quantized=[0])(node.args)
|
|
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale_arg,
|
|
"output_zero_point": zero_point_arg}
|
|
if quantized_op in ARGS_TO_SKIP:
|
|
args_to_skip = ARGS_TO_SKIP[quantized_op]
|
|
for arg in args_to_skip:
|
|
if arg in kwargs:
|
|
kwargs.pop(arg)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", quantized_op, args, kwargs)
|
|
else:
|
|
assert dtypes in [(torch.float16, torch.float16, None)]
|
|
# Generally fp16 kernels don't exist for fp16 ops
|
|
warnings.warn(
|
|
"Only reference patterns are currently supported for {dtype} dtype with {op} op"
|
|
"".format(dtype=dtypes, op=self.op))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
else:
|
|
assert is_reference
|
|
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
load_arg(quantized=[0])(node.args)
|
|
args = load_arg(quantized=False)(node.args)
|
|
kwargs = load_arg(quantized=False)(node.kwargs)
|
|
op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
return quantize_node(
|
|
quantizer, op_out, activation_post_process,
|
|
node, is_input=False)
|
|
else:
|
|
assert dtypes in [(torch.float16, torch.float16, None)]
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
|
|
|
|
# TODO: elu is using scale/zero_point instead of output_scale, output_zero_point
|
|
@register_quant_pattern(torch.nn.functional.elu)
|
|
class ELUQuantizeHandler(QuantizeHandler):
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(quantizer)
|
|
assert activation_post_process is not None
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
|
|
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
|
|
|
quantized_op = get_quantized_operator(node.target)
|
|
args = load_arg(quantized=[0])(node.args)
|
|
kwargs = {**load_arg(quantized=False)(node.kwargs), 'output_scale': scale_arg, 'output_zero_point': zero_point_arg}
|
|
kwargs.pop('inplace')
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_function', quantized_op, args, kwargs)
|
|
|
|
@register_quant_pattern(torch.nn.Hardsigmoid, default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.nn.Sigmoid, default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.sigmoid, default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('sigmoid', default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('sigmoid_', default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.nn.Tanh, default_symmetric_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.tanh, default_symmetric_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('tanh', default_symmetric_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_fake_quant)
|
|
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.node = node
|
|
|
|
def should_insert_observer_for_output(
|
|
self,
|
|
qconfig: Any,
|
|
model_is_training: bool,
|
|
) -> bool:
|
|
if model_is_training:
|
|
# in QAT, always insert fake_quants
|
|
return True
|
|
else:
|
|
# in PTQ, only insert observers when emulating fp16
|
|
return activation_dtype(qconfig) == torch.float16
|
|
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
) -> bool:
|
|
return True
|
|
|
|
# some qhandlers override the activations constructor
|
|
def get_activation_ctr(self, qconfig, pattern) -> Optional[Callable]:
|
|
if activation_dtype(qconfig) == torch.float16:
|
|
return qconfig.activation
|
|
else:
|
|
return get_default_output_activation_post_process_map().get(
|
|
pattern, None)
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
if dtypes == (torch.float16, torch.float16, None):
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
|
|
# these ops have quantized equivalents that do not need any extra information
|
|
@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
|
|
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
|
|
@register_quant_pattern(torch.nn.AdaptiveAvgPool3d)
|
|
@register_quant_pattern(torch.nn.AvgPool1d)
|
|
@register_quant_pattern(torch.nn.AvgPool2d)
|
|
@register_quant_pattern(torch.nn.AvgPool3d)
|
|
@register_quant_pattern(torch.nn.Dropout)
|
|
@register_quant_pattern(torch.nn.Hardtanh)
|
|
@register_quant_pattern(torch.nn.Identity)
|
|
@register_quant_pattern(torch.nn.MaxPool1d)
|
|
@register_quant_pattern(torch.nn.MaxPool2d)
|
|
@register_quant_pattern(torch.nn.MaxPool3d)
|
|
@register_quant_pattern(torch.nn.ReLU)
|
|
@register_quant_pattern(torch.nn.ReLU6)
|
|
@register_quant_pattern(torch.adaptive_avg_pool1d)
|
|
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
|
|
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d)
|
|
@register_quant_pattern(torch.nn.functional.dropout)
|
|
@register_quant_pattern(torch.nn.functional.hardtanh)
|
|
@register_quant_pattern(torch.nn.functional.hardtanh_)
|
|
@register_quant_pattern(torch.nn.functional.interpolate)
|
|
@register_quant_pattern(torch.nn.functional.max_pool1d)
|
|
@register_quant_pattern(torch.nn.functional.max_pool2d)
|
|
@register_quant_pattern(torch.nn.functional.max_pool3d)
|
|
@register_quant_pattern(torch.nn.functional.relu)
|
|
@register_quant_pattern(torch.nn.functional.relu6)
|
|
@register_quant_pattern(torch.avg_pool1d)
|
|
@register_quant_pattern(torch._C._nn.avg_pool2d)
|
|
@register_quant_pattern(torch._C._nn.avg_pool3d)
|
|
@register_quant_pattern(torch.chunk)
|
|
@register_quant_pattern(torch.clamp)
|
|
@register_quant_pattern(torch.flatten)
|
|
@register_quant_pattern(torch.transpose)
|
|
@register_quant_pattern(torch.max)
|
|
@register_quant_pattern(torch.mean)
|
|
@register_quant_pattern(torch.min)
|
|
@register_quant_pattern(torch.repeat_interleave)
|
|
@register_quant_pattern(torch.sort)
|
|
@register_quant_pattern(torch.squeeze)
|
|
@register_quant_pattern(torch.stack)
|
|
@register_quant_pattern(torch.unsqueeze)
|
|
@register_quant_pattern(operator.floordiv)
|
|
@register_quant_pattern(operator.getitem)
|
|
@register_quant_pattern('chunk')
|
|
@register_quant_pattern('clamp')
|
|
@register_quant_pattern('contiguous')
|
|
@register_quant_pattern('detach')
|
|
@register_quant_pattern('detach_')
|
|
@register_quant_pattern('mean')
|
|
@register_quant_pattern('numel')
|
|
@register_quant_pattern('permute')
|
|
@register_quant_pattern('relu')
|
|
@register_quant_pattern('relu_')
|
|
@register_quant_pattern('repeat')
|
|
@register_quant_pattern('repeat_interleave')
|
|
@register_quant_pattern('reshape')
|
|
@register_quant_pattern('resize_')
|
|
@register_quant_pattern('shape')
|
|
@register_quant_pattern('size')
|
|
@register_quant_pattern('squeeze')
|
|
@register_quant_pattern('squeeze_')
|
|
@register_quant_pattern('transpose')
|
|
@register_quant_pattern('unsqueeze')
|
|
@register_quant_pattern('unsqueeze_')
|
|
@register_quant_pattern('view')
|
|
class CopyNodeQuantizeHandler(QuantizeHandler):
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
) -> bool:
|
|
return True
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
class CustomModuleQuantizeHandler(QuantizeHandler):
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
""" Convert a float custom module to quantized custom module
|
|
"""
|
|
assert node.op == 'call_module'
|
|
assert convert_custom_config_dict is not None
|
|
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None)
|
|
assert custom_module_class_mapping is not None
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
observed_custom_module = quantizer.modules[node.target]
|
|
if activation_is_statically_quantized(qconfig):
|
|
assert node.name in quantizer.activation_post_process_map
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
observed_custom_module.activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
quantized_custom_module_class = get_swapped_custom_module_class(
|
|
observed_custom_module, custom_module_class_mapping, qconfig)
|
|
quantized_custom_module = \
|
|
quantized_custom_module_class.from_observed(observed_custom_module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized_custom_module)
|
|
# hardcoded the qunatized input to be None (take whatever is in the environemnt),
|
|
# we can extend this
|
|
# if there is a need, e.g. get the indexes of quantized inputs from some
|
|
# module attribute like module._QUANTIZED_INPUT_INDEXES
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
class StandaloneModuleQuantizeHandler(QuantizeHandler):
|
|
""" Converts an observed standalone module to quantized standalone module
|
|
by calling convert_fx on the observed standalone module.
|
|
"""
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
assert node.op == 'call_module'
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore[attr-defined]
|
|
observed_standalone_module = quantizer.modules[node.target]
|
|
input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs.tolist()
|
|
quantized_standalone_module = convert(observed_standalone_module, is_reference=is_reference)
|
|
parent_name, name = _parent_name(node.target)
|
|
# update the modules dict
|
|
setattr(quantizer.modules[parent_name], name, quantized_standalone_module)
|
|
quantizer.modules[node.target] = quantized_standalone_module
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))
|