mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 08:34:52 +08:00
Support matching Args for SubgraphMatcher (#85456)
Subgraph matcher now handles the matching of non-Node arguments.
Here are the 4 cases
- pn is Node, gn is Node: this go through the regular _match_node() function
- pn is Noed, gn is not a Node: this is a match if only pn is a placeholder op
- pn is not Node, gn is Node: this is a no match case
- pn is not a Node, gn is not a Node: this will go through the argument comparison.
With this change
```
def target(x):
return foo(x, 3)
def pattern(x, y):
return foo(x, y)
```
is a match
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85456
Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
db40fbdee0
commit
a8add2b92f
@ -621,6 +621,49 @@ class MultiOutputWithWithInvalidMatches:
|
||||
TestCase(False, True, 0),
|
||||
]
|
||||
|
||||
class QuantizationFp8Pattern:
|
||||
@classmethod
|
||||
def setup(cls):
|
||||
cls.quantization = torch.library.Library("fp8_quantization", "DEF")
|
||||
cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
|
||||
cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
|
||||
|
||||
@classmethod
|
||||
def tearDown(cls):
|
||||
del cls.quantization
|
||||
|
||||
@staticmethod
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
qt = torch.ops.fp8_quantization
|
||||
_scale_0 = self._scale_0
|
||||
quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0)
|
||||
dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0)
|
||||
_scale_1 = self._scale_0
|
||||
quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1)
|
||||
dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1)
|
||||
add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1)
|
||||
_scale_2 = self._scale_0
|
||||
quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2)
|
||||
dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2)
|
||||
return dequantize_per_tensor_affine_fp8_2
|
||||
|
||||
@staticmethod
|
||||
def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
|
||||
qt = torch.ops.fp8_quantization
|
||||
a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale)
|
||||
b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale)
|
||||
output = torch.ops.aten.add.Tensor(a, b)
|
||||
|
||||
qt.dequantize_per_tensor_affine_fp8
|
||||
|
||||
output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale)
|
||||
return output
|
||||
|
||||
test_cases = [
|
||||
# match_output, match_placeholder, num_matches
|
||||
TestCase(False, False, 1),
|
||||
]
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestFXMatcherUtils(JitTestCase):
|
||||
|
||||
@ -639,8 +682,14 @@ class TestFXMatcherUtils(JitTestCase):
|
||||
MultipleOutputsIdenticalAnchor,
|
||||
MultipleOutputsHorizontalPattern,
|
||||
MultiOutputWithWithInvalidMatches,
|
||||
QuantizationFp8Pattern,
|
||||
])
|
||||
def test_subgraph_matcher(self, test_model):
|
||||
|
||||
setup = getattr(test_model, "setup", None)
|
||||
if callable(setup):
|
||||
setup()
|
||||
|
||||
traced = symbolic_trace(test_model.forward)
|
||||
pattern_traced = symbolic_trace(test_model.pattern)
|
||||
|
||||
@ -662,6 +711,10 @@ class TestFXMatcherUtils(JitTestCase):
|
||||
continue
|
||||
assert node in match.nodes_map
|
||||
|
||||
tearDown = getattr(test_model, "tearDown", None)
|
||||
if callable(setup):
|
||||
tearDown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -4,7 +4,8 @@ import copy
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node
|
||||
from torch.fx._compatibility import compatibility
|
||||
from typing import Dict, List, Set
|
||||
import torch.utils._pytree as pytree
|
||||
from typing import Dict, List, Set, Any
|
||||
|
||||
__all__ = ['SubgraphMatcher', 'InternalMatch']
|
||||
|
||||
@ -124,7 +125,27 @@ class SubgraphMatcher:
|
||||
nodes_matched.add(gn)
|
||||
return non_overlapping_matches
|
||||
|
||||
def _match_args(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
|
||||
assert not(isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
|
||||
|
||||
if isinstance(pn, Node) and not isinstance(gn, Node):
|
||||
if pn.op == "placeholder":
|
||||
# Check if we've already matched these nodes in the current
|
||||
# traversal
|
||||
if pn in match.nodes_map:
|
||||
return match.nodes_map[pn] == gn
|
||||
|
||||
match.nodes_map[pn] = gn
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
elif not isinstance(pn, Node) and isinstance(gn, Node):
|
||||
return False
|
||||
else:
|
||||
return type(gn) == type(pn) and gn == pn
|
||||
|
||||
def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
|
||||
assert isinstance(pn, Node) and isinstance(gn, Node), "pn and gn must be Node"
|
||||
|
||||
# Check if we've already matched these nodes in the current
|
||||
# traversal
|
||||
@ -146,9 +167,23 @@ class SubgraphMatcher:
|
||||
|
||||
# Recursively traverse upwards to check if `pn` is a true
|
||||
# match for `gn`
|
||||
match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes) and
|
||||
all(self._match_nodes(pn_, gn_, match) for pn_, gn_
|
||||
in zip(pn.all_input_nodes, gn.all_input_nodes)))
|
||||
match_found = True
|
||||
|
||||
pn_flatten_args, _ = pytree.tree_flatten(pn.args)
|
||||
gn_flatten_args, _ = pytree.tree_flatten(gn.args)
|
||||
|
||||
if len(pn_flatten_args) == len(gn_flatten_args):
|
||||
for pn_, gn_ in zip(pn_flatten_args, gn_flatten_args):
|
||||
if isinstance(gn_, Node) and isinstance(pn_, Node):
|
||||
matched = self._match_nodes(pn_, gn_, match)
|
||||
else:
|
||||
matched = self._match_args(pn_, gn_, match)
|
||||
|
||||
if not matched:
|
||||
match_found = False
|
||||
break
|
||||
else:
|
||||
match_found = False
|
||||
|
||||
if not match_found:
|
||||
match.nodes_map.pop(pn)
|
||||
|
||||
Reference in New Issue
Block a user