mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ao][fx] fixing public v private match_utils.py (#88396)
Summary: made _is_match, _find_matches, _MatchResult private also added __all__ to lower_to_qnnpack.py Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D41015540](https://our.internmc.facebook.com/intern/diff/D41015540) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88396 Approved by: https://github.com/jcaip
This commit is contained in:
committed by
PyTorch MergeBot
parent
a856557b3a
commit
79156c11c3
@ -135,10 +135,10 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
|
||||
def test_function_import_fx_match_utils(self):
|
||||
function_list = [
|
||||
'MatchResult',
|
||||
'_MatchResult',
|
||||
'MatchAllNode',
|
||||
'is_match',
|
||||
'find_matches'
|
||||
'_is_match',
|
||||
'_find_matches'
|
||||
]
|
||||
self._test_function_import('fx.match_utils', function_list)
|
||||
|
||||
|
@ -27,7 +27,7 @@ from torch.ao.quantization.quantize_fx import (
|
||||
from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler
|
||||
|
||||
from torch.ao.quantization.fx.match_utils import (
|
||||
is_match,
|
||||
_is_match,
|
||||
MatchAllNode,
|
||||
)
|
||||
|
||||
@ -711,7 +711,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
modules = dict(m.named_modules())
|
||||
for n in m.graph.nodes:
|
||||
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
|
||||
self.assertTrue(is_match(modules, n, pattern))
|
||||
self.assertTrue(_is_match(modules, n, pattern))
|
||||
|
||||
def test_pattern_match_constant(self):
|
||||
class M(torch.nn.Module):
|
||||
@ -727,7 +727,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
modules = dict(m.named_modules())
|
||||
for n in m.graph.nodes:
|
||||
if n.op == "call_function" and n.target == operator.getitem:
|
||||
self.assertTrue(is_match(modules, n, pattern))
|
||||
self.assertTrue(_is_match(modules, n, pattern))
|
||||
|
||||
def test_fused_module_qat_swap(self):
|
||||
class Tmp(torch.nn.Module):
|
||||
|
@ -121,7 +121,7 @@ from .fx.ns_types import (
|
||||
)
|
||||
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
|
||||
from torch.ao.quantization.backend_config import BackendConfig
|
||||
from torch.ao.quantization.fx.match_utils import find_matches
|
||||
from torch.ao.quantization.fx.match_utils import _find_matches
|
||||
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
|
||||
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
|
||||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
@ -815,7 +815,7 @@ def prepare_n_shadows_model(
|
||||
standalone_module_names: List[str] = []
|
||||
standalone_module_classes: List[Type] = []
|
||||
custom_module_classes: List[Type] = []
|
||||
matches = find_matches(
|
||||
matches = _find_matches(
|
||||
mt.graph, modules, patterns, root_node_getter_mapping,
|
||||
standalone_module_names, standalone_module_classes, custom_module_classes)
|
||||
subgraphs_dedup: Dict[str, List[Node]] = \
|
||||
|
@ -20,7 +20,7 @@ from torch.ao.quantization import QConfigMapping
|
||||
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
||||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
from torch.ao.quantization.utils import getattr_from_fqn
|
||||
from torch.ao.quantization.fx.match_utils import MatchResult
|
||||
from torch.ao.quantization.fx.match_utils import _MatchResult
|
||||
|
||||
import collections
|
||||
import copy
|
||||
@ -100,7 +100,7 @@ class OutputProp:
|
||||
return None
|
||||
|
||||
def _get_dedup_subgraphs(
|
||||
matches: Dict[str, MatchResult]
|
||||
matches: Dict[str, _MatchResult]
|
||||
) -> Dict[str, List[Node]]:
|
||||
# the original matches variable is unique by node, make it unique by subgraph
|
||||
# instead
|
||||
@ -110,7 +110,7 @@ def _get_dedup_subgraphs(
|
||||
# Dict items are not reversible until Python 3.8, so we hack it
|
||||
# to be compatible with previous Python versions
|
||||
# TODO(future PR): try reversed(list(matches.items()))
|
||||
matches_items_reversed: List[Tuple[str, MatchResult]] = []
|
||||
matches_items_reversed: List[Tuple[str, _MatchResult]] = []
|
||||
for name, cur_match in matches.items():
|
||||
matches_items_reversed.insert(0, (name, cur_match))
|
||||
|
||||
@ -162,7 +162,7 @@ def _get_dedup_subgraphs(
|
||||
assert len(cur_match[1]) == 2
|
||||
# either (a, b), or ((a, b), c) or (c, (a, b))
|
||||
# cannot make any assumptions on order, not clear what the
|
||||
# find_matches function is doing to populate this
|
||||
# _find_matches function is doing to populate this
|
||||
# TODO(future PR): make this code less confusing, see discussion
|
||||
# in https://github.com/pytorch/pytorch/pull/80521/files#r975918836
|
||||
|
||||
|
@ -72,7 +72,7 @@ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
|
||||
default_base_op_idx = 0
|
||||
for quant_pattern, _quant_handler in all_quant_patterns.items():
|
||||
# TODO: this is a temporary hack to flatten the patterns from quantization so
|
||||
# that it works with the ns matcher function, maybe we should use `is_match`
|
||||
# that it works with the ns matcher function, maybe we should use `_is_match`
|
||||
# in torch.ao.quantization.fx.match_utils to match the patterns
|
||||
if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \
|
||||
isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2:
|
||||
|
@ -8,7 +8,7 @@ from .graph_module import (
|
||||
FusedGraphModule
|
||||
)
|
||||
from .match_utils import (
|
||||
is_match,
|
||||
_is_match,
|
||||
MatchAllNode,
|
||||
)
|
||||
from .pattern_utils import (
|
||||
@ -157,7 +157,7 @@ def _find_matches(
|
||||
if node.name not in match_map:
|
||||
for pattern, value in patterns.items():
|
||||
matched_node_pattern: List[Node] = []
|
||||
if is_match(modules, node, pattern):
|
||||
if _is_match(modules, node, pattern):
|
||||
apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern)
|
||||
break
|
||||
|
||||
|
@ -3,6 +3,10 @@ from .graph_module import QuantizedGraphModule
|
||||
from ..qconfig import QConfigAny
|
||||
from typing import Dict, Tuple
|
||||
|
||||
__all__ = [
|
||||
"lower_to_qnnpack"
|
||||
]
|
||||
|
||||
def lower_to_qnnpack(
|
||||
model: QuantizedGraphModule,
|
||||
qconfig_map: Dict[str, QConfigAny],
|
||||
|
@ -21,15 +21,11 @@ from torch.nn.utils.parametrize import type_before_parametrizations
|
||||
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
|
||||
|
||||
|
||||
# TODO: revisit this list. Many helper methods shouldn't be public
|
||||
__all__ = [
|
||||
"is_match",
|
||||
"find_matches",
|
||||
]
|
||||
__all__: List[str] = []
|
||||
|
||||
# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
|
||||
# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
|
||||
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
|
||||
_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
|
||||
|
||||
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
||||
QConfigAny]
|
||||
@ -38,7 +34,7 @@ _MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHan
|
||||
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
|
||||
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
|
||||
# we'll start from the last node of the graph and traverse back.
|
||||
def is_match(modules, node, pattern, max_uses=sys.maxsize):
|
||||
def _is_match(modules, node, pattern, max_uses=sys.maxsize):
|
||||
""" Matches a node in fx against a pattern
|
||||
"""
|
||||
if isinstance(pattern, tuple):
|
||||
@ -82,16 +78,16 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize):
|
||||
if len(arg_matches) != len(node.args):
|
||||
return False
|
||||
|
||||
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
|
||||
return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
|
||||
|
||||
def find_matches(
|
||||
def _find_matches(
|
||||
graph: Graph,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
patterns: Dict[Pattern, QuantizeHandler],
|
||||
root_node_getter_mapping: Dict[Pattern, Callable],
|
||||
standalone_module_names: List[str] = None,
|
||||
standalone_module_classes: List[Type] = None,
|
||||
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
|
||||
custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
|
||||
"""
|
||||
Matches the nodes in the input graph to quantization patterns, and
|
||||
outputs the information needed to quantize them in future steps.
|
||||
@ -123,7 +119,7 @@ def find_matches(
|
||||
if standalone_module_names is None:
|
||||
standalone_module_names = []
|
||||
|
||||
match_map: Dict[str, MatchResult] = {}
|
||||
match_map: Dict[str, _MatchResult] = {}
|
||||
all_matched : Set[str] = set()
|
||||
|
||||
def _recursive_record_node_in_match_map(
|
||||
@ -188,7 +184,7 @@ def find_matches(
|
||||
if node.name not in match_map and node.name not in all_matched:
|
||||
for pattern, quantize_handler_cls in patterns.items():
|
||||
root_node_getter = root_node_getter_mapping.get(pattern, None)
|
||||
if is_match(modules, node, pattern) and node.name not in match_map:
|
||||
if _is_match(modules, node, pattern) and node.name not in match_map:
|
||||
matched_node_pattern: List[Node] = []
|
||||
record_match(
|
||||
pattern,
|
||||
|
@ -57,7 +57,7 @@ from .pattern_utils import (
|
||||
|
||||
from .match_utils import (
|
||||
_MatchResultWithQConfig,
|
||||
find_matches,
|
||||
_find_matches,
|
||||
)
|
||||
|
||||
from ..utils import _parent_name
|
||||
@ -1527,7 +1527,7 @@ def prepare(
|
||||
standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
|
||||
|
||||
custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
|
||||
matches_without_qconfig = find_matches(
|
||||
matches_without_qconfig = _find_matches(
|
||||
model.graph, modules, pattern_to_quantize_handler, root_node_getter_mapping,
|
||||
standalone_module_names, standalone_module_classes, custom_module_classes)
|
||||
|
||||
|
@ -7,8 +7,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat
|
||||
here.
|
||||
"""
|
||||
from torch.ao.quantization.fx.match_utils import (
|
||||
MatchResult,
|
||||
_MatchResult,
|
||||
MatchAllNode,
|
||||
is_match,
|
||||
find_matches
|
||||
_is_match,
|
||||
_find_matches
|
||||
)
|
||||
|
Reference in New Issue
Block a user