[ao][fx] fixing public v private match_utils.py (#88396)

Summary: made _is_match, _find_matches, _MatchResult private also added
__all__ to lower_to_qnnpack.py

Test Plan: python test/test_public_bindings.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D41015540](https://our.internmc.facebook.com/intern/diff/D41015540)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88396
Approved by: https://github.com/jcaip
This commit is contained in:
HDCharles
2022-12-12 19:46:04 -08:00
committed by PyTorch MergeBot
parent a856557b3a
commit 79156c11c3
10 changed files with 32 additions and 32 deletions

View File

@ -135,10 +135,10 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_function_import_fx_match_utils(self): def test_function_import_fx_match_utils(self):
function_list = [ function_list = [
'MatchResult', '_MatchResult',
'MatchAllNode', 'MatchAllNode',
'is_match', '_is_match',
'find_matches' '_find_matches'
] ]
self._test_function_import('fx.match_utils', function_list) self._test_function_import('fx.match_utils', function_list)

View File

@ -27,7 +27,7 @@ from torch.ao.quantization.quantize_fx import (
from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler
from torch.ao.quantization.fx.match_utils import ( from torch.ao.quantization.fx.match_utils import (
is_match, _is_match,
MatchAllNode, MatchAllNode,
) )
@ -711,7 +711,7 @@ class TestQuantizeFx(QuantizationTestCase):
modules = dict(m.named_modules()) modules = dict(m.named_modules())
for n in m.graph.nodes: for n in m.graph.nodes:
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU: if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
self.assertTrue(is_match(modules, n, pattern)) self.assertTrue(_is_match(modules, n, pattern))
def test_pattern_match_constant(self): def test_pattern_match_constant(self):
class M(torch.nn.Module): class M(torch.nn.Module):
@ -727,7 +727,7 @@ class TestQuantizeFx(QuantizationTestCase):
modules = dict(m.named_modules()) modules = dict(m.named_modules())
for n in m.graph.nodes: for n in m.graph.nodes:
if n.op == "call_function" and n.target == operator.getitem: if n.op == "call_function" and n.target == operator.getitem:
self.assertTrue(is_match(modules, n, pattern)) self.assertTrue(_is_match(modules, n, pattern))
def test_fused_module_qat_swap(self): def test_fused_module_qat_swap(self):
class Tmp(torch.nn.Module): class Tmp(torch.nn.Module):

View File

@ -121,7 +121,7 @@ from .fx.ns_types import (
) )
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
from torch.ao.quantization.backend_config import BackendConfig from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.fx.match_utils import find_matches from torch.ao.quantization.fx.match_utils import _find_matches
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
from torch.ao.quantization.qconfig import QConfigAny from torch.ao.quantization.qconfig import QConfigAny
@ -815,7 +815,7 @@ def prepare_n_shadows_model(
standalone_module_names: List[str] = [] standalone_module_names: List[str] = []
standalone_module_classes: List[Type] = [] standalone_module_classes: List[Type] = []
custom_module_classes: List[Type] = [] custom_module_classes: List[Type] = []
matches = find_matches( matches = _find_matches(
mt.graph, modules, patterns, root_node_getter_mapping, mt.graph, modules, patterns, root_node_getter_mapping,
standalone_module_names, standalone_module_classes, custom_module_classes) standalone_module_names, standalone_module_classes, custom_module_classes)
subgraphs_dedup: Dict[str, List[Node]] = \ subgraphs_dedup: Dict[str, List[Node]] = \

View File

@ -20,7 +20,7 @@ from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from torch.ao.quantization.qconfig import QConfigAny from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.utils import getattr_from_fqn from torch.ao.quantization.utils import getattr_from_fqn
from torch.ao.quantization.fx.match_utils import MatchResult from torch.ao.quantization.fx.match_utils import _MatchResult
import collections import collections
import copy import copy
@ -100,7 +100,7 @@ class OutputProp:
return None return None
def _get_dedup_subgraphs( def _get_dedup_subgraphs(
matches: Dict[str, MatchResult] matches: Dict[str, _MatchResult]
) -> Dict[str, List[Node]]: ) -> Dict[str, List[Node]]:
# the original matches variable is unique by node, make it unique by subgraph # the original matches variable is unique by node, make it unique by subgraph
# instead # instead
@ -110,7 +110,7 @@ def _get_dedup_subgraphs(
# Dict items are not reversible until Python 3.8, so we hack it # Dict items are not reversible until Python 3.8, so we hack it
# to be compatible with previous Python versions # to be compatible with previous Python versions
# TODO(future PR): try reversed(list(matches.items())) # TODO(future PR): try reversed(list(matches.items()))
matches_items_reversed: List[Tuple[str, MatchResult]] = [] matches_items_reversed: List[Tuple[str, _MatchResult]] = []
for name, cur_match in matches.items(): for name, cur_match in matches.items():
matches_items_reversed.insert(0, (name, cur_match)) matches_items_reversed.insert(0, (name, cur_match))
@ -162,7 +162,7 @@ def _get_dedup_subgraphs(
assert len(cur_match[1]) == 2 assert len(cur_match[1]) == 2
# either (a, b), or ((a, b), c) or (c, (a, b)) # either (a, b), or ((a, b), c) or (c, (a, b))
# cannot make any assumptions on order, not clear what the # cannot make any assumptions on order, not clear what the
# find_matches function is doing to populate this # _find_matches function is doing to populate this
# TODO(future PR): make this code less confusing, see discussion # TODO(future PR): make this code less confusing, see discussion
# in https://github.com/pytorch/pytorch/pull/80521/files#r975918836 # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836

View File

@ -72,7 +72,7 @@ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
default_base_op_idx = 0 default_base_op_idx = 0
for quant_pattern, _quant_handler in all_quant_patterns.items(): for quant_pattern, _quant_handler in all_quant_patterns.items():
# TODO: this is a temporary hack to flatten the patterns from quantization so # TODO: this is a temporary hack to flatten the patterns from quantization so
# that it works with the ns matcher function, maybe we should use `is_match` # that it works with the ns matcher function, maybe we should use `_is_match`
# in torch.ao.quantization.fx.match_utils to match the patterns # in torch.ao.quantization.fx.match_utils to match the patterns
if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \ if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \
isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2: isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2:

View File

@ -8,7 +8,7 @@ from .graph_module import (
FusedGraphModule FusedGraphModule
) )
from .match_utils import ( from .match_utils import (
is_match, _is_match,
MatchAllNode, MatchAllNode,
) )
from .pattern_utils import ( from .pattern_utils import (
@ -157,7 +157,7 @@ def _find_matches(
if node.name not in match_map: if node.name not in match_map:
for pattern, value in patterns.items(): for pattern, value in patterns.items():
matched_node_pattern: List[Node] = [] matched_node_pattern: List[Node] = []
if is_match(modules, node, pattern): if _is_match(modules, node, pattern):
apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern) apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern)
break break

View File

@ -3,6 +3,10 @@ from .graph_module import QuantizedGraphModule
from ..qconfig import QConfigAny from ..qconfig import QConfigAny
from typing import Dict, Tuple from typing import Dict, Tuple
__all__ = [
"lower_to_qnnpack"
]
def lower_to_qnnpack( def lower_to_qnnpack(
model: QuantizedGraphModule, model: QuantizedGraphModule,
qconfig_map: Dict[str, QConfigAny], qconfig_map: Dict[str, QConfigAny],

View File

@ -21,15 +21,11 @@ from torch.nn.utils.parametrize import type_before_parametrizations
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
# TODO: revisit this list. Many helper methods shouldn't be public __all__: List[str] = []
__all__ = [
"is_match",
"find_matches",
]
# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type # TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]` # would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler] _MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler, _MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny] QConfigAny]
@ -38,7 +34,7 @@ _MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHan
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu. # need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns, # decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back. # we'll start from the last node of the graph and traverse back.
def is_match(modules, node, pattern, max_uses=sys.maxsize): def _is_match(modules, node, pattern, max_uses=sys.maxsize):
""" Matches a node in fx against a pattern """ Matches a node in fx against a pattern
""" """
if isinstance(pattern, tuple): if isinstance(pattern, tuple):
@ -82,16 +78,16 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize):
if len(arg_matches) != len(node.args): if len(arg_matches) != len(node.args):
return False return False
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
def find_matches( def _find_matches(
graph: Graph, graph: Graph,
modules: Dict[str, torch.nn.Module], modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler], patterns: Dict[Pattern, QuantizeHandler],
root_node_getter_mapping: Dict[Pattern, Callable], root_node_getter_mapping: Dict[Pattern, Callable],
standalone_module_names: List[str] = None, standalone_module_names: List[str] = None,
standalone_module_classes: List[Type] = None, standalone_module_classes: List[Type] = None,
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]: custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
""" """
Matches the nodes in the input graph to quantization patterns, and Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps. outputs the information needed to quantize them in future steps.
@ -123,7 +119,7 @@ def find_matches(
if standalone_module_names is None: if standalone_module_names is None:
standalone_module_names = [] standalone_module_names = []
match_map: Dict[str, MatchResult] = {} match_map: Dict[str, _MatchResult] = {}
all_matched : Set[str] = set() all_matched : Set[str] = set()
def _recursive_record_node_in_match_map( def _recursive_record_node_in_match_map(
@ -188,7 +184,7 @@ def find_matches(
if node.name not in match_map and node.name not in all_matched: if node.name not in match_map and node.name not in all_matched:
for pattern, quantize_handler_cls in patterns.items(): for pattern, quantize_handler_cls in patterns.items():
root_node_getter = root_node_getter_mapping.get(pattern, None) root_node_getter = root_node_getter_mapping.get(pattern, None)
if is_match(modules, node, pattern) and node.name not in match_map: if _is_match(modules, node, pattern) and node.name not in match_map:
matched_node_pattern: List[Node] = [] matched_node_pattern: List[Node] = []
record_match( record_match(
pattern, pattern,

View File

@ -57,7 +57,7 @@ from .pattern_utils import (
from .match_utils import ( from .match_utils import (
_MatchResultWithQConfig, _MatchResultWithQConfig,
find_matches, _find_matches,
) )
from ..utils import _parent_name from ..utils import _parent_name
@ -1527,7 +1527,7 @@ def prepare(
standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys()) standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping) custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
matches_without_qconfig = find_matches( matches_without_qconfig = _find_matches(
model.graph, modules, pattern_to_quantize_handler, root_node_getter_mapping, model.graph, modules, pattern_to_quantize_handler, root_node_getter_mapping,
standalone_module_names, standalone_module_classes, custom_module_classes) standalone_module_names, standalone_module_classes, custom_module_classes)

View File

@ -7,8 +7,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat
here. here.
""" """
from torch.ao.quantization.fx.match_utils import ( from torch.ao.quantization.fx.match_utils import (
MatchResult, _MatchResult,
MatchAllNode, MatchAllNode,
is_match, _is_match,
find_matches _find_matches
) )