mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[quant][graphmode][fx][refactor] Split quantize.py to prepare.py and convert.py (#59353)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59353 Next: remove Quantizer class Test Plan: Imported from OSS Reviewed By: raghuramank100 Differential Revision: D28856277 fbshipit-source-id: 25f5502be387dbe9706780f667501b46b82789a5
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8b4784a9c6
commit
18642e664a
@ -16,7 +16,7 @@ from torch.quantization.quantize_fx import (
|
||||
|
||||
from torch.quantization.fx.quantization_patterns import DefaultNodeQuantizeHandler
|
||||
|
||||
from torch.quantization.fx.pattern_utils import (
|
||||
from torch.quantization.fx.match_utils import (
|
||||
is_match,
|
||||
MatchAllNode,
|
||||
)
|
||||
|
481
torch/quantization/fx/convert.py
Normal file
481
torch/quantization/fx/convert.py
Normal file
@ -0,0 +1,481 @@
|
||||
from typing import Any, Dict, Tuple, List, Callable, Optional, Union
|
||||
import torch
|
||||
from torch.fx import (
|
||||
GraphModule,
|
||||
Proxy,
|
||||
map_arg
|
||||
)
|
||||
from torch.fx.graph import (
|
||||
Graph,
|
||||
Node,
|
||||
)
|
||||
from torch.fx.node import Argument
|
||||
from .quantization_types import Pattern
|
||||
from .qconfig_utils import QConfigAny
|
||||
from .match_utils import (
|
||||
find_matches,
|
||||
)
|
||||
from .graph_module import (
|
||||
is_observed_module,
|
||||
is_observed_standalone_module,
|
||||
QuantizedGraphModule,
|
||||
)
|
||||
from .quantization_patterns import (
|
||||
QuantizeHandler,
|
||||
)
|
||||
from .utils import (
|
||||
is_get_tensor_info_node,
|
||||
node_return_type_is_int,
|
||||
quantize_node,
|
||||
get_new_attr_name_with_prefix,
|
||||
collect_producer_nodes,
|
||||
graph_module_from_producer_nodes,
|
||||
get_custom_module_class_keys,
|
||||
WEIGHT_INDEX_DICT,
|
||||
)
|
||||
|
||||
from ..quantize import (
|
||||
_remove_qconfig,
|
||||
is_activation_post_process,
|
||||
)
|
||||
from ..utils import (
|
||||
activation_is_statically_quantized,
|
||||
activation_dtype,
|
||||
)
|
||||
|
||||
# weight prepacking ops
|
||||
WEIGHT_PREPACK_OPS = {
|
||||
torch._ops.ops.quantized.linear_prepack,
|
||||
torch._ops.ops.quantized.linear_prepack_fp16,
|
||||
torch._ops.ops.quantized.conv1d_prepack,
|
||||
torch._ops.ops.quantized.conv2d_prepack,
|
||||
torch._ops.ops.quantized.conv3d_prepack,
|
||||
}
|
||||
|
||||
def run_weight_observers(observed: GraphModule) -> None:
|
||||
r''' Extract the subgraph that produces the weight for dynamic quant
|
||||
or weight only quant node and run the subgraph to observe the weight.
|
||||
Note that the observers of dynamic quant or weight only quant ops are
|
||||
run during the convert step.
|
||||
'''
|
||||
for node in observed.graph.nodes:
|
||||
if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
|
||||
for i, node_arg in enumerate(node.args):
|
||||
if i in WEIGHT_INDEX_DICT[node.target]:
|
||||
# node_arg is weight
|
||||
weight_observer_nodes = collect_producer_nodes(node_arg)
|
||||
if weight_observer_nodes is not None:
|
||||
weight_observer_module = \
|
||||
graph_module_from_producer_nodes(
|
||||
observed, weight_observer_nodes)
|
||||
# run the weight observer
|
||||
weight_observer_module()
|
||||
|
||||
def fold_weight(
|
||||
quantized: QuantizedGraphModule,
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]]) -> QuantizedGraphModule:
|
||||
"""
|
||||
Trace back from the weight node util we hit getattr, reconstruct the
|
||||
graph module with the traced nodes and run the graph module to pack the
|
||||
weight. then replace the original chain of ops with the packed weight.
|
||||
"""
|
||||
packed_weights = dict()
|
||||
# map from folded node name to the prepacked weight name
|
||||
folded_nodes = dict()
|
||||
# get packed weights
|
||||
for node in quantized.graph.nodes:
|
||||
if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
|
||||
nodes_to_fold = collect_producer_nodes(node)
|
||||
if nodes_to_fold is not None:
|
||||
for node_to_fold in nodes_to_fold:
|
||||
folded_nodes[node_to_fold.name] = node
|
||||
|
||||
prepacking_module = graph_module_from_producer_nodes(
|
||||
quantized, nodes_to_fold)
|
||||
packed_weight = prepacking_module()
|
||||
packed_weights[node.name] = packed_weight
|
||||
|
||||
# remove folded nodes and replace the prepacking node with getattr
|
||||
folded_graph = Graph()
|
||||
env: Dict[Any, Any] = {}
|
||||
|
||||
def load_arg(a):
|
||||
return map_arg(a, lambda node: env[node.name])
|
||||
quantized_root = quantized
|
||||
quantized_graph = quantized.graph
|
||||
|
||||
for node in quantized_graph.nodes:
|
||||
prepack_node = folded_nodes.get(node.name, None)
|
||||
if prepack_node is node:
|
||||
packed_weight = packed_weights[node.name]
|
||||
# add a prepacked attribute to root
|
||||
op_node = list(prepack_node.users)[0]
|
||||
module_path, _ = node_name_to_scope[op_node.name]
|
||||
get_new_packed_weight_name = \
|
||||
get_new_attr_name_with_prefix(module_path + '_packed_weight_')
|
||||
packed_weight_name = get_new_packed_weight_name(quantized_root)
|
||||
setattr(quantized_root, packed_weight_name, packed_weight)
|
||||
# replace prepack node with a getattr node
|
||||
env[node.name] = folded_graph.create_node(
|
||||
'get_attr', packed_weight_name, (), {})
|
||||
elif prepack_node is not None:
|
||||
# remove the foled node
|
||||
continue
|
||||
else:
|
||||
# copy other nodes
|
||||
env[node.name] = folded_graph.node_copy(node, load_arg)
|
||||
quantized = QuantizedGraphModule(quantized_root, folded_graph, quantized_root.preserved_attr_names)
|
||||
return quantized
|
||||
|
||||
def restore_state(
|
||||
observed: GraphModule
|
||||
) -> Tuple[Dict[Pattern, QuantizeHandler], Dict[str, Tuple[str, type]], Dict[str, Any]]:
|
||||
assert is_observed_module(observed), \
|
||||
'incoming model must be produced by prepare_fx'
|
||||
prepare_custom_config_dict: Dict[str, Any] = \
|
||||
observed._prepare_custom_config_dict # type: ignore[assignment]
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
|
||||
patterns: Dict[Pattern, QuantizeHandler] = observed._patterns # type: ignore[assignment]
|
||||
return patterns, node_name_to_scope, prepare_custom_config_dict
|
||||
|
||||
def _convert(model: GraphModule, is_reference: bool = False,
|
||||
convert_custom_config_dict: Dict[str, Any] = None,
|
||||
is_standalone_module: bool = False,
|
||||
_remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
|
||||
""" standalone_module means it a submodule that is not inlined in
|
||||
parent module, and will be quantized separately as one unit.
|
||||
|
||||
Returns a quantized standalone module, whether input/output is quantized is
|
||||
specified by prepare_custom_config_dict, with
|
||||
input_quantized_idxs, output_quantized_idxs, please
|
||||
see docs for prepare_fx for details
|
||||
"""
|
||||
if convert_custom_config_dict is None:
|
||||
convert_custom_config_dict = {}
|
||||
patterns, node_name_to_scope, prepare_custom_config_dict = restore_state(model)
|
||||
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
|
||||
# always run weight observers in the top level forward method
|
||||
# for dynamic quant ops or weight only quant ops
|
||||
run_weight_observers(model)
|
||||
|
||||
# move to cpu since we only have quantized cpu kernels
|
||||
model.eval().cpu()
|
||||
# mapping from fully qualified module name to module instance
|
||||
# for example,
|
||||
# {
|
||||
# '': Model(...),
|
||||
# 'linear': Linear(...),
|
||||
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
|
||||
# }
|
||||
# We use remove_duplicate=False here because torch.cat uses
|
||||
# the same activation_post_process module instance but different names
|
||||
modules = dict(model.named_modules(remove_duplicate=False))
|
||||
|
||||
custom_module_classes = get_custom_module_class_keys(
|
||||
convert_custom_config_dict,
|
||||
"observed_to_quantized_custom_module_class")
|
||||
matches = find_matches(
|
||||
model.graph, modules, patterns,
|
||||
qconfig_map,
|
||||
custom_module_classes=custom_module_classes)
|
||||
|
||||
quantized_graph = Graph()
|
||||
env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {}
|
||||
|
||||
graph_inputs: List[str] = []
|
||||
for node in model.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
graph_inputs.append(node.name)
|
||||
|
||||
def load_non_quantized(n: Node) -> Node:
|
||||
assert n.name in env, \
|
||||
'trying to load float node but did not find ' + \
|
||||
'node:' + n.name + \
|
||||
' in env: ' + \
|
||||
str(env)
|
||||
quantized_node, dtype = env[n.name]
|
||||
if dtype and dtype != torch.float:
|
||||
env[n.name] = Proxy(quantized_node).dequantize().node, torch.float
|
||||
return env[n.name][0]
|
||||
|
||||
def load_quantized(n: Node) -> Node:
|
||||
assert n.name in env, \
|
||||
'trying to load quantized node but did not find node:' + \
|
||||
n.name + ' in environment:' + str(env)
|
||||
quantized_node, dtype = env[n.name]
|
||||
assert dtype in [torch.quint8, torch.qint8, torch.float16], \
|
||||
f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}'
|
||||
return quantized_node
|
||||
|
||||
def load_x(n: Node) -> Node:
|
||||
assert n.name in env, \
|
||||
'node ' + n.name + ' does not exist in environment'
|
||||
return env[n.name][0]
|
||||
|
||||
def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
|
||||
) -> Callable[[Node], Argument]:
|
||||
"""
|
||||
Input: quantized, which can be None, list, boolean or tuple
|
||||
- if quantized is None, then we'll load the node as long as it
|
||||
exists
|
||||
- if quantized is a boolean, then all args will be
|
||||
quantized/not quantized
|
||||
- if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False)
|
||||
- if quantized is a list or tuple, then arg should be a list and
|
||||
the args with corresponding indexes will be quantized
|
||||
|
||||
|
||||
Output: fn which takes arg_or_args, and loads them from the
|
||||
corresponding environment depending on the value of quantized.
|
||||
"""
|
||||
assert quantized is None or \
|
||||
isinstance(quantized, (tuple, list, bool)), type(quantized)
|
||||
if isinstance(quantized, (tuple, list)) and len(quantized) == 0:
|
||||
# empty tuple or list means nothing is quantized
|
||||
quantized = False
|
||||
|
||||
def load_arg_impl(arg_or_args):
|
||||
# we'll update the format of `quantized`
|
||||
# to better match arg_or_args
|
||||
updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized
|
||||
|
||||
if isinstance(quantized, (tuple, list)) and \
|
||||
len(quantized) == 1 and isinstance(arg_or_args, Node):
|
||||
# when argument is one Node instead of tuple, we just need to check
|
||||
# 0 is in the quantized list
|
||||
updated_quantized = 0 in quantized
|
||||
|
||||
if updated_quantized is None:
|
||||
return map_arg(arg_or_args, load_x)
|
||||
if isinstance(updated_quantized, bool):
|
||||
return map_arg(
|
||||
arg_or_args,
|
||||
load_quantized if updated_quantized else load_non_quantized)
|
||||
elif isinstance(updated_quantized, (tuple, list)):
|
||||
assert isinstance(arg_or_args, (tuple, list)), arg_or_args
|
||||
loaded_args = []
|
||||
# for now, we only support quantizing positional arguments
|
||||
for i, a in enumerate(arg_or_args):
|
||||
if i in updated_quantized:
|
||||
loaded_args.append(map_arg(a, load_quantized))
|
||||
else:
|
||||
loaded_args.append(map_arg(a, load_non_quantized))
|
||||
return type(arg_or_args)(loaded_args)
|
||||
return load_arg_impl
|
||||
|
||||
def node_arg_is_quantized(node_arg: Any) -> bool:
|
||||
if isinstance(node_arg, Node):
|
||||
assert node_arg.name in env, \
|
||||
'Expecting node_arg to be in the environment'
|
||||
if node_arg.name in env:
|
||||
_, dtype = env[node_arg.name]
|
||||
return dtype != torch.float
|
||||
else:
|
||||
return False
|
||||
elif isinstance(node_arg, list):
|
||||
quantized = map(node_arg_is_quantized, node_arg)
|
||||
if all(quantized):
|
||||
return True
|
||||
elif not any(quantized):
|
||||
return False
|
||||
else:
|
||||
raise Exception(
|
||||
"partially quantized inputs in list not handled yet")
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module]) -> bool:
|
||||
""" Check if output node is quantized or not """
|
||||
assert modules is not None
|
||||
# by default the output for a quantizable node is expected to be quantized
|
||||
quantized = True
|
||||
|
||||
# Need to get correct quantized/non-quantized state forn the output
|
||||
# of FixedQParamsQuantizeHandler
|
||||
# TODO: we may want to try to remove the special case here
|
||||
# as well
|
||||
if obj.should_mark_output_quantized_from_input_quantized_status(qconfig):
|
||||
assert node.op in [
|
||||
'call_module',
|
||||
'call_function',
|
||||
'call_method'], \
|
||||
'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
|
||||
# TODO: need to extend this to consider all relevant args instead of just arg[0]
|
||||
quantized = node_arg_is_quantized(node.args[0])
|
||||
|
||||
# the output is unquantized if the node is not a CopyNode
|
||||
# or the activation is not statically quantized
|
||||
if not activation_is_statically_quantized(qconfig) or \
|
||||
not obj.input_output_observed():
|
||||
quantized = False
|
||||
if node_return_type_is_int(node):
|
||||
quantized = False
|
||||
|
||||
return quantized
|
||||
|
||||
def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> None:
|
||||
""" Given a activation_post_process module call node, insert a
|
||||
quantize node"""
|
||||
assert modules is not None
|
||||
assert isinstance(node.target, str)
|
||||
observer_module = modules[node.target]
|
||||
prev_node = node.args[0]
|
||||
if observer_module.dtype == torch.float32:
|
||||
# copy the observer for fp32 dtype
|
||||
env[node.name] = quantized_graph.node_copy(
|
||||
node, load_non_quantized), torch.float
|
||||
elif isinstance(prev_node, Node) and prev_node.name in env:
|
||||
# if previous node is already quantized, we'll just remove the
|
||||
# activation_post_process
|
||||
_, prev_dtype = env[prev_node.name]
|
||||
current_dtype = observer_module.dtype
|
||||
if prev_dtype == current_dtype:
|
||||
env[node.name] = env[prev_node.name]
|
||||
else:
|
||||
root_module = modules[""]
|
||||
assert isinstance(prev_node, Node)
|
||||
observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
|
||||
env[node.name] = (
|
||||
quantize_node(
|
||||
load_non_quantized(prev_node),
|
||||
observer_module, node, modules, quantized_graph,
|
||||
node_name_to_scope, is_input=True),
|
||||
observer_dtype)
|
||||
else:
|
||||
# replace activation post process with quantization ops
|
||||
root_module = modules[""]
|
||||
assert isinstance(node.args[0], Node)
|
||||
dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
|
||||
env[node.name] = (
|
||||
quantize_node(
|
||||
load_non_quantized(node.args[0]),
|
||||
observer_module, node, modules,
|
||||
quantized_graph,
|
||||
node_name_to_scope, is_input=True),
|
||||
dtype)
|
||||
|
||||
# additional state to override inputs to be quantized, if specified
|
||||
# by the user
|
||||
placeholder_node_seen_cnt = 0
|
||||
output_node_seen_cnt = 0
|
||||
input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
|
||||
"input_quantized_idxs", [])
|
||||
output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
|
||||
"output_quantized_idxs", [])
|
||||
|
||||
for node in model.graph.nodes:
|
||||
if node.op == "output":
|
||||
cur_output_node_idx = output_node_seen_cnt
|
||||
output_node_seen_cnt += 1
|
||||
if cur_output_node_idx in output_quantized_idxs:
|
||||
# Result are kept quantized if the user specified the
|
||||
# output_quantized_idxs override.
|
||||
graph_output = map_arg(node.args[0], load_x)
|
||||
else:
|
||||
graph_output = map_arg(node.args[0], load_non_quantized)
|
||||
quantized_graph.output(graph_output)
|
||||
continue
|
||||
root_node, matched, matched_pattern, obj, qconfig = \
|
||||
matches.get(node.name, (None, None, None, None, None))
|
||||
if root_node is node:
|
||||
is_observed_standalone_module_node = (
|
||||
node.op == 'call_module' and
|
||||
is_observed_standalone_module(
|
||||
modules[node.target])
|
||||
)
|
||||
if qconfig is None and not is_observed_standalone_module_node:
|
||||
result = quantized_graph.node_copy(
|
||||
node, load_non_quantized)
|
||||
quantized = False
|
||||
else:
|
||||
assert obj is not None
|
||||
# We will get whether the output is quantized or not before
|
||||
# convert for standalone module and after convert
|
||||
# for non-standalone module, since _standalone_module_output_quantized_idxs
|
||||
# is only available in observed standalone module
|
||||
if is_observed_standalone_module_node:
|
||||
out_quant_idxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore[operator] # noqa: B950
|
||||
assert len(out_quant_idxs) <= 1, "Currently standalone only support one output"
|
||||
quantized = 0 in out_quant_idxs
|
||||
|
||||
qconfig = qconfig_map[node.name]
|
||||
result = obj.convert(
|
||||
node, qconfig, modules, quantized_graph, node_name_to_scope, load_arg, is_reference=is_reference,
|
||||
convert_custom_config_dict=convert_custom_config_dict)
|
||||
if not is_observed_standalone_module_node:
|
||||
quantized = is_output_quantized(node, obj, qconfig, modules)
|
||||
|
||||
if quantized:
|
||||
env[node.name] = result, activation_dtype(qconfig)
|
||||
else:
|
||||
env[node.name] = result, torch.float
|
||||
continue
|
||||
elif root_node is not None:
|
||||
if qconfig is None:
|
||||
# This branch is hit if all of these conditions are met:
|
||||
# 1. we are in a fusion pattern of multiple nodes (i.e. add-relu)
|
||||
# 2. the current node is not the "root_node" of the pattern
|
||||
# 3. quantization for this pattern is disabled
|
||||
#
|
||||
# In this case, we need to make sure to populate the env with
|
||||
# intermediate nodes manually, because the QuantizeHandler.convert
|
||||
# function will not be called.
|
||||
result = quantized_graph.node_copy(
|
||||
node, load_non_quantized)
|
||||
env[node.name] = result, torch.float
|
||||
continue
|
||||
|
||||
# handle activation post process calls
|
||||
if node.op == 'call_module' and \
|
||||
is_activation_post_process(modules[node.target]):
|
||||
insert_quantize_node(node, modules)
|
||||
elif node.op == 'placeholder':
|
||||
cur_placeholder_node_idx = placeholder_node_seen_cnt
|
||||
placeholder_node_seen_cnt += 1
|
||||
if cur_placeholder_node_idx in input_quantized_idxs:
|
||||
env[node.name] = \
|
||||
quantized_graph.node_copy(
|
||||
node, load_non_quantized), torch.quint8
|
||||
else:
|
||||
env[node.name] = \
|
||||
quantized_graph.node_copy(node, load_non_quantized), torch.float
|
||||
else:
|
||||
# copy quantized or non-quantized node
|
||||
# get_tensor_info_node like shape works for both
|
||||
# quantized and non-quantized input and output a non-Tensor
|
||||
# (we use None for dtype currently for non-Tensors)
|
||||
if is_get_tensor_info_node(node):
|
||||
env[node.name] = \
|
||||
quantized_graph.node_copy(node, load_x), None
|
||||
else:
|
||||
env[node.name] = \
|
||||
quantized_graph.node_copy(node, load_non_quantized), torch.float
|
||||
|
||||
# remove activation post process
|
||||
act_post_process_removed_graph = Graph()
|
||||
remove_env: Dict[str, Node] = {}
|
||||
|
||||
def load_arg_remove(a: Argument) -> Argument:
|
||||
return map_arg(a, lambda node: remove_env[node.name])
|
||||
|
||||
for node in quantized_graph.nodes:
|
||||
if node.op == 'output':
|
||||
act_post_process_removed_graph.output(
|
||||
map_arg(node.args[0], load_arg_remove))
|
||||
continue
|
||||
if node.op == 'call_module' and \
|
||||
is_activation_post_process(modules[node.target]):
|
||||
# remove activation post process node
|
||||
remove_env[node.name] = remove_env[node.args[0].name]
|
||||
else:
|
||||
remove_env[node.name] = act_post_process_removed_graph.node_copy(
|
||||
node, load_arg_remove)
|
||||
|
||||
# removes qconfig and activation_post_process modules
|
||||
if _remove_qconfig_flag:
|
||||
_remove_qconfig(model)
|
||||
preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
|
||||
model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes)
|
||||
if not is_reference:
|
||||
model = fold_weight(model, node_name_to_scope)
|
||||
return model
|
@ -13,10 +13,11 @@ from ..utils import (
|
||||
)
|
||||
|
||||
from .pattern_utils import (
|
||||
is_match,
|
||||
get_default_fusion_patterns,
|
||||
)
|
||||
|
||||
from .match_utils import is_match
|
||||
|
||||
from .graph_module import (
|
||||
FusedGraphModule
|
||||
)
|
||||
|
222
torch/quantization/fx/match_utils.py
Normal file
222
torch/quantization/fx/match_utils.py
Normal file
@ -0,0 +1,222 @@
|
||||
import sys
|
||||
import torch
|
||||
from torch.fx.graph import (
|
||||
Graph,
|
||||
Node,
|
||||
)
|
||||
from .quantization_types import Pattern
|
||||
from .quantization_patterns import (
|
||||
QuantizeHandler,
|
||||
CustomModuleQuantizeHandler,
|
||||
StandaloneModuleQuantizeHandler,
|
||||
BinaryOpQuantizeHandler,
|
||||
binary_op_supported_dtypes,
|
||||
binary_reference_op_supported_dtypes,
|
||||
)
|
||||
from .qconfig_utils import (
|
||||
QConfigAny,
|
||||
)
|
||||
from .graph_module import (
|
||||
is_observed_standalone_module,
|
||||
)
|
||||
from ..utils import get_qconfig_dtypes
|
||||
|
||||
from typing import Any, Dict, List, Callable, Optional, Tuple, Set
|
||||
|
||||
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
||||
QConfigAny]
|
||||
|
||||
class MatchAllNode:
|
||||
""" A node pattern that matches all nodes
|
||||
"""
|
||||
pass
|
||||
|
||||
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
|
||||
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
|
||||
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
|
||||
# we'll start from the last node of the graph and traverse back.
|
||||
def is_match(modules, node, pattern, max_uses=sys.maxsize):
|
||||
""" Matches a node in fx against a pattern
|
||||
"""
|
||||
if isinstance(pattern, tuple):
|
||||
self_match, *arg_matches = pattern
|
||||
if self_match is getattr:
|
||||
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
|
||||
arg_matches = []
|
||||
else:
|
||||
self_match = pattern
|
||||
arg_matches = []
|
||||
|
||||
if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
|
||||
return True
|
||||
|
||||
if len(node.users) > max_uses:
|
||||
return False
|
||||
|
||||
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
|
||||
if node.op != 'call_module':
|
||||
return False
|
||||
if not type(modules[node.target]) == self_match:
|
||||
return False
|
||||
elif callable(self_match):
|
||||
if node.op != 'call_function' or node.target is not self_match:
|
||||
return False
|
||||
elif node.target is getattr:
|
||||
if node.args[1] != pattern[1]:
|
||||
return False
|
||||
elif isinstance(self_match, str):
|
||||
if node.op != 'call_method' or node.target != self_match:
|
||||
return False
|
||||
elif node.target != self_match:
|
||||
return False
|
||||
|
||||
if not arg_matches:
|
||||
return True
|
||||
|
||||
if len(arg_matches) != len(node.args):
|
||||
return False
|
||||
|
||||
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
|
||||
|
||||
def find_matches(
|
||||
graph: Graph,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
patterns: Dict[Pattern, QuantizeHandler],
|
||||
qconfig_map: Dict[str, QConfigAny],
|
||||
standalone_module_names: List[str] = None,
|
||||
standalone_module_classes: List[Callable] = None,
|
||||
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
|
||||
"""
|
||||
Matches the nodes in the input graph to quantization patterns, and
|
||||
outputs the information needed to quantize them in future steps.
|
||||
|
||||
Inputs:
|
||||
- graph: an fx.Graph object
|
||||
- modules: a mapping of fully qualified module name to instance,
|
||||
for example, {'foo': ModuleFoo, ...}
|
||||
- patterns: a mapping from a tuple of nodes in reverse order to
|
||||
uninitialized QuantizeHandler subclass.
|
||||
|
||||
Outputs a map of
|
||||
node_name ->
|
||||
(node, matched_values, matched_pattern, QuantizeHandler instance,
|
||||
qconfig)
|
||||
|
||||
For example, {
|
||||
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
|
||||
<CopyNodeQuantizeHandler instance>, QConfig(...)),
|
||||
...
|
||||
}
|
||||
"""
|
||||
if custom_module_classes is None:
|
||||
custom_module_classes = []
|
||||
|
||||
if standalone_module_classes is None:
|
||||
standalone_module_classes = []
|
||||
|
||||
if standalone_module_names is None:
|
||||
standalone_module_names = []
|
||||
|
||||
match_map: Dict[str, MatchResult] = {}
|
||||
all_matched : Set[str] = set()
|
||||
|
||||
def record_match(pattern, node, matched):
|
||||
if isinstance(pattern, tuple):
|
||||
s, *args = pattern
|
||||
record_match(s, node, matched)
|
||||
if pattern[0] is not getattr:
|
||||
for subpattern, arg in zip(args, node.args):
|
||||
record_match(subpattern, arg, matched)
|
||||
else:
|
||||
matched.append(node)
|
||||
|
||||
cache_for_no_tensor_check: Dict[Node, bool] = dict()
|
||||
for node in reversed(graph.nodes):
|
||||
if node.name not in match_map and node.name not in all_matched:
|
||||
for pattern, value in patterns.items():
|
||||
if is_match(modules, node, pattern):
|
||||
skip_this_match = False
|
||||
if value is BinaryOpQuantizeHandler:
|
||||
|
||||
# to properly check for dtype support, we need to
|
||||
# navigate to the base node of an add-relu or mul-relu
|
||||
# pattern
|
||||
base_node = node
|
||||
if (
|
||||
(node.op == 'call_function' and
|
||||
node.target is torch.nn.functional.relu) or
|
||||
(node.op == 'call_module' and
|
||||
isinstance(modules[node.target], torch.nn.ReLU))
|
||||
):
|
||||
base_node = node.args[0]
|
||||
|
||||
this_node_qconfig = \
|
||||
qconfig_map[base_node.name]
|
||||
if this_node_qconfig:
|
||||
dtypes = get_qconfig_dtypes(this_node_qconfig)
|
||||
# TODO(future PR): update the pattern to quantize
|
||||
# handler logic to take this into account.
|
||||
|
||||
|
||||
# This needs to handle 3 cases
|
||||
# 1) op and dtype is in either [is_ref or non-ref] list -> don't skip
|
||||
# 2) op is not in either list (i.e. relu) -> don't skip
|
||||
# 3) op is in non-ref list, but not for dtype, and op+dtype not in is_ref list -> skip
|
||||
|
||||
# note: the value of is_reference is unknown at prepare, so we have to cover both cases
|
||||
# handle is_reference = False
|
||||
skip_match_not_is_reference = (
|
||||
(base_node.target in binary_op_supported_dtypes) and
|
||||
(dtypes not in binary_op_supported_dtypes[base_node.target])
|
||||
)
|
||||
|
||||
# handle is_reference = True
|
||||
supported_is_reference = (
|
||||
(base_node.target in binary_reference_op_supported_dtypes) and
|
||||
(dtypes in binary_reference_op_supported_dtypes[base_node.target])
|
||||
)
|
||||
|
||||
# only skip if not reference says skip and is_reference doesn't support
|
||||
skip_this_match = skip_match_not_is_reference and not supported_is_reference
|
||||
|
||||
if not skip_this_match:
|
||||
matched: List[Any] = []
|
||||
record_match(pattern, node, matched)
|
||||
for n in matched:
|
||||
match_map[n.name] = (
|
||||
node, matched, pattern, value(node, modules), # type: ignore[operator]
|
||||
qconfig_map[n.name])
|
||||
all_matched.add(n.name)
|
||||
# break after finding the first match
|
||||
break
|
||||
|
||||
# add custom module instances to the match result
|
||||
assert modules is not None
|
||||
for node in graph.nodes:
|
||||
if node.op == 'call_module' and \
|
||||
type(modules[node.target]) in custom_module_classes:
|
||||
custom_module_qconfig = qconfig_map[node.name]
|
||||
match_map[node.name] = (
|
||||
node, [node], None, CustomModuleQuantizeHandler(node, modules),
|
||||
custom_module_qconfig)
|
||||
|
||||
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
|
||||
assert modules is not None
|
||||
return (
|
||||
node_target in standalone_module_names or # type: ignore[operator]
|
||||
type(modules[node_target]) in standalone_module_classes # type: ignore[operator]
|
||||
)
|
||||
|
||||
# add standalone modules to the match
|
||||
for node in graph.nodes:
|
||||
if node.op == 'call_module' and \
|
||||
(is_standalone_module(node.target, modules) or
|
||||
is_observed_standalone_module(modules[node.target])):
|
||||
# add node to matched nodes
|
||||
custom_module_qconfig = qconfig_map[node.name]
|
||||
match_map[node.name] = (
|
||||
node, [node], None,
|
||||
StandaloneModuleQuantizeHandler(node, modules),
|
||||
custom_module_qconfig)
|
||||
|
||||
return match_map
|
@ -1,13 +1,20 @@
|
||||
import torch
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Any
|
||||
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
from torch.fx.graph import (
|
||||
Node,
|
||||
)
|
||||
from .quantization_types import Pattern
|
||||
from .qconfig_utils import QConfigAny
|
||||
# from .quantization_patterns import BinaryOpQuantizeHandler
|
||||
|
||||
|
||||
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
|
||||
QuantizeHandler = Any
|
||||
|
||||
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
||||
QConfigAny]
|
||||
|
||||
# pattern for conv bn fusion
|
||||
DEFAULT_FUSION_PATTERNS = OrderedDict()
|
||||
def register_fusion_pattern(pattern):
|
||||
@ -44,61 +51,9 @@ def get_default_output_activation_post_process_map() -> Dict[Pattern, torch.quan
|
||||
return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP
|
||||
|
||||
|
||||
class MatchAllNode:
|
||||
""" A node pattern that matches all nodes
|
||||
"""
|
||||
pass
|
||||
|
||||
# Example use of register pattern function:
|
||||
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
# class ConvBNReLUFusion():
|
||||
# def __init__(...):
|
||||
# ...
|
||||
#
|
||||
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
|
||||
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
|
||||
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
|
||||
# we'll start from the last node of the graph and traverse back.
|
||||
|
||||
def is_match(modules, node, pattern, max_uses=sys.maxsize):
|
||||
""" Matches a node in fx against a pattern
|
||||
"""
|
||||
if isinstance(pattern, tuple):
|
||||
self_match, *arg_matches = pattern
|
||||
if self_match is getattr:
|
||||
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
|
||||
arg_matches = []
|
||||
else:
|
||||
self_match = pattern
|
||||
arg_matches = []
|
||||
|
||||
if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
|
||||
return True
|
||||
|
||||
if len(node.users) > max_uses:
|
||||
return False
|
||||
|
||||
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
|
||||
if node.op != 'call_module':
|
||||
return False
|
||||
if not type(modules[node.target]) == self_match:
|
||||
return False
|
||||
elif callable(self_match):
|
||||
if node.op != 'call_function' or node.target is not self_match:
|
||||
return False
|
||||
elif node.target is getattr:
|
||||
if node.args[1] != pattern[1]:
|
||||
return False
|
||||
elif isinstance(self_match, str):
|
||||
if node.op != 'call_method' or node.target != self_match:
|
||||
return False
|
||||
elif node.target != self_match:
|
||||
return False
|
||||
|
||||
if not arg_matches:
|
||||
return True
|
||||
|
||||
if len(arg_matches) != len(node.args):
|
||||
return False
|
||||
|
||||
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
|
||||
|
1037
torch/quantization/fx/prepare.py
Normal file
1037
torch/quantization/fx/prepare.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
from typing import Union, Callable, Tuple, Any
|
||||
from typing import Any, Callable, Tuple, Union
|
||||
|
||||
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Callable, Callable]]
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -13,6 +13,14 @@ from torch.fx.graph import (
|
||||
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union
|
||||
import operator
|
||||
|
||||
# A dictionary for querying the weight index for a given op
|
||||
WEIGHT_INDEX_DICT = {
|
||||
torch.nn.functional.conv1d : [1],
|
||||
torch.nn.functional.conv2d : [1],
|
||||
torch.nn.functional.conv3d : [1],
|
||||
torch.nn.functional.linear : [1],
|
||||
}
|
||||
|
||||
# turn foo.bar -> ['foo', 'bar']
|
||||
def _parent_name(target):
|
||||
r = target.rsplit('.', 1)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule, map_arg
|
||||
from torch.fx.graph import Graph, Node
|
||||
from torch.quantization.fx.quantize import is_activation_post_process
|
||||
from torch.quantization.quantize import is_activation_post_process
|
||||
from torch.quantization.fx.utils import get_new_attr_name_with_prefix
|
||||
|
||||
from .utils import (
|
||||
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
toq = torch.ops.quantized
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Node
|
||||
from torch.quantization.fx.quantize import is_activation_post_process
|
||||
from torch.quantization.quantize import is_activation_post_process
|
||||
|
||||
from .ns_types import NSNodeTargetType
|
||||
|
||||
|
Reference in New Issue
Block a user