[ao][fx] fixing public v private match_utils.py (#88396)

Summary: made _is_match, _find_matches, _MatchResult private also added
__all__ to lower_to_qnnpack.py

Test Plan: python test/test_public_bindings.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D41015540](https://our.internmc.facebook.com/intern/diff/D41015540)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88396
Approved by: https://github.com/jcaip
This commit is contained in:
HDCharles
2022-12-12 19:46:04 -08:00
committed by PyTorch MergeBot
parent a856557b3a
commit 79156c11c3
10 changed files with 32 additions and 32 deletions

View File

@ -135,10 +135,10 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_function_import_fx_match_utils(self):
function_list = [
'MatchResult',
'_MatchResult',
'MatchAllNode',
'is_match',
'find_matches'
'_is_match',
'_find_matches'
]
self._test_function_import('fx.match_utils', function_list)

View File

@ -27,7 +27,7 @@ from torch.ao.quantization.quantize_fx import (
from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler
from torch.ao.quantization.fx.match_utils import (
is_match,
_is_match,
MatchAllNode,
)
@ -711,7 +711,7 @@ class TestQuantizeFx(QuantizationTestCase):
modules = dict(m.named_modules())
for n in m.graph.nodes:
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
self.assertTrue(is_match(modules, n, pattern))
self.assertTrue(_is_match(modules, n, pattern))
def test_pattern_match_constant(self):
class M(torch.nn.Module):
@ -727,7 +727,7 @@ class TestQuantizeFx(QuantizationTestCase):
modules = dict(m.named_modules())
for n in m.graph.nodes:
if n.op == "call_function" and n.target == operator.getitem:
self.assertTrue(is_match(modules, n, pattern))
self.assertTrue(_is_match(modules, n, pattern))
def test_fused_module_qat_swap(self):
class Tmp(torch.nn.Module):

View File

@ -121,7 +121,7 @@ from .fx.ns_types import (
)
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.fx.match_utils import find_matches
from torch.ao.quantization.fx.match_utils import _find_matches
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
from torch.ao.quantization.qconfig import QConfigAny
@ -815,7 +815,7 @@ def prepare_n_shadows_model(
standalone_module_names: List[str] = []
standalone_module_classes: List[Type] = []
custom_module_classes: List[Type] = []
matches = find_matches(
matches = _find_matches(
mt.graph, modules, patterns, root_node_getter_mapping,
standalone_module_names, standalone_module_classes, custom_module_classes)
subgraphs_dedup: Dict[str, List[Node]] = \

View File

@ -20,7 +20,7 @@ from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.utils import getattr_from_fqn
from torch.ao.quantization.fx.match_utils import MatchResult
from torch.ao.quantization.fx.match_utils import _MatchResult
import collections
import copy
@ -100,7 +100,7 @@ class OutputProp:
return None
def _get_dedup_subgraphs(
matches: Dict[str, MatchResult]
matches: Dict[str, _MatchResult]
) -> Dict[str, List[Node]]:
# the original matches variable is unique by node, make it unique by subgraph
# instead
@ -110,7 +110,7 @@ def _get_dedup_subgraphs(
# Dict items are not reversible until Python 3.8, so we hack it
# to be compatible with previous Python versions
# TODO(future PR): try reversed(list(matches.items()))
matches_items_reversed: List[Tuple[str, MatchResult]] = []
matches_items_reversed: List[Tuple[str, _MatchResult]] = []
for name, cur_match in matches.items():
matches_items_reversed.insert(0, (name, cur_match))
@ -162,7 +162,7 @@ def _get_dedup_subgraphs(
assert len(cur_match[1]) == 2
# either (a, b), or ((a, b), c) or (c, (a, b))
# cannot make any assumptions on order, not clear what the
# find_matches function is doing to populate this
# _find_matches function is doing to populate this
# TODO(future PR): make this code less confusing, see discussion
# in https://github.com/pytorch/pytorch/pull/80521/files#r975918836

View File

@ -72,7 +72,7 @@ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
default_base_op_idx = 0
for quant_pattern, _quant_handler in all_quant_patterns.items():
# TODO: this is a temporary hack to flatten the patterns from quantization so
# that it works with the ns matcher function, maybe we should use `is_match`
# that it works with the ns matcher function, maybe we should use `_is_match`
# in torch.ao.quantization.fx.match_utils to match the patterns
if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \
isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2:

View File

@ -8,7 +8,7 @@ from .graph_module import (
FusedGraphModule
)
from .match_utils import (
is_match,
_is_match,
MatchAllNode,
)
from .pattern_utils import (
@ -157,7 +157,7 @@ def _find_matches(
if node.name not in match_map:
for pattern, value in patterns.items():
matched_node_pattern: List[Node] = []
if is_match(modules, node, pattern):
if _is_match(modules, node, pattern):
apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern)
break

View File

@ -3,6 +3,10 @@ from .graph_module import QuantizedGraphModule
from ..qconfig import QConfigAny
from typing import Dict, Tuple
__all__ = [
"lower_to_qnnpack"
]
def lower_to_qnnpack(
model: QuantizedGraphModule,
qconfig_map: Dict[str, QConfigAny],

View File

@ -21,15 +21,11 @@ from torch.nn.utils.parametrize import type_before_parametrizations
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
# TODO: revisit this list. Many helper methods shouldn't be public
__all__ = [
"is_match",
"find_matches",
]
__all__: List[str] = []
# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
@ -38,7 +34,7 @@ _MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHan
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back.
def is_match(modules, node, pattern, max_uses=sys.maxsize):
def _is_match(modules, node, pattern, max_uses=sys.maxsize):
""" Matches a node in fx against a pattern
"""
if isinstance(pattern, tuple):
@ -82,16 +78,16 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize):
if len(arg_matches) != len(node.args):
return False
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
def find_matches(
def _find_matches(
graph: Graph,
modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
root_node_getter_mapping: Dict[Pattern, Callable],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Type] = None,
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
@ -123,7 +119,7 @@ def find_matches(
if standalone_module_names is None:
standalone_module_names = []
match_map: Dict[str, MatchResult] = {}
match_map: Dict[str, _MatchResult] = {}
all_matched : Set[str] = set()
def _recursive_record_node_in_match_map(
@ -188,7 +184,7 @@ def find_matches(
if node.name not in match_map and node.name not in all_matched:
for pattern, quantize_handler_cls in patterns.items():
root_node_getter = root_node_getter_mapping.get(pattern, None)
if is_match(modules, node, pattern) and node.name not in match_map:
if _is_match(modules, node, pattern) and node.name not in match_map:
matched_node_pattern: List[Node] = []
record_match(
pattern,

View File

@ -57,7 +57,7 @@ from .pattern_utils import (
from .match_utils import (
_MatchResultWithQConfig,
find_matches,
_find_matches,
)
from ..utils import _parent_name
@ -1527,7 +1527,7 @@ def prepare(
standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
matches_without_qconfig = find_matches(
matches_without_qconfig = _find_matches(
model.graph, modules, pattern_to_quantize_handler, root_node_getter_mapping,
standalone_module_names, standalone_module_classes, custom_module_classes)

View File

@ -7,8 +7,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat
here.
"""
from torch.ao.quantization.fx.match_utils import (
MatchResult,
_MatchResult,
MatchAllNode,
is_match,
find_matches
_is_match,
_find_matches
)