mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The time complexity of find node whether in NodeList is O(n). Reuse partition to speed up due to partition.nodes is hash table and has same elements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135317 Approved by: https://github.com/ezyang
869 lines
26 KiB
Python
869 lines
26 KiB
Python
# Owner(s): ["module: fx.passes"]
|
|
|
|
from dataclasses import dataclass
|
|
import operator
|
|
import logging
|
|
import sys
|
|
|
|
import torch
|
|
from torch.fx._symbolic_trace import symbolic_trace
|
|
|
|
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
|
from torch.fx.passes.operator_support import OperatorSupport
|
|
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
|
|
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
|
|
|
|
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
logging.basicConfig(level=logging.WARNING)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.linear2 = torch.nn.Linear(4, 4)
|
|
self.param = torch.nn.Parameter(torch.rand(4, 4))
|
|
|
|
def forward(self, a, b, c):
|
|
add = a + b
|
|
|
|
linear_1 = self.linear(add)
|
|
|
|
add_1 = add + c
|
|
add_2 = add_1 + self.param
|
|
add_3 = add_1 + linear_1
|
|
add_4 = add_2 + add_3
|
|
|
|
linear_2 = self.linear2(add_4)
|
|
|
|
add_5 = linear_2 + add_4
|
|
add_6 = add_5 + a
|
|
relu = add_6.relu()
|
|
|
|
return add_4, add_6, relu
|
|
|
|
class TestDeepModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a, b, c):
|
|
o = a + b
|
|
o = o + 1.0
|
|
|
|
# testing to avoid DFS uses in passes. Since Python has max recursion depth.
|
|
for _ in range(sys.getrecursionlimit() + 1):
|
|
o = o - c
|
|
|
|
return o
|
|
|
|
|
|
class TestPartitionFunctions:
|
|
@staticmethod
|
|
def forward1(a, b, c):
|
|
add = a + b
|
|
add_1 = add + b
|
|
add_2 = add_1 + c
|
|
relu_1 = add_2.relu()
|
|
add_3 = add_1 + add_2
|
|
add_4 = add_1 + relu_1 + add_3
|
|
relu_2 = add_4.relu()
|
|
add_5 = relu_2 + add_4
|
|
add_6 = add_5 + add_4
|
|
return add_4, add_6
|
|
|
|
@staticmethod
|
|
def forward2(a, b, _):
|
|
add = a + b
|
|
add_1 = add + b
|
|
relu_1 = add_1.relu() # blocked by this
|
|
add_3 = add_1 + relu_1
|
|
add_4 = add_1 + add_3
|
|
return add_4, add_1
|
|
|
|
@staticmethod
|
|
def forward3(a, b, c):
|
|
add = a + b
|
|
add_1 = a + c
|
|
add_2 = b + c
|
|
return add, add_1, add_2
|
|
|
|
@staticmethod
|
|
def forward4(a, b, c):
|
|
add = a + b
|
|
add_1 = a + c
|
|
add_2 = b + c
|
|
return torch.where(add > 0, add_1, add_2)
|
|
|
|
@staticmethod
|
|
def forward5(a, b, c):
|
|
# add should be fused right branch, as left branch is not supported
|
|
add = a + 1
|
|
# left branch
|
|
relu = add.relu()
|
|
# right branch
|
|
add_1 = add + 2
|
|
return relu, add_1
|
|
|
|
@staticmethod
|
|
def forward6(a, b, c):
|
|
# add should have its own partition, as neither branchs are supported
|
|
add = a + 1
|
|
# left branch
|
|
relu = add.relu()
|
|
# right branch
|
|
relu_1 = add.relu()
|
|
return relu, relu_1
|
|
|
|
@staticmethod
|
|
def forward7(a, b, c):
|
|
# both branches are supported, all adds should be fused together
|
|
add = a + 1
|
|
# left branch
|
|
add_1 = add + 2
|
|
# right branch is larger
|
|
add_2 = add + 1
|
|
add_3 = add_2 + 1
|
|
return add_3, add_1
|
|
|
|
@staticmethod
|
|
def forward8(a, b, c):
|
|
# both branches are in the same partition, add should join the same partition
|
|
add = a + 1
|
|
# left branch
|
|
add_1 = add + 2
|
|
# right branch
|
|
add_2 = add + 1
|
|
# left and right branch merges
|
|
add_3 = add_2 + add_1
|
|
|
|
return add_3
|
|
|
|
@staticmethod
|
|
def forward9(a, b, c):
|
|
add = a + 1
|
|
# branch 1
|
|
add_1 = add + 1
|
|
# branch 2
|
|
add_2 = add + 1
|
|
# branch_3
|
|
add_3 = add + 1
|
|
out = torch.stack([add_1, add_2, add_3])
|
|
return out
|
|
|
|
@staticmethod
|
|
def forward10(a, b, c):
|
|
add = a + 1
|
|
# branch 1
|
|
add_1 = add + 1
|
|
# branch 2
|
|
add_2 = add + 1
|
|
# branch 3: depends on branch 2
|
|
add_3 = add + add_2
|
|
out = torch.stack([add_1, add_2, add_3])
|
|
return out
|
|
|
|
@staticmethod
|
|
def forward11(a, b, c):
|
|
add = a + 1
|
|
# branch 1
|
|
add_1 = add.relu()
|
|
# branch 2 depends on branch 1
|
|
add_2 = add + add_1
|
|
# branch 3
|
|
add_3 = add.relu()
|
|
out = torch.stack([add_1, add_2, add_3])
|
|
return out
|
|
|
|
@staticmethod
|
|
def forward12(a, b, c):
|
|
b0 = a + 1.0
|
|
c0 = a + 1.5
|
|
x0 = b0.relu()
|
|
x1 = c0.relu()
|
|
b1 = b0 + x1
|
|
c1 = c0 + 1.2
|
|
# c2 has dependency on x0 & b0, when we merge {c0, c1, c2}
|
|
# this dependency should be updated to the fusion group and reflected
|
|
# on the decision to not fuse b0 & b1, which forms a cyclic dependency in
|
|
# the new graph
|
|
c2 = x0 + c0
|
|
return b1, c2
|
|
|
|
@staticmethod
|
|
def forward13(a, b, c):
|
|
a0, a1, a2, a3 = a.split(1, 0)
|
|
b1 = a0 + b
|
|
c1 = a1 + c
|
|
return b1 + c1
|
|
|
|
@staticmethod
|
|
def forward14(a, b, c):
|
|
a0, a1 = torch.ops.aten.std_mean(a)
|
|
out = a0 + 1.0
|
|
return out
|
|
|
|
@staticmethod
|
|
def forward15(a, b, c):
|
|
a0 = torch.ops.aten.view(a, [2, 2])
|
|
a1 = torch.ops.aten.permute(a0, [1, 0])
|
|
a2 = a1 + 1.0
|
|
a3 = torch.ops.aten.permute(a2, [1, 0])
|
|
a4 = a3 + 1.0
|
|
a5 = torch.ops.aten.permute(a4, [1, 0])
|
|
return torch.ops.aten.permute(a5, [1, 0])
|
|
|
|
@staticmethod
|
|
def forward16(a, b, c):
|
|
a0 = a - 1.0
|
|
a1 = torch.ops.aten.view(a0, [2, 2])
|
|
a2 = torch.ops.aten.permute(a1, [1, 0])
|
|
a3 = a2 + 1.0
|
|
a4 = torch.ops.aten.permute(a3, [1, 0])
|
|
a5 = a4 + 1.0
|
|
a6 = torch.ops.aten.permute(a5, [1, 0])
|
|
a7 = torch.ops.aten.permute(a6, [1, 0])
|
|
return a7 - 1.0
|
|
|
|
@staticmethod
|
|
def forward17(a, b, c, d, e, f):
|
|
a0 = a + b
|
|
a1 = c + d
|
|
a2 = e + f
|
|
return a0, a1, a2
|
|
|
|
@staticmethod
|
|
def forward18(a, b, c):
|
|
a0, a1 = torch.ops.aten.var_mean(a)
|
|
return a0
|
|
|
|
# A mock OperatorSupport class, where only operator.add is supported
|
|
class MockOperatorSupport(OperatorSupport):
|
|
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
|
|
return (node.op == "call_function" and
|
|
node.target in {operator.add, operator.getitem,
|
|
torch.ops.aten.view,
|
|
torch.ops.aten.permute,
|
|
torch.ops.aten.std_mean})
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestFXGraphPasses(JitTestCase):
|
|
|
|
@parametrize("fn, expected_partition, bookend_non_compute_pass", [
|
|
(TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False),
|
|
(TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False),
|
|
|
|
# 1 horizontal fusion with common producer
|
|
(TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False),
|
|
(TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False),
|
|
|
|
# 2 branches cases
|
|
(TestPartitionFunctions.forward5, [["add_1", "add"]], False),
|
|
(TestPartitionFunctions.forward6, [["add"]], False),
|
|
(TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False),
|
|
(TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False),
|
|
|
|
# 3 branch cases
|
|
(TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False),
|
|
(TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False),
|
|
(TestPartitionFunctions.forward11, [['add_1'], ['add']], False),
|
|
|
|
# 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition
|
|
(TestPartitionFunctions.forward12, [["add_2", "add_3", "add_4"], ["add", "add_1"]], False),
|
|
|
|
# 5 getitem special case
|
|
(TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False),
|
|
(TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False),
|
|
|
|
# 6 bookend non_compute pass
|
|
(TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True),
|
|
(TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
|
|
(TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True),
|
|
(TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
|
|
# should be empty partition, not a partiton with empty nodes
|
|
(TestPartitionFunctions.forward18, [], False),
|
|
])
|
|
def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass):
|
|
traced = symbolic_trace(fn)
|
|
|
|
non_compute_ops = []
|
|
if bookend_non_compute_pass:
|
|
non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"]
|
|
|
|
supported_ops = MockOperatorSupport()
|
|
partitioner = CapabilityBasedPartitioner(traced,
|
|
supported_ops,
|
|
allows_single_node_partition=True,
|
|
non_compute_ops=non_compute_ops)
|
|
partitions = partitioner.propose_partitions()
|
|
if bookend_non_compute_pass:
|
|
partitioner.remove_bookend_non_compute_ops(partitions)
|
|
|
|
partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
|
|
assert len(partitions_name) == len(expected_partition)
|
|
for i in range(len(partitions_name)):
|
|
assert set(partitions_name[i]) == set(expected_partition[i])
|
|
|
|
fused_graph = partitioner.fuse_partitions(partitions)
|
|
|
|
a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
|
|
|
|
expected = fn(a, b, c)
|
|
result = fused_graph(a, b, c)
|
|
torch.testing.assert_close(expected, result)
|
|
|
|
@parametrize("fn, expected_partition", [
|
|
(TestPartitionFunctions.forward17, [['add', 'add_1', 'add_2']]),
|
|
])
|
|
def test_partitioner_independent_output(self, fn, expected_partition):
|
|
traced = symbolic_trace(fn)
|
|
|
|
supported_ops = MockOperatorSupport()
|
|
partitioner = CapabilityBasedPartitioner(traced,
|
|
supported_ops,
|
|
allows_single_node_partition=True)
|
|
partitions = partitioner.propose_partitions()
|
|
partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
|
|
assert len(partitions_name) == len(expected_partition)
|
|
for i in range(len(partitions_name)):
|
|
assert set(partitions_name[i]) == set(expected_partition[i])
|
|
|
|
fused_graph = partitioner.fuse_partitions(partitions)
|
|
|
|
a, b, c, d, e, f = torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)
|
|
|
|
expected = fn(a, b, c, d, e, f)
|
|
result = fused_graph(a, b, c, d, e, f)
|
|
torch.testing.assert_close(expected, result)
|
|
|
|
@parametrize("partition", [
|
|
[['add', 'add_1'], ['add_5', 'add_6']],
|
|
[['add', 'add_1', 'add_2']], # vertical fusion
|
|
[['add_2', 'add_3']], # horizontal fusion
|
|
[['add_3', 'add_4']],
|
|
[['add_6', 'add_5']], # arbitray node order
|
|
[['add_4', 'add_1', 'add_3', 'add_2']], # arbitray node order
|
|
[['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']], # arbitray partition order
|
|
[['add_5', 'linear2']], # includes call_function + call_module node
|
|
[['add_6', 'relu']], # includes call_function + call_module node
|
|
[['param', 'add_2']], # includes get_attr + call_module nodes
|
|
[['param', 'add_1', 'linear']], # includes get_attr + call_function + call_module nodes
|
|
[["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]], # full graph
|
|
])
|
|
def test_fuser_util(self, partition):
|
|
m = TestModule()
|
|
gm = symbolic_trace(m)
|
|
|
|
nodes_by_name = {node.name : node for node in gm.graph.nodes}
|
|
|
|
partitions = []
|
|
for node_names in partition:
|
|
partitions.append(dict.fromkeys([nodes_by_name[name] for name in node_names]))
|
|
|
|
fused_graph = fuse_by_partitions(gm, partitions)
|
|
|
|
a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
|
|
|
|
expected = m(a, b, c)
|
|
result = fused_graph(a, b, c)
|
|
|
|
torch.testing.assert_close(expected, result)
|
|
|
|
@parametrize("partition", [
|
|
[['add', 'add_1'], ['add_1', 'add_5', 'add_6']], # add_1 exists in multiple partitions
|
|
[['add', 'add_1', 'add_3']], # invalid partition: circular dependency
|
|
[['add_4', 'add_5']], # invalid partition: circular dependency
|
|
[['relu', 'add_5']], # invalid partition: circular dependency
|
|
])
|
|
def test_fuser_util_xfail(self, partition):
|
|
m = TestModule()
|
|
gm = symbolic_trace(m)
|
|
|
|
nodes_by_name = {node.name : node for node in gm.graph.nodes}
|
|
|
|
partitions = []
|
|
for node_names in partition:
|
|
partitions.append(dict.fromkeys([nodes_by_name[name] for name in node_names]))
|
|
|
|
with self.assertRaises(Exception):
|
|
fuse_by_partitions(gm, partitions)
|
|
|
|
def test_fuser_pass_deep_model(self):
|
|
m = TestDeepModule()
|
|
traced = symbolic_trace(m)
|
|
|
|
supported_ops = MockOperatorSupport()
|
|
partitioner = CapabilityBasedPartitioner(traced,
|
|
supported_ops,
|
|
allows_single_node_partition=True)
|
|
partitions = partitioner.propose_partitions()
|
|
|
|
@dataclass
|
|
class TestCase:
|
|
match_output: bool
|
|
match_placeholder: bool
|
|
num_matches: int
|
|
remove_overlapping_matches: bool = True
|
|
|
|
class SingleNodePattern:
|
|
@staticmethod
|
|
def forward(x):
|
|
val = torch.neg(x)
|
|
return torch.add(val, val)
|
|
|
|
@staticmethod
|
|
def pattern(a):
|
|
return torch.neg(a)
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
TestCase(True, False, 0),
|
|
TestCase(False, True, 1),
|
|
TestCase(True, True, 0)
|
|
]
|
|
class SimplePattern:
|
|
@staticmethod
|
|
def forward(x, w1, w2):
|
|
m1 = torch.cat([w1, w2]).sum()
|
|
m2 = torch.cat([w2, w1]).sum()
|
|
m3 = torch.cat([m1, m2]).sum()
|
|
return x + torch.max(m1) + torch.max(m2) + m3
|
|
|
|
@staticmethod
|
|
def pattern(a, b):
|
|
return torch.cat([a, b]).sum()
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 3),
|
|
TestCase(True, False, 0),
|
|
TestCase(False, True, 2),
|
|
TestCase(True, True, 0)
|
|
]
|
|
|
|
class SimpleFullGraphMatching:
|
|
@staticmethod
|
|
def forward(x):
|
|
a = torch.neg(x)
|
|
return torch.add(a, a)
|
|
|
|
@staticmethod
|
|
def pattern(x):
|
|
a = torch.neg(x)
|
|
return torch.add(a, a)
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
TestCase(True, False, 1),
|
|
TestCase(False, True, 1),
|
|
TestCase(True, True, 1)
|
|
]
|
|
|
|
class DiamondShapePatternTestCase:
|
|
@staticmethod
|
|
def forward(x):
|
|
a = torch.neg(x)
|
|
|
|
a = a.relu()
|
|
left = a.sigmoid()
|
|
right = a.relu()
|
|
out = left + right
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
def pattern(a):
|
|
a = a.relu()
|
|
left = a.sigmoid()
|
|
right = a.relu()
|
|
out = left + right
|
|
return out
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
TestCase(True, False, 1),
|
|
TestCase(False, True, 0),
|
|
TestCase(True, True, 0)
|
|
]
|
|
|
|
class NonFullyContainedMatches:
|
|
@staticmethod
|
|
def forward(x, w1, w2, b1, b2):
|
|
# fully contained matched subgraph
|
|
m1 = torch.cat([w1, w2])
|
|
m2 = torch.cat([x, b2])
|
|
t0 = torch.addmm(b1, m1, m2.t())
|
|
t0_sum = torch.sum(t0) # use of t0 is not leaking
|
|
|
|
# leaking matched subgraph, m3 is leaked
|
|
m3 = torch.cat([w1, w2])
|
|
m4 = torch.cat([x, b2])
|
|
t1 = torch.addmm(b1, m3, m4.t())
|
|
m3_sum = torch.sum(m3)
|
|
|
|
return t0_sum, m3_sum
|
|
|
|
@staticmethod
|
|
def pattern(x, w1, w2, b1, b2):
|
|
m1 = torch.cat([w1, w2])
|
|
m2 = torch.cat([x, b2])
|
|
return torch.addmm(b1, m1, m2.t())
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
|
|
TestCase(True, False, 0),
|
|
|
|
TestCase(False, True, 1), # leaked used of placeholder is not leaking
|
|
]
|
|
|
|
class ChainRepeatedPattern:
|
|
@staticmethod
|
|
def forward(x):
|
|
x = torch.sigmoid(x)
|
|
x = torch.sigmoid(x)
|
|
x = torch.sigmoid(x)
|
|
return torch.sigmoid(x)
|
|
|
|
@staticmethod
|
|
def pattern(x):
|
|
return torch.sigmoid(torch.sigmoid(x))
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 3, remove_overlapping_matches=False),
|
|
TestCase(False, False, 2, remove_overlapping_matches=True),
|
|
TestCase(True, False, 1),
|
|
TestCase(False, True, 1),
|
|
TestCase(True, True, 0)
|
|
]
|
|
|
|
class QuantizationModel:
|
|
@staticmethod
|
|
def forward(x):
|
|
x += 3
|
|
x = x.dequantize()
|
|
x = torch.sigmoid(x)
|
|
x = x.to(torch.float16)
|
|
return x
|
|
|
|
@staticmethod
|
|
def pattern(x):
|
|
x = x.dequantize()
|
|
x = torch.sigmoid(x)
|
|
x = x.to(torch.float16)
|
|
return x
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
TestCase(True, False, 1),
|
|
TestCase(False, True, 0),
|
|
TestCase(True, True, 0)
|
|
]
|
|
|
|
class MultipleOutputsWithDependency:
|
|
@staticmethod
|
|
def forward(x):
|
|
y = x.relu()
|
|
z = y.sigmoid()
|
|
return z, y
|
|
|
|
@staticmethod
|
|
def pattern(a):
|
|
b = a.relu()
|
|
c = b.sigmoid()
|
|
return b, c # outputs have data dependency
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
TestCase(True, False, 0),
|
|
TestCase(False, True, 1),
|
|
TestCase(True, True, 0)
|
|
]
|
|
|
|
class MultipleOutputsWithoutDependency:
|
|
@staticmethod
|
|
def forward(x):
|
|
x = x + 1
|
|
|
|
# target subgraph to match
|
|
x = x.relu()
|
|
z = x.sum()
|
|
y = x.sigmoid()
|
|
|
|
out = y.sigmoid() + z.sum()
|
|
return out
|
|
|
|
@staticmethod
|
|
def pattern(a):
|
|
a = a.relu()
|
|
b = a.sigmoid()
|
|
c = a.sum()
|
|
return b, c
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
TestCase(True, False, 0),
|
|
TestCase(False, True, 0),
|
|
TestCase(True, True, 0)
|
|
]
|
|
|
|
class MultipleOutputsMultipleOverlappingMatches:
|
|
@staticmethod
|
|
def forward(x):
|
|
x = x + 1
|
|
|
|
# target subgraph to match
|
|
x = x.relu()
|
|
z = x.sum()
|
|
z1 = x.sum()
|
|
y = x.sigmoid()
|
|
y1 = x.sigmoid()
|
|
|
|
return z + z1 + y + y1
|
|
|
|
@staticmethod
|
|
def pattern(a):
|
|
a = a.relu()
|
|
b = a.sigmoid()
|
|
c = a.sum()
|
|
return a, b, c
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 4, remove_overlapping_matches=False),
|
|
TestCase(False, False, 1, remove_overlapping_matches=True),
|
|
]
|
|
|
|
class MultipleOutputsMultipleNonOverlappingMatches:
|
|
@staticmethod
|
|
def forward(x):
|
|
x = x + 1
|
|
|
|
# target subgraph to match
|
|
x = x.relu()
|
|
z = x.sum()
|
|
y = x.sigmoid()
|
|
|
|
x = x.relu()
|
|
z1 = x.sum()
|
|
y1 = x.sigmoid()
|
|
|
|
return z + z1 + y + y1
|
|
|
|
@staticmethod
|
|
def pattern(a):
|
|
a = a.relu()
|
|
b = a.sigmoid()
|
|
c = a.sum()
|
|
return b, c
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
]
|
|
|
|
class MultipleOutputsIdenticalAnchor:
|
|
@staticmethod
|
|
def forward(x):
|
|
x = x + 1
|
|
|
|
# target subgraph to match
|
|
x = x.relu()
|
|
y = x.sigmoid()
|
|
y1 = x.sigmoid()
|
|
|
|
return y, y1
|
|
|
|
@staticmethod
|
|
def pattern(a):
|
|
a = a.relu()
|
|
b = a.sigmoid()
|
|
b1 = a.sigmoid()
|
|
return b, b1
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
# (False, False, 2), # FIXME: currently still matches to 2, should fix to 1
|
|
TestCase(True, False, 1),
|
|
TestCase(False, True, 0),
|
|
]
|
|
|
|
|
|
class MultipleOutputsHorizontalPattern:
|
|
@staticmethod
|
|
def forward(x):
|
|
x = x + 1
|
|
|
|
# target subgraph to match
|
|
y1 = x.relu()
|
|
y2 = x.sigmoid()
|
|
|
|
return y1, y2
|
|
|
|
@staticmethod
|
|
def pattern(a):
|
|
b1 = a.relu()
|
|
b2 = a.sigmoid()
|
|
|
|
return b1, b2
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
TestCase(True, False, 1),
|
|
TestCase(False, True, 0),
|
|
TestCase(True, True, 0)
|
|
]
|
|
|
|
class MultiOutputWithWithInvalidMatches:
|
|
@staticmethod
|
|
def forward(x):
|
|
res0 = torch.nn.functional.linear(x, torch.rand(3, 3))
|
|
res1 = torch.sigmoid(res0)
|
|
res2 = res0 * res1
|
|
res3 = torch.sum(res2, dim=1)
|
|
return res3
|
|
|
|
@staticmethod
|
|
def pattern(a, b, c):
|
|
lin_res = torch.nn.functional.linear(a, b)
|
|
mul_res = lin_res * c
|
|
return lin_res, mul_res
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 0),
|
|
TestCase(True, False, 0),
|
|
TestCase(False, True, 0),
|
|
]
|
|
|
|
class QuantizationFp8Pattern:
|
|
@classmethod
|
|
def setup(cls):
|
|
cls.quantization = torch.library.Library("fp8_quantization", "DEF") # noqa: TOR901
|
|
cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
|
|
cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
|
|
|
|
@classmethod
|
|
def tearDown(cls):
|
|
del cls.quantization
|
|
|
|
@staticmethod
|
|
def forward(self, arg0_1, arg1_1):
|
|
qt = torch.ops.fp8_quantization
|
|
_scale_0 = self._scale_0
|
|
quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0)
|
|
dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0)
|
|
_scale_1 = self._scale_0
|
|
quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1)
|
|
dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1)
|
|
add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1)
|
|
_scale_2 = self._scale_0
|
|
quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2)
|
|
dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2)
|
|
return dequantize_per_tensor_affine_fp8_2
|
|
|
|
@staticmethod
|
|
def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
|
|
qt = torch.ops.fp8_quantization
|
|
a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale)
|
|
b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale)
|
|
output = torch.ops.aten.add.Tensor(a, b)
|
|
|
|
qt.dequantize_per_tensor_affine_fp8
|
|
|
|
output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale)
|
|
return output
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 1),
|
|
]
|
|
|
|
class NoAnchorFound:
|
|
# This test case is for pattern where no matching anchor is found in the target graph
|
|
# `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes
|
|
@staticmethod
|
|
def forward(x):
|
|
x = x + 1
|
|
return x
|
|
|
|
@staticmethod
|
|
def pattern(a):
|
|
b1 = a.relu()
|
|
return b1
|
|
|
|
test_cases = [
|
|
# match_output, match_placeholder, num_matches
|
|
TestCase(False, False, 0),
|
|
TestCase(True, False, 0),
|
|
TestCase(False, True, 0),
|
|
TestCase(True, True, 0)
|
|
]
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestFXMatcherUtils(JitTestCase):
|
|
|
|
@parametrize("test_model", [
|
|
SingleNodePattern,
|
|
SimplePattern,
|
|
SimpleFullGraphMatching,
|
|
DiamondShapePatternTestCase,
|
|
NonFullyContainedMatches,
|
|
ChainRepeatedPattern,
|
|
QuantizationModel,
|
|
MultipleOutputsWithDependency,
|
|
MultipleOutputsWithoutDependency,
|
|
MultipleOutputsMultipleOverlappingMatches,
|
|
MultipleOutputsMultipleNonOverlappingMatches,
|
|
MultipleOutputsIdenticalAnchor,
|
|
MultipleOutputsHorizontalPattern,
|
|
MultiOutputWithWithInvalidMatches,
|
|
QuantizationFp8Pattern,
|
|
NoAnchorFound,
|
|
])
|
|
def test_subgraph_matcher(self, test_model):
|
|
|
|
setup = getattr(test_model, "setup", None)
|
|
if callable(setup):
|
|
setup()
|
|
|
|
traced = symbolic_trace(test_model.forward)
|
|
pattern_traced = symbolic_trace(test_model.pattern)
|
|
|
|
for test_case in test_model.test_cases:
|
|
|
|
matcher = SubgraphMatcher(pattern_traced.graph,
|
|
match_output=test_case.match_output,
|
|
match_placeholder=test_case.match_placeholder,
|
|
remove_overlapping_matches=test_case.remove_overlapping_matches)
|
|
matches = matcher.match(traced.graph)
|
|
|
|
assert len(matches) == test_case.num_matches
|
|
|
|
for match in matches:
|
|
for node in pattern_traced.graph.nodes:
|
|
if not test_case.match_placeholder and node.op == "placeholder":
|
|
continue
|
|
if not test_case.match_output and node.op == "output":
|
|
continue
|
|
assert node in match.nodes_map
|
|
|
|
tearDown = getattr(test_model, "tearDown", None)
|
|
if callable(setup):
|
|
tearDown()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|