Revert D26655616: [quant][graphmode][fx] Add support for fp16 bmm pattern

Test Plan: revert-hammer

Differential Revision:
D26655616 (2c44b256d8)

Original commit changeset: 1d0639303e5c

fbshipit-source-id: 403429c706c8a9e6a657669daf8aadf282025f83
This commit is contained in:
Mike Ruberry
2021-03-01 04:48:33 -08:00
committed by Facebook GitHub Bot
parent e43ea227fe
commit 3a024a7ae2
2 changed files with 21 additions and 77 deletions

View File

@ -42,7 +42,7 @@ from abc import ABC, abstractmethod
import operator
import warnings
from typing import Any, Callable, Dict, Union, Optional, Tuple, List
from typing import Any, Callable, Dict, Union
# -------------------------
# Pattern Registrations
@ -77,7 +77,6 @@ class QuantizeHandler(ABC):
@register_quant_pattern(operator.mul)
@register_quant_pattern(torch.add)
@register_quant_pattern(torch.mul)
@register_quant_pattern(torch.bmm)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.ReLU, operator.mul))
@register_quant_pattern((torch.nn.ReLU, torch.add))
@ -94,9 +93,9 @@ class BinaryOp(QuantizeHandler):
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0] # type: ignore
self.binary_op_node = node
self.binary_op = node.target
self.num_node_args = len([a for a in self.binary_op_node.args[:2] if isinstance(a, Node)])
self.bop_node = node
self.bop = node.target
self.num_node_args = len([a for a in self.bop_node.args[:2] if isinstance(a, Node)])
qbin_op_mapping: Dict[Union[Callable, str], Callable] = {
operator.add: torch.ops.quantized.add,
torch.add: torch.ops.quantized.add,
@ -110,11 +109,9 @@ class BinaryOp(QuantizeHandler):
torch.mul: torch.ops.quantized.mul_relu,
}
# corresponding quantized op
self.quantized_binary_op: Optional[Callable] = None
if self.binary_op in qbin_op_mapping:
self.quantized_binary_op = qbin_relu_op_mapping[self.binary_op] \
if self.relu_node is not None \
else qbin_op_mapping[self.binary_op] # type: ignore
self.qop = qbin_relu_op_mapping[self.bop] \
if self.relu_node is not None \
else qbin_op_mapping[self.bop] # type: ignore
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
is_reference: bool = False,
@ -124,32 +121,21 @@ class BinaryOp(QuantizeHandler):
# static quint8 qint8
# tuple (activation_dtype, weight_dtype, compute_dtype)
# these are supported types for common binary ops like add/mul etc.
all_bop_dtypes = [
supported_dtypes = [
(torch.quint8, torch.qint8, None),
(torch.float16, torch.float16, None),
]
float16_dtypes = [
(torch.float16, torch.float16, None)
]
supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
operator.add: all_bop_dtypes,
torch.add: all_bop_dtypes,
operator.mul: all_bop_dtypes,
torch.mul: all_bop_dtypes,
torch.bmm: float16_dtypes,
}
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
# leave the op unquantized if the dtype combination is not supported
if dtypes not in supported_dtypes[self.binary_op]:
if dtypes not in supported_dtypes:
warnings.warn(
"dtype combination: {} is not "
"supported by {} "
"supported dtype combinations are: {}".format(dtypes, self.binary_op, supported_dtypes[self.binary_op]))
"supported by add/mul "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
op_out = quantizer.quantized_graph.node_copy(self.bop_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
@ -159,17 +145,16 @@ class BinaryOp(QuantizeHandler):
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
if dtypes in [(torch.quint8, torch.qint8, None)]:
assert self.quantized_binary_op is not None
if self.num_node_args == 1:
# add/mul scalar
if isinstance(self.binary_op_node.args[0], Node):
if isinstance(self.bop_node.args[0], Node):
quantized_index = 0
else:
quantized_index = 1
return quantizer.quantized_graph.create_node(
'call_function', self.quantized_binary_op,
load_arg(quantized=[quantized_index])(self.binary_op_node.args), self.binary_op_node.kwargs)
'call_function', self.qop,
load_arg(quantized=[quantized_index])(self.bop_node.args), self.bop_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
@ -181,16 +166,15 @@ class BinaryOp(QuantizeHandler):
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
kwargs = {**self.binary_op_node.kwargs}
add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg)
kwargs = {**self.bop_node.kwargs}
add_args = (*load_arg(quantized=True)(self.bop_node.args), scale_arg, zero_point_arg)
op = quantizer.quantized_graph.create_node(
'call_function', self.quantized_binary_op, add_args, kwargs)
'call_function', self.qop, add_args, kwargs)
return op
else:
assert dtypes == (torch.float16, torch.float16, None)
elif dtypes in [(torch.float16, torch.float16, None)]:
# TODO (refactor) this is duplicated, maybe have a helper function
if self.relu_node:
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
op_out = quantizer.quantized_graph.node_copy(self.bop_node, load_arg(quantized=False))
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)