Files
pytorch/torch/ao/quantization/fx/fuse.py
Jerry Zhang a73c6a45b6 [reland][quant][graphmode][fx] Enable fuse handler for sequence of 3 ops (#70006)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70006

reland: fixing some mypy errors that was missed before

This PR enables fuse handler for sequence of three ops, and merges all fuse handlers into one

TODO: we can also move this to backend_config_dict folder

Test Plan:
regression fusion test
```
python test/test_quantization.py TestFuseFx
```

Imported from OSS

Imported from OSS

Reviewed By: supriyar

Differential Revision: D33144606

fbshipit-source-id: ca34f282018a0fb4d04c7e35119eaf2d64258e78
2021-12-16 15:04:16 -08:00

114 lines
4.7 KiB
Python

from torch.fx import (
GraphModule,
Node,
map_arg
)
from torch.fx.graph import Graph
from ..utils import (
get_combined_dict
)
from .graph_module import (
FusedGraphModule
)
from .match_utils import is_match
from .pattern_utils import (
get_default_fusion_patterns,
)
from .backend_config_dict.utils import get_fusion_pattern_to_fuse_handler_cls
from .backend_config_dict.utils import get_fuser_method_mapping
from .fusion_patterns import * # noqa: F401,F403
from typing import Callable, Tuple, Dict, Any, Optional, List
from .quantization_types import Pattern, NodePattern
class Fuser:
def fuse(
self,
model: GraphModule,
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
input_root = model
input_graph = model.graph
self.modules = dict(input_root.named_modules())
if backend_config_dict is None:
additional_fusion_patterns = \
fuse_custom_config_dict.get("additional_fusion_pattern", {})
fusion_pattern_to_fuse_handler_cls = get_combined_dict(
get_default_fusion_patterns(), additional_fusion_patterns)
fuser_method_mapping = None
else:
fusion_pattern_to_fuse_handler_cls = get_fusion_pattern_to_fuse_handler_cls(backend_config_dict)
fuser_method_mapping = get_fuser_method_mapping(backend_config_dict)
# find fusion
fusion_pairs = self._find_matches(
input_root, input_graph, fusion_pattern_to_fuse_handler_cls)
self.fused_graph = Graph()
env: Dict[Any, Any] = {}
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
def get_root_node(node_pattern):
while not isinstance(node_pattern[-1], Node):
node_pattern = node_pattern[-1]
return node_pattern[-1]
for node in input_graph.nodes:
maybe_last_node, pattern, matched_node_pattern, obj = \
fusion_pairs.get(node.name, (None, None, None, None))
if maybe_last_node is node:
assert obj is not None
# TODO: currently we hard code the root node, which only works for
# a sequence of ops and assume the root node is the last node,
# we want to make this more general to support more complex patterns
root_node = get_root_node(matched_node_pattern) # type: ignore[index]
env[node.name] = obj.fuse(
self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type]
fuse_custom_config_dict, fuser_method_mapping)
elif maybe_last_node is None:
env[node.name] = self.fused_graph.node_copy(node, load_arg)
# node matched in patterns and is not root is removed here
preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", []))
model = FusedGraphModule(input_root, self.fused_graph, preserved_attributes)
return model
def _find_matches(
self, root: GraphModule, graph: Graph,
patterns: Dict[Pattern, Callable]
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]]:
modules = dict(root.named_modules())
match_map : Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]] = {} # node name -> (root_node, match_value)
def apply_match(pattern, node, match, matched_node_pattern):
if isinstance(pattern, tuple):
s, *args = pattern
current_node_pattern: List[Node] = []
apply_match(s, node, match, current_node_pattern)
for subpattern, arg in zip(args, node.args):
apply_match(subpattern, arg, match, current_node_pattern)
matched_node_pattern.append(tuple(current_node_pattern))
else:
# the first pattern matches will take precedence
if node.name not in match_map:
matched_node_pattern.append(node)
root_node, pattern, handler = match
match_map[node.name] = (root_node, pattern, matched_node_pattern, handler)
for node in reversed(graph.nodes):
if node.name not in match_map:
for pattern, value in patterns.items():
matched_node_pattern: List[Node] = []
if is_match(modules, node, pattern):
apply_match(pattern, node, (node, pattern, value(self, node)), matched_node_pattern)
return match_map