mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ao][fx] fixing public v private match_utils.py (#88396)
Summary: made _is_match, _find_matches, _MatchResult private also added __all__ to lower_to_qnnpack.py Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D41015540](https://our.internmc.facebook.com/intern/diff/D41015540) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88396 Approved by: https://github.com/jcaip
This commit is contained in:
committed by
PyTorch MergeBot
parent
a856557b3a
commit
79156c11c3
@ -135,10 +135,10 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
|||||||
|
|
||||||
def test_function_import_fx_match_utils(self):
|
def test_function_import_fx_match_utils(self):
|
||||||
function_list = [
|
function_list = [
|
||||||
'MatchResult',
|
'_MatchResult',
|
||||||
'MatchAllNode',
|
'MatchAllNode',
|
||||||
'is_match',
|
'_is_match',
|
||||||
'find_matches'
|
'_find_matches'
|
||||||
]
|
]
|
||||||
self._test_function_import('fx.match_utils', function_list)
|
self._test_function_import('fx.match_utils', function_list)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ from torch.ao.quantization.quantize_fx import (
|
|||||||
from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler
|
from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler
|
||||||
|
|
||||||
from torch.ao.quantization.fx.match_utils import (
|
from torch.ao.quantization.fx.match_utils import (
|
||||||
is_match,
|
_is_match,
|
||||||
MatchAllNode,
|
MatchAllNode,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -711,7 +711,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||||||
modules = dict(m.named_modules())
|
modules = dict(m.named_modules())
|
||||||
for n in m.graph.nodes:
|
for n in m.graph.nodes:
|
||||||
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
|
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
|
||||||
self.assertTrue(is_match(modules, n, pattern))
|
self.assertTrue(_is_match(modules, n, pattern))
|
||||||
|
|
||||||
def test_pattern_match_constant(self):
|
def test_pattern_match_constant(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
@ -727,7 +727,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||||||
modules = dict(m.named_modules())
|
modules = dict(m.named_modules())
|
||||||
for n in m.graph.nodes:
|
for n in m.graph.nodes:
|
||||||
if n.op == "call_function" and n.target == operator.getitem:
|
if n.op == "call_function" and n.target == operator.getitem:
|
||||||
self.assertTrue(is_match(modules, n, pattern))
|
self.assertTrue(_is_match(modules, n, pattern))
|
||||||
|
|
||||||
def test_fused_module_qat_swap(self):
|
def test_fused_module_qat_swap(self):
|
||||||
class Tmp(torch.nn.Module):
|
class Tmp(torch.nn.Module):
|
||||||
|
@ -121,7 +121,7 @@ from .fx.ns_types import (
|
|||||||
)
|
)
|
||||||
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
|
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
|
||||||
from torch.ao.quantization.backend_config import BackendConfig
|
from torch.ao.quantization.backend_config import BackendConfig
|
||||||
from torch.ao.quantization.fx.match_utils import find_matches
|
from torch.ao.quantization.fx.match_utils import _find_matches
|
||||||
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
|
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
|
||||||
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
|
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
|
||||||
from torch.ao.quantization.qconfig import QConfigAny
|
from torch.ao.quantization.qconfig import QConfigAny
|
||||||
@ -815,7 +815,7 @@ def prepare_n_shadows_model(
|
|||||||
standalone_module_names: List[str] = []
|
standalone_module_names: List[str] = []
|
||||||
standalone_module_classes: List[Type] = []
|
standalone_module_classes: List[Type] = []
|
||||||
custom_module_classes: List[Type] = []
|
custom_module_classes: List[Type] = []
|
||||||
matches = find_matches(
|
matches = _find_matches(
|
||||||
mt.graph, modules, patterns, root_node_getter_mapping,
|
mt.graph, modules, patterns, root_node_getter_mapping,
|
||||||
standalone_module_names, standalone_module_classes, custom_module_classes)
|
standalone_module_names, standalone_module_classes, custom_module_classes)
|
||||||
subgraphs_dedup: Dict[str, List[Node]] = \
|
subgraphs_dedup: Dict[str, List[Node]] = \
|
||||||
|
@ -20,7 +20,7 @@ from torch.ao.quantization import QConfigMapping
|
|||||||
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
||||||
from torch.ao.quantization.qconfig import QConfigAny
|
from torch.ao.quantization.qconfig import QConfigAny
|
||||||
from torch.ao.quantization.utils import getattr_from_fqn
|
from torch.ao.quantization.utils import getattr_from_fqn
|
||||||
from torch.ao.quantization.fx.match_utils import MatchResult
|
from torch.ao.quantization.fx.match_utils import _MatchResult
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
@ -100,7 +100,7 @@ class OutputProp:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_dedup_subgraphs(
|
def _get_dedup_subgraphs(
|
||||||
matches: Dict[str, MatchResult]
|
matches: Dict[str, _MatchResult]
|
||||||
) -> Dict[str, List[Node]]:
|
) -> Dict[str, List[Node]]:
|
||||||
# the original matches variable is unique by node, make it unique by subgraph
|
# the original matches variable is unique by node, make it unique by subgraph
|
||||||
# instead
|
# instead
|
||||||
@ -110,7 +110,7 @@ def _get_dedup_subgraphs(
|
|||||||
# Dict items are not reversible until Python 3.8, so we hack it
|
# Dict items are not reversible until Python 3.8, so we hack it
|
||||||
# to be compatible with previous Python versions
|
# to be compatible with previous Python versions
|
||||||
# TODO(future PR): try reversed(list(matches.items()))
|
# TODO(future PR): try reversed(list(matches.items()))
|
||||||
matches_items_reversed: List[Tuple[str, MatchResult]] = []
|
matches_items_reversed: List[Tuple[str, _MatchResult]] = []
|
||||||
for name, cur_match in matches.items():
|
for name, cur_match in matches.items():
|
||||||
matches_items_reversed.insert(0, (name, cur_match))
|
matches_items_reversed.insert(0, (name, cur_match))
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ def _get_dedup_subgraphs(
|
|||||||
assert len(cur_match[1]) == 2
|
assert len(cur_match[1]) == 2
|
||||||
# either (a, b), or ((a, b), c) or (c, (a, b))
|
# either (a, b), or ((a, b), c) or (c, (a, b))
|
||||||
# cannot make any assumptions on order, not clear what the
|
# cannot make any assumptions on order, not clear what the
|
||||||
# find_matches function is doing to populate this
|
# _find_matches function is doing to populate this
|
||||||
# TODO(future PR): make this code less confusing, see discussion
|
# TODO(future PR): make this code less confusing, see discussion
|
||||||
# in https://github.com/pytorch/pytorch/pull/80521/files#r975918836
|
# in https://github.com/pytorch/pytorch/pull/80521/files#r975918836
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
|
|||||||
default_base_op_idx = 0
|
default_base_op_idx = 0
|
||||||
for quant_pattern, _quant_handler in all_quant_patterns.items():
|
for quant_pattern, _quant_handler in all_quant_patterns.items():
|
||||||
# TODO: this is a temporary hack to flatten the patterns from quantization so
|
# TODO: this is a temporary hack to flatten the patterns from quantization so
|
||||||
# that it works with the ns matcher function, maybe we should use `is_match`
|
# that it works with the ns matcher function, maybe we should use `_is_match`
|
||||||
# in torch.ao.quantization.fx.match_utils to match the patterns
|
# in torch.ao.quantization.fx.match_utils to match the patterns
|
||||||
if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \
|
if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \
|
||||||
isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2:
|
isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2:
|
||||||
|
@ -8,7 +8,7 @@ from .graph_module import (
|
|||||||
FusedGraphModule
|
FusedGraphModule
|
||||||
)
|
)
|
||||||
from .match_utils import (
|
from .match_utils import (
|
||||||
is_match,
|
_is_match,
|
||||||
MatchAllNode,
|
MatchAllNode,
|
||||||
)
|
)
|
||||||
from .pattern_utils import (
|
from .pattern_utils import (
|
||||||
@ -157,7 +157,7 @@ def _find_matches(
|
|||||||
if node.name not in match_map:
|
if node.name not in match_map:
|
||||||
for pattern, value in patterns.items():
|
for pattern, value in patterns.items():
|
||||||
matched_node_pattern: List[Node] = []
|
matched_node_pattern: List[Node] = []
|
||||||
if is_match(modules, node, pattern):
|
if _is_match(modules, node, pattern):
|
||||||
apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern)
|
apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -3,6 +3,10 @@ from .graph_module import QuantizedGraphModule
|
|||||||
from ..qconfig import QConfigAny
|
from ..qconfig import QConfigAny
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"lower_to_qnnpack"
|
||||||
|
]
|
||||||
|
|
||||||
def lower_to_qnnpack(
|
def lower_to_qnnpack(
|
||||||
model: QuantizedGraphModule,
|
model: QuantizedGraphModule,
|
||||||
qconfig_map: Dict[str, QConfigAny],
|
qconfig_map: Dict[str, QConfigAny],
|
||||||
|
@ -21,15 +21,11 @@ from torch.nn.utils.parametrize import type_before_parametrizations
|
|||||||
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
|
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
|
||||||
|
|
||||||
|
|
||||||
# TODO: revisit this list. Many helper methods shouldn't be public
|
__all__: List[str] = []
|
||||||
__all__ = [
|
|
||||||
"is_match",
|
|
||||||
"find_matches",
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
|
# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
|
||||||
# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
|
# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
|
||||||
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
|
_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
|
||||||
|
|
||||||
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
||||||
QConfigAny]
|
QConfigAny]
|
||||||
@ -38,7 +34,7 @@ _MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHan
|
|||||||
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
|
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
|
||||||
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
|
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
|
||||||
# we'll start from the last node of the graph and traverse back.
|
# we'll start from the last node of the graph and traverse back.
|
||||||
def is_match(modules, node, pattern, max_uses=sys.maxsize):
|
def _is_match(modules, node, pattern, max_uses=sys.maxsize):
|
||||||
""" Matches a node in fx against a pattern
|
""" Matches a node in fx against a pattern
|
||||||
"""
|
"""
|
||||||
if isinstance(pattern, tuple):
|
if isinstance(pattern, tuple):
|
||||||
@ -82,16 +78,16 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize):
|
|||||||
if len(arg_matches) != len(node.args):
|
if len(arg_matches) != len(node.args):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return all(is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
|
return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
|
||||||
|
|
||||||
def find_matches(
|
def _find_matches(
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
modules: Dict[str, torch.nn.Module],
|
modules: Dict[str, torch.nn.Module],
|
||||||
patterns: Dict[Pattern, QuantizeHandler],
|
patterns: Dict[Pattern, QuantizeHandler],
|
||||||
root_node_getter_mapping: Dict[Pattern, Callable],
|
root_node_getter_mapping: Dict[Pattern, Callable],
|
||||||
standalone_module_names: List[str] = None,
|
standalone_module_names: List[str] = None,
|
||||||
standalone_module_classes: List[Type] = None,
|
standalone_module_classes: List[Type] = None,
|
||||||
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
|
custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
|
||||||
"""
|
"""
|
||||||
Matches the nodes in the input graph to quantization patterns, and
|
Matches the nodes in the input graph to quantization patterns, and
|
||||||
outputs the information needed to quantize them in future steps.
|
outputs the information needed to quantize them in future steps.
|
||||||
@ -123,7 +119,7 @@ def find_matches(
|
|||||||
if standalone_module_names is None:
|
if standalone_module_names is None:
|
||||||
standalone_module_names = []
|
standalone_module_names = []
|
||||||
|
|
||||||
match_map: Dict[str, MatchResult] = {}
|
match_map: Dict[str, _MatchResult] = {}
|
||||||
all_matched : Set[str] = set()
|
all_matched : Set[str] = set()
|
||||||
|
|
||||||
def _recursive_record_node_in_match_map(
|
def _recursive_record_node_in_match_map(
|
||||||
@ -188,7 +184,7 @@ def find_matches(
|
|||||||
if node.name not in match_map and node.name not in all_matched:
|
if node.name not in match_map and node.name not in all_matched:
|
||||||
for pattern, quantize_handler_cls in patterns.items():
|
for pattern, quantize_handler_cls in patterns.items():
|
||||||
root_node_getter = root_node_getter_mapping.get(pattern, None)
|
root_node_getter = root_node_getter_mapping.get(pattern, None)
|
||||||
if is_match(modules, node, pattern) and node.name not in match_map:
|
if _is_match(modules, node, pattern) and node.name not in match_map:
|
||||||
matched_node_pattern: List[Node] = []
|
matched_node_pattern: List[Node] = []
|
||||||
record_match(
|
record_match(
|
||||||
pattern,
|
pattern,
|
||||||
|
@ -57,7 +57,7 @@ from .pattern_utils import (
|
|||||||
|
|
||||||
from .match_utils import (
|
from .match_utils import (
|
||||||
_MatchResultWithQConfig,
|
_MatchResultWithQConfig,
|
||||||
find_matches,
|
_find_matches,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..utils import _parent_name
|
from ..utils import _parent_name
|
||||||
@ -1527,7 +1527,7 @@ def prepare(
|
|||||||
standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
|
standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
|
||||||
|
|
||||||
custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
|
custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
|
||||||
matches_without_qconfig = find_matches(
|
matches_without_qconfig = _find_matches(
|
||||||
model.graph, modules, pattern_to_quantize_handler, root_node_getter_mapping,
|
model.graph, modules, pattern_to_quantize_handler, root_node_getter_mapping,
|
||||||
standalone_module_names, standalone_module_classes, custom_module_classes)
|
standalone_module_names, standalone_module_classes, custom_module_classes)
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat
|
|||||||
here.
|
here.
|
||||||
"""
|
"""
|
||||||
from torch.ao.quantization.fx.match_utils import (
|
from torch.ao.quantization.fx.match_utils import (
|
||||||
MatchResult,
|
_MatchResult,
|
||||||
MatchAllNode,
|
MatchAllNode,
|
||||||
is_match,
|
_is_match,
|
||||||
find_matches
|
_find_matches
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user