mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][graphmode][fx][refactor] Remove patterns from Quantizer class (#59033)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59033 To remove Quantizer class and split prepare and convert functions to different files Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Imported from OSS Reviewed By: vkuzo Differential Revision: D28724861 fbshipit-source-id: 97b38e851b6bf581510a24636b1d8d6f1d977f5a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
83892c1861
commit
e4b2684331
@ -1158,8 +1158,8 @@ class ELUQuantizeHandler(QuantizeHandler):
|
||||
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[operator]
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point, modules, quantized_graph, node_name_to_scope)
|
||||
scale_arg, zero_point_arg = create_qparam_nodes(
|
||||
quantizer, node.name, scale, zero_point, modules, quantized_graph, node_name_to_scope)
|
||||
|
||||
quantized_op = get_quantized_operator(node.target)
|
||||
args = load_arg(quantized=[0])(node.args)
|
||||
|
||||
@ -955,17 +955,6 @@ def run_weight_observers(observed: GraphModule) -> None:
|
||||
|
||||
class Quantizer:
|
||||
def __init__(self):
|
||||
# mapping from a tuple of nodes in reverse order to uninitialized
|
||||
# QuantizeHandler subclass. For example,
|
||||
# {
|
||||
# # match a single node
|
||||
# (<class 'torch.nn.modules.conv.Conv3d'>:
|
||||
# <class 'torch.quantization.fx.quantize.ConvRelu'>),
|
||||
# # match multiple nodes in reverse order
|
||||
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
|
||||
# <class 'torch.quantization.fx.quantize.Add'>),
|
||||
# }
|
||||
self.patterns: Dict[Pattern, QuantizeHandler] = {}
|
||||
self.prepare_custom_config_dict: Dict[str, Any] = {}
|
||||
|
||||
def _prepare(
|
||||
@ -1001,7 +990,17 @@ class Quantizer:
|
||||
|
||||
additional_quant_patterns = \
|
||||
prepare_custom_config_dict.get("additional_quant_pattern", {})
|
||||
self.patterns = get_combined_dict(
|
||||
# mapping from a tuple of nodes in reverse order to uninitialized
|
||||
# QuantizeHandler subclass. For example,
|
||||
# {
|
||||
# # match a single node
|
||||
# (<class 'torch.nn.modules.conv.Conv3d'>:
|
||||
# <class 'torch.quantization.fx.quantize.ConvRelu'>),
|
||||
# # match multiple nodes in reverse order
|
||||
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
|
||||
# <class 'torch.quantization.fx.quantize.Add'>),
|
||||
# }
|
||||
patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict(
|
||||
get_default_quant_patterns(), additional_quant_patterns)
|
||||
|
||||
convert_dict_to_ordered_dict(qconfig_dict)
|
||||
@ -1036,7 +1035,7 @@ class Quantizer:
|
||||
custom_module_classes = get_custom_module_class_keys(
|
||||
prepare_custom_config_dict, "float_to_observed_custom_module_class")
|
||||
matches = self._find_matches(
|
||||
model.graph, modules, self.patterns, qconfig_map, standalone_module_names,
|
||||
model.graph, modules, patterns, qconfig_map, standalone_module_names,
|
||||
standalone_module_classes, custom_module_classes)
|
||||
|
||||
input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
|
||||
@ -1052,7 +1051,7 @@ class Quantizer:
|
||||
model.graph, prepare_custom_config_dict,
|
||||
input_quantized_idxs, output_quantized_idxs)
|
||||
|
||||
self.save_state(model, qconfig_map, node_name_to_scope)
|
||||
self.save_state(model, qconfig_map, node_name_to_scope, patterns)
|
||||
preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", []))
|
||||
model = ObservedGraphModule(model, model.graph, preserved_attributes)
|
||||
if is_standalone_module:
|
||||
@ -1072,21 +1071,22 @@ class Quantizer:
|
||||
self,
|
||||
observed: GraphModule,
|
||||
qconfig_map: Dict[str, QConfigAny],
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]]) -> None:
|
||||
observed._patterns = self.patterns # type: ignore[assignment]
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
patterns: Dict[Pattern, QuantizeHandler]) -> None:
|
||||
observed._patterns = patterns # type: ignore[assignment]
|
||||
observed._qconfig_map = qconfig_map # type: ignore[assignment]
|
||||
observed._prepare_custom_config_dict = \
|
||||
self.prepare_custom_config_dict # type: ignore[assignment]
|
||||
observed._node_name_to_scope = node_name_to_scope # type: ignore[assignment]
|
||||
|
||||
def restore_state(self, observed: GraphModule) -> Dict[str, Tuple[str, type]]:
|
||||
def restore_state(self, observed: GraphModule) -> Tuple[Dict[Pattern, QuantizeHandler], Dict[str, Tuple[str, type]]]:
|
||||
assert is_observed_module(observed), \
|
||||
'incoming model must be produced by prepare_fx'
|
||||
self.patterns = observed._patterns # type: ignore[assignment]
|
||||
self.prepare_custom_config_dict = \
|
||||
observed._prepare_custom_config_dict # type: ignore[assignment]
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
|
||||
return node_name_to_scope
|
||||
patterns: Dict[Pattern, QuantizeHandler] = observed._patterns # type: ignore[assignment]
|
||||
return patterns, node_name_to_scope
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@ -1113,7 +1113,7 @@ class Quantizer:
|
||||
"""
|
||||
if convert_custom_config_dict is None:
|
||||
convert_custom_config_dict = {}
|
||||
node_name_to_scope = self.restore_state(model)
|
||||
patterns, node_name_to_scope = self.restore_state(model)
|
||||
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
|
||||
# always run weight observers in the top level forward method
|
||||
# for dynamic quant ops or weight only quant ops
|
||||
@ -1136,7 +1136,7 @@ class Quantizer:
|
||||
convert_custom_config_dict,
|
||||
"observed_to_quantized_custom_module_class")
|
||||
matches = self._find_matches(
|
||||
model.graph, modules, self.patterns,
|
||||
model.graph, modules, patterns,
|
||||
qconfig_map,
|
||||
custom_module_classes=custom_module_classes)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user