Support matching Args for SubgraphMatcher (#85456)

Subgraph matcher now handles the matching of non-Node arguments.

Here are the 4 cases
- pn is Node, gn is Node: this go through the regular _match_node() function
- pn is Noed, gn is not a Node: this is a match if only pn is a placeholder op
- pn is not Node, gn is Node: this is a no match case
- pn is not a Node, gn is not a Node: this will go through the argument comparison.

With this change
```
def target(x):
    return foo(x, 3)

def pattern(x, y):
    return foo(x, y)
```

is a match

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85456
Approved by: https://github.com/jerryzh168
This commit is contained in:
Sherlock Huang
2022-09-24 17:57:32 +00:00
committed by PyTorch MergeBot
parent db40fbdee0
commit a8add2b92f
2 changed files with 92 additions and 4 deletions

View File

@ -621,6 +621,49 @@ class MultiOutputWithWithInvalidMatches:
TestCase(False, True, 0),
]
class QuantizationFp8Pattern:
@classmethod
def setup(cls):
cls.quantization = torch.library.Library("fp8_quantization", "DEF")
cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
@classmethod
def tearDown(cls):
del cls.quantization
@staticmethod
def forward(self, arg0_1, arg1_1):
qt = torch.ops.fp8_quantization
_scale_0 = self._scale_0
quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0)
dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0)
_scale_1 = self._scale_0
quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1)
dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1)
add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1)
_scale_2 = self._scale_0
quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2)
dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2)
return dequantize_per_tensor_affine_fp8_2
@staticmethod
def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
qt = torch.ops.fp8_quantization
a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale)
b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale)
output = torch.ops.aten.add.Tensor(a, b)
qt.dequantize_per_tensor_affine_fp8
output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale)
return output
test_cases = [
# match_output, match_placeholder, num_matches
TestCase(False, False, 1),
]
@instantiate_parametrized_tests
class TestFXMatcherUtils(JitTestCase):
@ -639,8 +682,14 @@ class TestFXMatcherUtils(JitTestCase):
MultipleOutputsIdenticalAnchor,
MultipleOutputsHorizontalPattern,
MultiOutputWithWithInvalidMatches,
QuantizationFp8Pattern,
])
def test_subgraph_matcher(self, test_model):
setup = getattr(test_model, "setup", None)
if callable(setup):
setup()
traced = symbolic_trace(test_model.forward)
pattern_traced = symbolic_trace(test_model.pattern)
@ -662,6 +711,10 @@ class TestFXMatcherUtils(JitTestCase):
continue
assert node in match.nodes_map
tearDown = getattr(test_model, "tearDown", None)
if callable(setup):
tearDown()
if __name__ == "__main__":
run_tests()

View File

@ -4,7 +4,8 @@ import copy
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx._compatibility import compatibility
from typing import Dict, List, Set
import torch.utils._pytree as pytree
from typing import Dict, List, Set, Any
__all__ = ['SubgraphMatcher', 'InternalMatch']
@ -124,7 +125,27 @@ class SubgraphMatcher:
nodes_matched.add(gn)
return non_overlapping_matches
def _match_args(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
assert not(isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
if isinstance(pn, Node) and not isinstance(gn, Node):
if pn.op == "placeholder":
# Check if we've already matched these nodes in the current
# traversal
if pn in match.nodes_map:
return match.nodes_map[pn] == gn
match.nodes_map[pn] = gn
return True
else:
return False
elif not isinstance(pn, Node) and isinstance(gn, Node):
return False
else:
return type(gn) == type(pn) and gn == pn
def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
assert isinstance(pn, Node) and isinstance(gn, Node), "pn and gn must be Node"
# Check if we've already matched these nodes in the current
# traversal
@ -146,9 +167,23 @@ class SubgraphMatcher:
# Recursively traverse upwards to check if `pn` is a true
# match for `gn`
match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes) and
all(self._match_nodes(pn_, gn_, match) for pn_, gn_
in zip(pn.all_input_nodes, gn.all_input_nodes)))
match_found = True
pn_flatten_args, _ = pytree.tree_flatten(pn.args)
gn_flatten_args, _ = pytree.tree_flatten(gn.args)
if len(pn_flatten_args) == len(gn_flatten_args):
for pn_, gn_ in zip(pn_flatten_args, gn_flatten_args):
if isinstance(gn_, Node) and isinstance(pn_, Node):
matched = self._match_nodes(pn_, gn_, match)
else:
matched = self._match_args(pn_, gn_, match)
if not matched:
match_found = False
break
else:
match_found = False
if not match_found:
match.nodes_map.pop(pn)