[quant][graphmode][fx][refactor] Remove patterns from Quantizer class (#59033)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59033

To remove Quantizer class and split prepare and convert functions to different files

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D28724861

fbshipit-source-id: 97b38e851b6bf581510a24636b1d8d6f1d977f5a
This commit is contained in:
Jerry Zhang
2021-06-01 13:42:23 -07:00
committed by Facebook GitHub Bot
parent 83892c1861
commit e4b2684331
2 changed files with 23 additions and 23 deletions

View File

@ -1158,8 +1158,8 @@ class ELUQuantizeHandler(QuantizeHandler):
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[operator]
scale = float(scale)
zero_point = int(zero_point)
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point, modules, quantized_graph, node_name_to_scope)
scale_arg, zero_point_arg = create_qparam_nodes(
quantizer, node.name, scale, zero_point, modules, quantized_graph, node_name_to_scope)
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)

View File

@ -955,17 +955,6 @@ def run_weight_observers(observed: GraphModule) -> None:
class Quantizer:
def __init__(self):
# mapping from a tuple of nodes in reverse order to uninitialized
# QuantizeHandler subclass. For example,
# {
# # match a single node
# (<class 'torch.nn.modules.conv.Conv3d'>:
# <class 'torch.quantization.fx.quantize.ConvRelu'>),
# # match multiple nodes in reverse order
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
# <class 'torch.quantization.fx.quantize.Add'>),
# }
self.patterns: Dict[Pattern, QuantizeHandler] = {}
self.prepare_custom_config_dict: Dict[str, Any] = {}
def _prepare(
@ -1001,7 +990,17 @@ class Quantizer:
additional_quant_patterns = \
prepare_custom_config_dict.get("additional_quant_pattern", {})
self.patterns = get_combined_dict(
# mapping from a tuple of nodes in reverse order to uninitialized
# QuantizeHandler subclass. For example,
# {
# # match a single node
# (<class 'torch.nn.modules.conv.Conv3d'>:
# <class 'torch.quantization.fx.quantize.ConvRelu'>),
# # match multiple nodes in reverse order
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
# <class 'torch.quantization.fx.quantize.Add'>),
# }
patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict(
get_default_quant_patterns(), additional_quant_patterns)
convert_dict_to_ordered_dict(qconfig_dict)
@ -1036,7 +1035,7 @@ class Quantizer:
custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
matches = self._find_matches(
model.graph, modules, self.patterns, qconfig_map, standalone_module_names,
model.graph, modules, patterns, qconfig_map, standalone_module_names,
standalone_module_classes, custom_module_classes)
input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
@ -1052,7 +1051,7 @@ class Quantizer:
model.graph, prepare_custom_config_dict,
input_quantized_idxs, output_quantized_idxs)
self.save_state(model, qconfig_map, node_name_to_scope)
self.save_state(model, qconfig_map, node_name_to_scope, patterns)
preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", []))
model = ObservedGraphModule(model, model.graph, preserved_attributes)
if is_standalone_module:
@ -1072,21 +1071,22 @@ class Quantizer:
self,
observed: GraphModule,
qconfig_map: Dict[str, QConfigAny],
node_name_to_scope: Dict[str, Tuple[str, type]]) -> None:
observed._patterns = self.patterns # type: ignore[assignment]
node_name_to_scope: Dict[str, Tuple[str, type]],
patterns: Dict[Pattern, QuantizeHandler]) -> None:
observed._patterns = patterns # type: ignore[assignment]
observed._qconfig_map = qconfig_map # type: ignore[assignment]
observed._prepare_custom_config_dict = \
self.prepare_custom_config_dict # type: ignore[assignment]
observed._node_name_to_scope = node_name_to_scope # type: ignore[assignment]
def restore_state(self, observed: GraphModule) -> Dict[str, Tuple[str, type]]:
def restore_state(self, observed: GraphModule) -> Tuple[Dict[Pattern, QuantizeHandler], Dict[str, Tuple[str, type]]]:
assert is_observed_module(observed), \
'incoming model must be produced by prepare_fx'
self.patterns = observed._patterns # type: ignore[assignment]
self.prepare_custom_config_dict = \
observed._prepare_custom_config_dict # type: ignore[assignment]
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
return node_name_to_scope
patterns: Dict[Pattern, QuantizeHandler] = observed._patterns # type: ignore[assignment]
return patterns, node_name_to_scope
def prepare(
self,
@ -1113,7 +1113,7 @@ class Quantizer:
"""
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
node_name_to_scope = self.restore_state(model)
patterns, node_name_to_scope = self.restore_state(model)
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
# always run weight observers in the top level forward method
# for dynamic quant ops or weight only quant ops
@ -1136,7 +1136,7 @@ class Quantizer:
convert_custom_config_dict,
"observed_to_quantized_custom_module_class")
matches = self._find_matches(
model.graph, modules, self.patterns,
model.graph, modules, patterns,
qconfig_map,
custom_module_classes=custom_module_classes)