[quant][graphmode][fx][refactor] Split quantize.py to prepare.py and convert.py (#59353)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59353

Next: remove Quantizer class

Test Plan: Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D28856277

fbshipit-source-id: 25f5502be387dbe9706780f667501b46b82789a5
This commit is contained in:
Jerry Zhang
2021-06-02 23:51:28 -07:00
committed by Facebook GitHub Bot
parent 8b4784a9c6
commit 18642e664a
12 changed files with 1769 additions and 1701 deletions

View File

@ -16,7 +16,7 @@ from torch.quantization.quantize_fx import (
from torch.quantization.fx.quantization_patterns import DefaultNodeQuantizeHandler
from torch.quantization.fx.pattern_utils import (
from torch.quantization.fx.match_utils import (
is_match,
MatchAllNode,
)

View File

@ -46,9 +46,9 @@ class _LearnableFakeQuantize(torch.quantization.FakeQuantizeBase):
self.activation_post_process = observer(**observer_kwargs)
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
'quant_min out of bound'
'quant_min out of bound'
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
'quant_max out of bound'
'quant_max out of bound'
self.dtype = self.activation_post_process.dtype
self.qscheme = self.activation_post_process.qscheme
self.ch_axis = self.activation_post_process.ch_axis \

View File

@ -0,0 +1,481 @@
from typing import Any, Dict, Tuple, List, Callable, Optional, Union
import torch
from torch.fx import (
GraphModule,
Proxy,
map_arg
)
from torch.fx.graph import (
Graph,
Node,
)
from torch.fx.node import Argument
from .quantization_types import Pattern
from .qconfig_utils import QConfigAny
from .match_utils import (
find_matches,
)
from .graph_module import (
is_observed_module,
is_observed_standalone_module,
QuantizedGraphModule,
)
from .quantization_patterns import (
QuantizeHandler,
)
from .utils import (
is_get_tensor_info_node,
node_return_type_is_int,
quantize_node,
get_new_attr_name_with_prefix,
collect_producer_nodes,
graph_module_from_producer_nodes,
get_custom_module_class_keys,
WEIGHT_INDEX_DICT,
)
from ..quantize import (
_remove_qconfig,
is_activation_post_process,
)
from ..utils import (
activation_is_statically_quantized,
activation_dtype,
)
# weight prepacking ops
WEIGHT_PREPACK_OPS = {
torch._ops.ops.quantized.linear_prepack,
torch._ops.ops.quantized.linear_prepack_fp16,
torch._ops.ops.quantized.conv1d_prepack,
torch._ops.ops.quantized.conv2d_prepack,
torch._ops.ops.quantized.conv3d_prepack,
}
def run_weight_observers(observed: GraphModule) -> None:
r''' Extract the subgraph that produces the weight for dynamic quant
or weight only quant node and run the subgraph to observe the weight.
Note that the observers of dynamic quant or weight only quant ops are
run during the convert step.
'''
for node in observed.graph.nodes:
if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if i in WEIGHT_INDEX_DICT[node.target]:
# node_arg is weight
weight_observer_nodes = collect_producer_nodes(node_arg)
if weight_observer_nodes is not None:
weight_observer_module = \
graph_module_from_producer_nodes(
observed, weight_observer_nodes)
# run the weight observer
weight_observer_module()
def fold_weight(
quantized: QuantizedGraphModule,
node_name_to_scope: Dict[str, Tuple[str, type]]) -> QuantizedGraphModule:
"""
Trace back from the weight node util we hit getattr, reconstruct the
graph module with the traced nodes and run the graph module to pack the
weight. then replace the original chain of ops with the packed weight.
"""
packed_weights = dict()
# map from folded node name to the prepacked weight name
folded_nodes = dict()
# get packed weights
for node in quantized.graph.nodes:
if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
nodes_to_fold = collect_producer_nodes(node)
if nodes_to_fold is not None:
for node_to_fold in nodes_to_fold:
folded_nodes[node_to_fold.name] = node
prepacking_module = graph_module_from_producer_nodes(
quantized, nodes_to_fold)
packed_weight = prepacking_module()
packed_weights[node.name] = packed_weight
# remove folded nodes and replace the prepacking node with getattr
folded_graph = Graph()
env: Dict[Any, Any] = {}
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
quantized_root = quantized
quantized_graph = quantized.graph
for node in quantized_graph.nodes:
prepack_node = folded_nodes.get(node.name, None)
if prepack_node is node:
packed_weight = packed_weights[node.name]
# add a prepacked attribute to root
op_node = list(prepack_node.users)[0]
module_path, _ = node_name_to_scope[op_node.name]
get_new_packed_weight_name = \
get_new_attr_name_with_prefix(module_path + '_packed_weight_')
packed_weight_name = get_new_packed_weight_name(quantized_root)
setattr(quantized_root, packed_weight_name, packed_weight)
# replace prepack node with a getattr node
env[node.name] = folded_graph.create_node(
'get_attr', packed_weight_name, (), {})
elif prepack_node is not None:
# remove the foled node
continue
else:
# copy other nodes
env[node.name] = folded_graph.node_copy(node, load_arg)
quantized = QuantizedGraphModule(quantized_root, folded_graph, quantized_root.preserved_attr_names)
return quantized
def restore_state(
observed: GraphModule
) -> Tuple[Dict[Pattern, QuantizeHandler], Dict[str, Tuple[str, type]], Dict[str, Any]]:
assert is_observed_module(observed), \
'incoming model must be produced by prepare_fx'
prepare_custom_config_dict: Dict[str, Any] = \
observed._prepare_custom_config_dict # type: ignore[assignment]
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
patterns: Dict[Pattern, QuantizeHandler] = observed._patterns # type: ignore[assignment]
return patterns, node_name_to_scope, prepare_custom_config_dict
def _convert(model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Returns a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config_dict, with
input_quantized_idxs, output_quantized_idxs, please
see docs for prepare_fx for details
"""
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
patterns, node_name_to_scope, prepare_custom_config_dict = restore_state(model)
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
# always run weight observers in the top level forward method
# for dynamic quant ops or weight only quant ops
run_weight_observers(model)
# move to cpu since we only have quantized cpu kernels
model.eval().cpu()
# mapping from fully qualified module name to module instance
# for example,
# {
# '': Model(...),
# 'linear': Linear(...),
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
# }
# We use remove_duplicate=False here because torch.cat uses
# the same activation_post_process module instance but different names
modules = dict(model.named_modules(remove_duplicate=False))
custom_module_classes = get_custom_module_class_keys(
convert_custom_config_dict,
"observed_to_quantized_custom_module_class")
matches = find_matches(
model.graph, modules, patterns,
qconfig_map,
custom_module_classes=custom_module_classes)
quantized_graph = Graph()
env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {}
graph_inputs: List[str] = []
for node in model.graph.nodes:
if node.op == 'placeholder':
graph_inputs.append(node.name)
def load_non_quantized(n: Node) -> Node:
assert n.name in env, \
'trying to load float node but did not find ' + \
'node:' + n.name + \
' in env: ' + \
str(env)
quantized_node, dtype = env[n.name]
if dtype and dtype != torch.float:
env[n.name] = Proxy(quantized_node).dequantize().node, torch.float
return env[n.name][0]
def load_quantized(n: Node) -> Node:
assert n.name in env, \
'trying to load quantized node but did not find node:' + \
n.name + ' in environment:' + str(env)
quantized_node, dtype = env[n.name]
assert dtype in [torch.quint8, torch.qint8, torch.float16], \
f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}'
return quantized_node
def load_x(n: Node) -> Node:
assert n.name in env, \
'node ' + n.name + ' does not exist in environment'
return env[n.name][0]
def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
) -> Callable[[Node], Argument]:
"""
Input: quantized, which can be None, list, boolean or tuple
- if quantized is None, then we'll load the node as long as it
exists
- if quantized is a boolean, then all args will be
quantized/not quantized
- if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False)
- if quantized is a list or tuple, then arg should be a list and
the args with corresponding indexes will be quantized
Output: fn which takes arg_or_args, and loads them from the
corresponding environment depending on the value of quantized.
"""
assert quantized is None or \
isinstance(quantized, (tuple, list, bool)), type(quantized)
if isinstance(quantized, (tuple, list)) and len(quantized) == 0:
# empty tuple or list means nothing is quantized
quantized = False
def load_arg_impl(arg_or_args):
# we'll update the format of `quantized`
# to better match arg_or_args
updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized
if isinstance(quantized, (tuple, list)) and \
len(quantized) == 1 and isinstance(arg_or_args, Node):
# when argument is one Node instead of tuple, we just need to check
# 0 is in the quantized list
updated_quantized = 0 in quantized
if updated_quantized is None:
return map_arg(arg_or_args, load_x)
if isinstance(updated_quantized, bool):
return map_arg(
arg_or_args,
load_quantized if updated_quantized else load_non_quantized)
elif isinstance(updated_quantized, (tuple, list)):
assert isinstance(arg_or_args, (tuple, list)), arg_or_args
loaded_args = []
# for now, we only support quantizing positional arguments
for i, a in enumerate(arg_or_args):
if i in updated_quantized:
loaded_args.append(map_arg(a, load_quantized))
else:
loaded_args.append(map_arg(a, load_non_quantized))
return type(arg_or_args)(loaded_args)
return load_arg_impl
def node_arg_is_quantized(node_arg: Any) -> bool:
if isinstance(node_arg, Node):
assert node_arg.name in env, \
'Expecting node_arg to be in the environment'
if node_arg.name in env:
_, dtype = env[node_arg.name]
return dtype != torch.float
else:
return False
elif isinstance(node_arg, list):
quantized = map(node_arg_is_quantized, node_arg)
if all(quantized):
return True
elif not any(quantized):
return False
else:
raise Exception(
"partially quantized inputs in list not handled yet")
else:
return False
def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module]) -> bool:
""" Check if output node is quantized or not """
assert modules is not None
# by default the output for a quantizable node is expected to be quantized
quantized = True
# Need to get correct quantized/non-quantized state forn the output
# of FixedQParamsQuantizeHandler
# TODO: we may want to try to remove the special case here
# as well
if obj.should_mark_output_quantized_from_input_quantized_status(qconfig):
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
# TODO: need to extend this to consider all relevant args instead of just arg[0]
quantized = node_arg_is_quantized(node.args[0])
# the output is unquantized if the node is not a CopyNode
# or the activation is not statically quantized
if not activation_is_statically_quantized(qconfig) or \
not obj.input_output_observed():
quantized = False
if node_return_type_is_int(node):
quantized = False
return quantized
def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> None:
""" Given a activation_post_process module call node, insert a
quantize node"""
assert modules is not None
assert isinstance(node.target, str)
observer_module = modules[node.target]
prev_node = node.args[0]
if observer_module.dtype == torch.float32:
# copy the observer for fp32 dtype
env[node.name] = quantized_graph.node_copy(
node, load_non_quantized), torch.float
elif isinstance(prev_node, Node) and prev_node.name in env:
# if previous node is already quantized, we'll just remove the
# activation_post_process
_, prev_dtype = env[prev_node.name]
current_dtype = observer_module.dtype
if prev_dtype == current_dtype:
env[node.name] = env[prev_node.name]
else:
root_module = modules[""]
assert isinstance(prev_node, Node)
observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
env[node.name] = (
quantize_node(
load_non_quantized(prev_node),
observer_module, node, modules, quantized_graph,
node_name_to_scope, is_input=True),
observer_dtype)
else:
# replace activation post process with quantization ops
root_module = modules[""]
assert isinstance(node.args[0], Node)
dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
env[node.name] = (
quantize_node(
load_non_quantized(node.args[0]),
observer_module, node, modules,
quantized_graph,
node_name_to_scope, is_input=True),
dtype)
# additional state to override inputs to be quantized, if specified
# by the user
placeholder_node_seen_cnt = 0
output_node_seen_cnt = 0
input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
"input_quantized_idxs", [])
output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
"output_quantized_idxs", [])
for node in model.graph.nodes:
if node.op == "output":
cur_output_node_idx = output_node_seen_cnt
output_node_seen_cnt += 1
if cur_output_node_idx in output_quantized_idxs:
# Result are kept quantized if the user specified the
# output_quantized_idxs override.
graph_output = map_arg(node.args[0], load_x)
else:
graph_output = map_arg(node.args[0], load_non_quantized)
quantized_graph.output(graph_output)
continue
root_node, matched, matched_pattern, obj, qconfig = \
matches.get(node.name, (None, None, None, None, None))
if root_node is node:
is_observed_standalone_module_node = (
node.op == 'call_module' and
is_observed_standalone_module(
modules[node.target])
)
if qconfig is None and not is_observed_standalone_module_node:
result = quantized_graph.node_copy(
node, load_non_quantized)
quantized = False
else:
assert obj is not None
# We will get whether the output is quantized or not before
# convert for standalone module and after convert
# for non-standalone module, since _standalone_module_output_quantized_idxs
# is only available in observed standalone module
if is_observed_standalone_module_node:
out_quant_idxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore[operator] # noqa: B950
assert len(out_quant_idxs) <= 1, "Currently standalone only support one output"
quantized = 0 in out_quant_idxs
qconfig = qconfig_map[node.name]
result = obj.convert(
node, qconfig, modules, quantized_graph, node_name_to_scope, load_arg, is_reference=is_reference,
convert_custom_config_dict=convert_custom_config_dict)
if not is_observed_standalone_module_node:
quantized = is_output_quantized(node, obj, qconfig, modules)
if quantized:
env[node.name] = result, activation_dtype(qconfig)
else:
env[node.name] = result, torch.float
continue
elif root_node is not None:
if qconfig is None:
# This branch is hit if all of these conditions are met:
# 1. we are in a fusion pattern of multiple nodes (i.e. add-relu)
# 2. the current node is not the "root_node" of the pattern
# 3. quantization for this pattern is disabled
#
# In this case, we need to make sure to populate the env with
# intermediate nodes manually, because the QuantizeHandler.convert
# function will not be called.
result = quantized_graph.node_copy(
node, load_non_quantized)
env[node.name] = result, torch.float
continue
# handle activation post process calls
if node.op == 'call_module' and \
is_activation_post_process(modules[node.target]):
insert_quantize_node(node, modules)
elif node.op == 'placeholder':
cur_placeholder_node_idx = placeholder_node_seen_cnt
placeholder_node_seen_cnt += 1
if cur_placeholder_node_idx in input_quantized_idxs:
env[node.name] = \
quantized_graph.node_copy(
node, load_non_quantized), torch.quint8
else:
env[node.name] = \
quantized_graph.node_copy(node, load_non_quantized), torch.float
else:
# copy quantized or non-quantized node
# get_tensor_info_node like shape works for both
# quantized and non-quantized input and output a non-Tensor
# (we use None for dtype currently for non-Tensors)
if is_get_tensor_info_node(node):
env[node.name] = \
quantized_graph.node_copy(node, load_x), None
else:
env[node.name] = \
quantized_graph.node_copy(node, load_non_quantized), torch.float
# remove activation post process
act_post_process_removed_graph = Graph()
remove_env: Dict[str, Node] = {}
def load_arg_remove(a: Argument) -> Argument:
return map_arg(a, lambda node: remove_env[node.name])
for node in quantized_graph.nodes:
if node.op == 'output':
act_post_process_removed_graph.output(
map_arg(node.args[0], load_arg_remove))
continue
if node.op == 'call_module' and \
is_activation_post_process(modules[node.target]):
# remove activation post process node
remove_env[node.name] = remove_env[node.args[0].name]
else:
remove_env[node.name] = act_post_process_removed_graph.node_copy(
node, load_arg_remove)
# removes qconfig and activation_post_process modules
if _remove_qconfig_flag:
_remove_qconfig(model)
preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes)
if not is_reference:
model = fold_weight(model, node_name_to_scope)
return model

View File

@ -13,10 +13,11 @@ from ..utils import (
)
from .pattern_utils import (
is_match,
get_default_fusion_patterns,
)
from .match_utils import is_match
from .graph_module import (
FusedGraphModule
)

View File

@ -0,0 +1,222 @@
import sys
import torch
from torch.fx.graph import (
Graph,
Node,
)
from .quantization_types import Pattern
from .quantization_patterns import (
QuantizeHandler,
CustomModuleQuantizeHandler,
StandaloneModuleQuantizeHandler,
BinaryOpQuantizeHandler,
binary_op_supported_dtypes,
binary_reference_op_supported_dtypes,
)
from .qconfig_utils import (
QConfigAny,
)
from .graph_module import (
is_observed_standalone_module,
)
from ..utils import get_qconfig_dtypes
from typing import Any, Dict, List, Callable, Optional, Tuple, Set
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
class MatchAllNode:
""" A node pattern that matches all nodes
"""
pass
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back.
def is_match(modules, node, pattern, max_uses=sys.maxsize):
""" Matches a node in fx against a pattern
"""
if isinstance(pattern, tuple):
self_match, *arg_matches = pattern
if self_match is getattr:
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
arg_matches = []
else:
self_match = pattern
arg_matches = []
if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
return True
if len(node.users) > max_uses:
return False
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != 'call_module':
return False
if not type(modules[node.target]) == self_match:
return False
elif callable(self_match):
if node.op != 'call_function' or node.target is not self_match:
return False
elif node.target is getattr:
if node.args[1] != pattern[1]:
return False
elif isinstance(self_match, str):
if node.op != 'call_method' or node.target != self_match:
return False
elif node.target != self_match:
return False
if not arg_matches:
return True
if len(arg_matches) != len(node.args):
return False
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
def find_matches(
graph: Graph,
modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
qconfig_map: Dict[str, QConfigAny],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Callable] = None,
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
Inputs:
- graph: an fx.Graph object
- modules: a mapping of fully qualified module name to instance,
for example, {'foo': ModuleFoo, ...}
- patterns: a mapping from a tuple of nodes in reverse order to
uninitialized QuantizeHandler subclass.
Outputs a map of
node_name ->
(node, matched_values, matched_pattern, QuantizeHandler instance,
qconfig)
For example, {
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
<CopyNodeQuantizeHandler instance>, QConfig(...)),
...
}
"""
if custom_module_classes is None:
custom_module_classes = []
if standalone_module_classes is None:
standalone_module_classes = []
if standalone_module_names is None:
standalone_module_names = []
match_map: Dict[str, MatchResult] = {}
all_matched : Set[str] = set()
def record_match(pattern, node, matched):
if isinstance(pattern, tuple):
s, *args = pattern
record_match(s, node, matched)
if pattern[0] is not getattr:
for subpattern, arg in zip(args, node.args):
record_match(subpattern, arg, matched)
else:
matched.append(node)
cache_for_no_tensor_check: Dict[Node, bool] = dict()
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, value in patterns.items():
if is_match(modules, node, pattern):
skip_this_match = False
if value is BinaryOpQuantizeHandler:
# to properly check for dtype support, we need to
# navigate to the base node of an add-relu or mul-relu
# pattern
base_node = node
if (
(node.op == 'call_function' and
node.target is torch.nn.functional.relu) or
(node.op == 'call_module' and
isinstance(modules[node.target], torch.nn.ReLU))
):
base_node = node.args[0]
this_node_qconfig = \
qconfig_map[base_node.name]
if this_node_qconfig:
dtypes = get_qconfig_dtypes(this_node_qconfig)
# TODO(future PR): update the pattern to quantize
# handler logic to take this into account.
# This needs to handle 3 cases
# 1) op and dtype is in either [is_ref or non-ref] list -> don't skip
# 2) op is not in either list (i.e. relu) -> don't skip
# 3) op is in non-ref list, but not for dtype, and op+dtype not in is_ref list -> skip
# note: the value of is_reference is unknown at prepare, so we have to cover both cases
# handle is_reference = False
skip_match_not_is_reference = (
(base_node.target in binary_op_supported_dtypes) and
(dtypes not in binary_op_supported_dtypes[base_node.target])
)
# handle is_reference = True
supported_is_reference = (
(base_node.target in binary_reference_op_supported_dtypes) and
(dtypes in binary_reference_op_supported_dtypes[base_node.target])
)
# only skip if not reference says skip and is_reference doesn't support
skip_this_match = skip_match_not_is_reference and not supported_is_reference
if not skip_this_match:
matched: List[Any] = []
record_match(pattern, node, matched)
for n in matched:
match_map[n.name] = (
node, matched, pattern, value(node, modules), # type: ignore[operator]
qconfig_map[n.name])
all_matched.add(n.name)
# break after finding the first match
break
# add custom module instances to the match result
assert modules is not None
for node in graph.nodes:
if node.op == 'call_module' and \
type(modules[node.target]) in custom_module_classes:
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
node, [node], None, CustomModuleQuantizeHandler(node, modules),
custom_module_qconfig)
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
assert modules is not None
return (
node_target in standalone_module_names or # type: ignore[operator]
type(modules[node_target]) in standalone_module_classes # type: ignore[operator]
)
# add standalone modules to the match
for node in graph.nodes:
if node.op == 'call_module' and \
(is_standalone_module(node.target, modules) or
is_observed_standalone_module(modules[node.target])):
# add node to matched nodes
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
node, [node], None,
StandaloneModuleQuantizeHandler(node, modules),
custom_module_qconfig)
return match_map

View File

@ -1,13 +1,20 @@
import torch
import sys
from collections import OrderedDict
from typing import Dict, Any
from typing import Dict, Any, Tuple, List, Optional
from torch.fx.graph import (
Node,
)
from .quantization_types import Pattern
from .qconfig_utils import QConfigAny
# from .quantization_patterns import BinaryOpQuantizeHandler
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
QuantizeHandler = Any
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
# pattern for conv bn fusion
DEFAULT_FUSION_PATTERNS = OrderedDict()
def register_fusion_pattern(pattern):
@ -44,61 +51,9 @@ def get_default_output_activation_post_process_map() -> Dict[Pattern, torch.quan
return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP
class MatchAllNode:
""" A node pattern that matches all nodes
"""
pass
# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvBNReLUFusion():
# def __init__(...):
# ...
#
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back.
def is_match(modules, node, pattern, max_uses=sys.maxsize):
""" Matches a node in fx against a pattern
"""
if isinstance(pattern, tuple):
self_match, *arg_matches = pattern
if self_match is getattr:
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
arg_matches = []
else:
self_match = pattern
arg_matches = []
if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
return True
if len(node.users) > max_uses:
return False
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != 'call_module':
return False
if not type(modules[node.target]) == self_match:
return False
elif callable(self_match):
if node.op != 'call_function' or node.target is not self_match:
return False
elif node.target is getattr:
if node.args[1] != pattern[1]:
return False
elif isinstance(self_match, str):
if node.op != 'call_method' or node.target != self_match:
return False
elif node.target != self_match:
return False
if not arg_matches:
return True
if len(arg_matches) != len(node.args):
return False
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
from typing import Union, Callable, Tuple, Any
from typing import Any, Callable, Tuple, Union
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Callable, Callable]]

File diff suppressed because it is too large Load Diff

View File

@ -13,6 +13,14 @@ from torch.fx.graph import (
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union
import operator
# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv1d : [1],
torch.nn.functional.conv2d : [1],
torch.nn.functional.conv3d : [1],
torch.nn.functional.linear : [1],
}
# turn foo.bar -> ['foo', 'bar']
def _parent_name(target):
r = target.rsplit('.', 1)

View File

@ -1,7 +1,7 @@
import torch
from torch.fx import GraphModule, map_arg
from torch.fx.graph import Graph, Node
from torch.quantization.fx.quantize import is_activation_post_process
from torch.quantization.quantize import is_activation_post_process
from torch.quantization.fx.utils import get_new_attr_name_with_prefix
from .utils import (

View File

@ -6,7 +6,7 @@ import torch.nn as nn
toq = torch.ops.quantized
from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.quantization.fx.quantize import is_activation_post_process
from torch.quantization.quantize import is_activation_post_process
from .ns_types import NSNodeTargetType