[quant][refactor] Merge add and mul handler (#52651)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52651

Merging them for easier extensions to fp16 and more binary ops

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D26600118

fbshipit-source-id: a1816e593cf3065afe87d2e6e44cdace13bf6aeb
This commit is contained in:
Jerry Zhang
2021-02-27 19:48:30 -08:00
committed by Facebook GitHub Bot
parent a296fa36ac
commit 0818dbf49d
3 changed files with 44 additions and 106 deletions

View File

@ -42,7 +42,7 @@ from abc import ABC, abstractmethod
import operator
import warnings
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Union
# -------------------------
# Pattern Registrations
@ -74,93 +74,18 @@ class QuantizeHandler(ABC):
return NotImplemented
@register_quant_pattern(operator.add)
@register_quant_pattern(torch.add)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.ReLU, torch.add))
@register_quant_pattern((torch.nn.functional.relu, operator.add))
@register_quant_pattern((torch.nn.functional.relu, torch.add))
class Add(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore
assert node.op == 'call_function' and node.target in [operator.add, torch.add]
self.add_node = node
self.num_node_args = len([a for a in self.add_node.args[:2] if isinstance(a, Node)])
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# Supported combinations are:
# quant_type | activation (compute_type) | weight
# static quint8 qint8
# tuple (activation_dtype, weight_dtype, compute_dtype)
supported_dtypes = [
(torch.quint8, torch.qint8, None),
]
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
# leave the op unquantized if the dtype combination is not supported
if dtypes not in supported_dtypes:
warnings.warn(
"dtype combination: {} is not "
"supported by add/mul "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.add_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
return quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
else:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
if self.num_node_args == 1:
# add scalar
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
if isinstance(self.add_node.args[0], Node):
quantized_index = 0
else:
quantized_index = 1
return quantizer.quantized_graph.create_node(
'call_function', op,
load_arg(quantized=[quantized_index])(self.add_node.args), self.add_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
kwargs = {**self.add_node.kwargs}
add_args = (*load_arg(quantized=True)(self.add_node.args), scale_arg, zero_point_arg)
op = quantizer.quantized_graph.create_node(
'call_function', op, add_args, kwargs)
return op
# TODO: merge with Add
@register_quant_pattern(operator.mul)
@register_quant_pattern(torch.add)
@register_quant_pattern(torch.mul)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.ReLU, operator.mul))
@register_quant_pattern((torch.nn.ReLU, torch.add))
@register_quant_pattern((torch.nn.ReLU, torch.mul))
@register_quant_pattern((torch.nn.functional.relu, operator.add))
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
@register_quant_pattern((torch.nn.functional.relu, torch.add))
@register_quant_pattern((torch.nn.functional.relu, torch.mul))
class Mul(QuantizeHandler):
class BinaryOp(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
@ -168,9 +93,25 @@ class Mul(QuantizeHandler):
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore
assert node.op == 'call_function' and node.target in [operator.mul, torch.mul]
self.mul_node = node
self.num_node_args = len([a for a in self.mul_node.args[:2] if isinstance(a, Node)])
self.bop_node = node
self.bop = node.target
self.num_node_args = len([a for a in self.bop_node.args[:2] if isinstance(a, Node)])
qbin_op_mapping: Dict[Union[Callable, str], Callable] = {
operator.add: torch.ops.quantized.add,
torch.add: torch.ops.quantized.add,
operator.mul: torch.ops.quantized.mul,
torch.mul: torch.ops.quantized.mul,
}
qbin_relu_op_mapping: Dict[Union[Callable, str], Callable] = {
operator.add: torch.ops.quantized.add_relu,
torch.add: torch.ops.quantized.add_relu,
operator.mul: torch.ops.quantized.mul_relu,
torch.mul: torch.ops.quantized.mul_relu,
}
# corresponding quantized op
self.qop = qbin_relu_op_mapping[self.bop] \
if self.relu_node is not None \
else qbin_op_mapping[self.bop] # type: ignore
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
@ -193,7 +134,7 @@ class Mul(QuantizeHandler):
"supported by add/mul "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.mul_node, load_arg(quantized=False))
op_out = quantizer.quantized_graph.node_copy(self.bop_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
@ -203,34 +144,31 @@ class Mul(QuantizeHandler):
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
if self.num_node_args == 1:
# mul scalar
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
if isinstance(self.mul_node.args[0], Node):
# add/mul scalar
if isinstance(self.bop_node.args[0], Node):
quantized_index = 0
else:
quantized_index = 1
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=[quantized_index])(self.mul_node.args), self.mul_node.kwargs)
'call_function', self.qop,
load_arg(quantized=[quantized_index])(self.bop_node.args), self.bop_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.mul
kwargs = {**self.mul_node.kwargs}
args = (*load_arg(quantized=True)(self.mul_node.args), scale_arg, zero_point_arg)
return quantizer.quantized_graph.create_node('call_function', op, args, kwargs)
op = torch.ops.quantized.add
kwargs = {**self.bop_node.kwargs}
add_args = (*load_arg(quantized=True)(self.bop_node.args), scale_arg, zero_point_arg)
op = quantizer.quantized_graph.create_node(
'call_function', self.qop, add_args, kwargs)
return op
@register_quant_pattern(torch.cat)
class Cat(QuantizeHandler):