mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[inductor][pattern matcher] revise mkldnn pattern matcher UT (#141334)"
This reverts commit 942a2438e263a2632b8934dd245060c9b237f4be. Reverted https://github.com/pytorch/pytorch/pull/141334 on behalf of https://github.com/atalman due to Failing internally ([comment](https://github.com/pytorch/pytorch/pull/141334#issuecomment-2512891840))
This commit is contained in:
@ -135,18 +135,23 @@ class TestPatternMatcherBase(TestCase):
|
||||
self,
|
||||
mod,
|
||||
inputs,
|
||||
matcher_check_fn,
|
||||
matcher_count=None,
|
||||
matcher_nodes=None,
|
||||
atol=1e-5,
|
||||
rtol=1.3e-6,
|
||||
check_autocast=torch.float32,
|
||||
check_quantization=False,
|
||||
is_qat=False,
|
||||
matcher_check_fn=None,
|
||||
dtype=None,
|
||||
is_dynamic=False,
|
||||
quantizer=None,
|
||||
):
|
||||
counters.clear()
|
||||
torch._dynamo.reset()
|
||||
assert matcher_check_fn is not None or (
|
||||
matcher_count is not None and matcher_nodes is not None
|
||||
)
|
||||
if (
|
||||
check_autocast == torch.bfloat16
|
||||
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
|
||||
@ -169,14 +174,34 @@ class TestPatternMatcherBase(TestCase):
|
||||
)
|
||||
with torch.no_grad(), maybe_autocast:
|
||||
_ = torch.compile(convert_model)(*inputs)
|
||||
matcher_check_fn()
|
||||
if matcher_count is not None:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["pattern_matcher_count"], matcher_count
|
||||
)
|
||||
if matcher_nodes is not None:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["pattern_matcher_nodes"],
|
||||
matcher_nodes,
|
||||
)
|
||||
if matcher_check_fn is not None:
|
||||
matcher_check_fn()
|
||||
else:
|
||||
with torch.no_grad(), maybe_autocast:
|
||||
clone_inputs = self._clone_inputs(inputs)
|
||||
expected = mod(*inputs)
|
||||
actual = torch.compile(mod)(*clone_inputs)
|
||||
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
||||
matcher_check_fn()
|
||||
if matcher_count is not None:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["pattern_matcher_count"], matcher_count
|
||||
)
|
||||
if matcher_nodes is not None:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["pattern_matcher_nodes"],
|
||||
matcher_nodes,
|
||||
)
|
||||
if matcher_check_fn is not None:
|
||||
matcher_check_fn()
|
||||
|
||||
def _test_code_common(
|
||||
self,
|
||||
@ -268,24 +293,15 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
.add(1)
|
||||
.to(memory_format=memory_format)
|
||||
)
|
||||
|
||||
def matcher_check_fn():
|
||||
match_nodes = unary_list[unary_fn]
|
||||
if dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
) and self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
match_nodes += 2
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
||||
match_nodes,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
||||
)
|
||||
|
||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||
# Add 1 for weight packing pass.
|
||||
match_nodes = unary_list[unary_fn] + 1
|
||||
if dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
) and self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
match_nodes += 2
|
||||
self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype)
|
||||
generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
|
||||
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
||||
|
||||
@ -336,21 +352,16 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
# only fuse for linear when the dtype is bf16
|
||||
mod = mod
|
||||
v = torch.randn(2, 10)
|
||||
|
||||
def matcher_check_fn():
|
||||
match_nodes = unary_list[unary_fn]
|
||||
if self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
match_nodes += 2
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
||||
match_nodes,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1
|
||||
)
|
||||
|
||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||
# packing pass + unary fusion.
|
||||
matcher_count = 2
|
||||
# Add 1 for weight packing pass.
|
||||
matcher_nodes = unary_list[unary_fn] + 1
|
||||
if self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
matcher_nodes += 2
|
||||
self._test_common(
|
||||
mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype
|
||||
)
|
||||
# only generated 1 kernel for "to"
|
||||
self.assertEqual(metrics.generated_kernel_count, 1)
|
||||
|
||||
@ -367,14 +378,10 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
for bias in [True, False]:
|
||||
mod = M(bias=bias).eval()
|
||||
v = torch.randn(2, 10)
|
||||
|
||||
# packing pass.
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1
|
||||
)
|
||||
|
||||
self._test_common(mod, (v,), matcher_check_fn)
|
||||
matcher_count = 1
|
||||
matcher_nodes = 1
|
||||
self._test_common(mod, (v,), matcher_count, matcher_nodes)
|
||||
|
||||
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||
def test_linear_input_non_contiguous_3D_wo_bias(self):
|
||||
@ -447,29 +454,18 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
metrics.reset()
|
||||
fold_mod = M(dtype, unary_fn, cast_bias=True).eval()
|
||||
v = torch.randn(2, 10)
|
||||
|
||||
def folder_matcher_check_fn():
|
||||
match_nodes = unary_list[unary_fn]
|
||||
if self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
match_nodes += 2
|
||||
# we have 2 linears, so we double the matcher_count/nodes
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
||||
match_nodes * 2,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["binary_folding"], 2)
|
||||
|
||||
matcher_count = 3
|
||||
# Add 1 for weight packing pass, add 2 for bias folding pass per linear.
|
||||
matcher_nodes = unary_list[unary_fn] + 3
|
||||
if self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
matcher_nodes += 2
|
||||
# we have 2 linears, so we double the matcher_count/nodes
|
||||
self._test_common(
|
||||
fold_mod,
|
||||
(v,),
|
||||
folder_matcher_check_fn,
|
||||
matcher_count * 2,
|
||||
matcher_nodes * 2,
|
||||
check_autocast=dtype,
|
||||
)
|
||||
self.assertEqual(metrics.generated_kernel_count, 1)
|
||||
@ -477,13 +473,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
# https://github.com/pytorch/pytorch/pull/129138
|
||||
metrics.reset()
|
||||
mod = M(dtype, unary_fn, cast_bias=False).eval()
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
|
||||
)
|
||||
|
||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||
self._test_common(mod, (v,), 2, 2, check_autocast=dtype)
|
||||
# 1 kernel for "to_lowp", 2 kernels for unary ops
|
||||
self.assertEqual(metrics.generated_kernel_count, 3)
|
||||
|
||||
@ -537,24 +527,15 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
v = torch.randn(x_shape, dtype=torch.float32).to(
|
||||
memory_format=memory_format
|
||||
)
|
||||
|
||||
def matcher_check_fn():
|
||||
match_nodes = unary_list[unary_fn]
|
||||
if dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
) and self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
match_nodes += 2
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
||||
match_nodes,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
||||
)
|
||||
|
||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||
# Add 1 for weight packing pass.
|
||||
match_nodes = unary_list[unary_fn] + 1
|
||||
if dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
) and self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
match_nodes += 2
|
||||
self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype)
|
||||
generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
|
||||
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
||||
|
||||
@ -631,22 +612,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
.add(1)
|
||||
.to(memory_format=memory_format)
|
||||
)
|
||||
|
||||
def matcher_check_fn():
|
||||
match_nodes = binary_list[binary_fn][1]
|
||||
if has_relu:
|
||||
match_nodes += 1
|
||||
self.assertEqual(
|
||||
counters["inductor"][
|
||||
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
||||
],
|
||||
match_nodes,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 2
|
||||
)
|
||||
|
||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||
match_count = binary_list[binary_fn][0] + 2
|
||||
match_nodes = binary_list[binary_fn][1]
|
||||
if has_relu:
|
||||
match_nodes += 1
|
||||
self._test_common(
|
||||
mod, (v,), match_count, match_nodes + 2, check_autocast=dtype
|
||||
)
|
||||
generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
|
||||
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
||||
|
||||
@ -774,33 +746,26 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
|
||||
for binary_fn, input_shape, bias, dtype in options:
|
||||
metrics.reset()
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"][
|
||||
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
||||
],
|
||||
2,
|
||||
)
|
||||
reshape_linear_reshape_match_nodes = 3 if len(input_shape) == 3 else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"],
|
||||
reshape_linear_reshape_match_nodes,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1
|
||||
)
|
||||
|
||||
# addmm(mm) + (linear+add)
|
||||
match_count = 2
|
||||
match_nodes = 3
|
||||
if len(input_shape) == 3:
|
||||
is_inplace = binary_list[binary_fn][2]
|
||||
# view + linear + view(joint_graph+freeze pass)
|
||||
match_count = match_count + 5 if is_inplace else match_count + 3
|
||||
match_nodes = match_nodes + 8 if is_inplace else match_nodes + 5
|
||||
mod = M(binary_fn, input_shape[-1], out_feature, bias).eval()
|
||||
v = torch.randn(input_shape)
|
||||
other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
(
|
||||
v,
|
||||
other,
|
||||
),
|
||||
matcher_check_fn,
|
||||
match_count,
|
||||
match_nodes,
|
||||
check_autocast=dtype,
|
||||
)
|
||||
self.assertEqual(metrics.generated_kernel_count, 1)
|
||||
@ -872,25 +837,18 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
dtypes.append(torch.bfloat16)
|
||||
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
||||
dtypes.append(torch.float16)
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], 7
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"], 6
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
mod = M().to(dtype).eval()
|
||||
v = torch.randn(2, 4, 16).to(dtype)
|
||||
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
||||
# 1. view(match_count=4, match_nodes=4).
|
||||
# 2. mm to packed linear(match_count=2, match_nodes=2).
|
||||
# 3. view+linear+view to linear(match_count=2, match_nodes=6).
|
||||
# 4. linear+silu fusion(match_count=1, match_nodes=5)
|
||||
# 5. linear+relu fusion(match_count=1, match_nodes=2)
|
||||
|
||||
match_count = 10
|
||||
match_nodes = 19
|
||||
self._test_common(mod, (v,), match_count, match_nodes, rtol=1e-2, atol=1e-2)
|
||||
|
||||
def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False):
|
||||
class M(torch.nn.Module):
|
||||
@ -928,9 +886,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1220,9 +1178,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
def _qconv2d_add_cpu_test_helper2(self, use_relu=False, int8_mixed_bf16=False):
|
||||
@ -1304,9 +1262,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(x, x2, x3),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1370,8 +1328,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(x1, x2),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1423,8 +1381,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1468,8 +1426,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1542,8 +1500,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1586,9 +1544,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
is_qat=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
def _qat_qconv2d_unary_cpu_test_helper(
|
||||
@ -1628,9 +1586,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
is_qat=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1725,9 +1683,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
is_qat=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1784,9 +1742,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
is_qat=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1845,8 +1803,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
def _qlinear_cpu_test_helper(
|
||||
@ -1881,13 +1839,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
inputs,
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=(
|
||||
matcher_check_fn
|
||||
if matcher_check_fn is not None
|
||||
else _default_matcher_check_fn
|
||||
),
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
check_quantization=True,
|
||||
is_qat=is_qat,
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
@ -2051,9 +2009,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
inputs,
|
||||
matcher_check_fn,
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -2222,9 +2180,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
is_qat=is_qat,
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
@ -2316,13 +2274,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
inputs,
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=(
|
||||
matcher_check_fn
|
||||
if matcher_check_fn is not None
|
||||
else default_matcher_check_fn
|
||||
),
|
||||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
check_quantization=True,
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
|
||||
@ -2461,8 +2419,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(x1, x2),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -2508,8 +2466,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -2547,8 +2505,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -2596,8 +2554,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/99841.
|
||||
@ -2619,18 +2577,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
min_values = [3, torch.randn(1, 32, 28, 28)]
|
||||
max_values = [0, torch.randn(1, 32, 28, 28)]
|
||||
v = torch.randn(1, 3, 28, 28)
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], 3
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
||||
)
|
||||
|
||||
for min_value, max_value in zip(min_values, max_values):
|
||||
mod = Model().eval()
|
||||
self._test_common(mod, (v, min_value, max_value), matcher_check_fn)
|
||||
self._test_common(mod, (v, min_value, max_value), 2, 4)
|
||||
|
||||
def test_leaky_relu_pattern_fallback(self):
|
||||
class Model(torch.nn.Module):
|
||||
@ -2645,20 +2594,11 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
return torch.where(conv_out > 0, conv_out, conv_out * negative_slope)
|
||||
|
||||
negative_slopes = [0.1, torch.randn(1, 32, 28, 28)]
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], 4
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
v = torch.randn(1, 3, 28, 28)
|
||||
for negative_slope in negative_slopes:
|
||||
mod = Model().eval()
|
||||
self._test_common(mod, (v, negative_slope), matcher_check_fn)
|
||||
self._test_common(mod, (v, negative_slope), 2, 5)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/99838.
|
||||
def test_conv2d_add_scalar(self):
|
||||
@ -2674,16 +2614,10 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
out = torch.add(out_conv, 1.0)
|
||||
return out
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(counters["inductor"]["binary_folding"], 1)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
mod = Model().eval()
|
||||
v = torch.randn(1, 3, 28, 28)
|
||||
self._test_common(mod, (v,), matcher_check_fn)
|
||||
self._test_common(mod, (v,), 2, 3)
|
||||
|
||||
def test_conv2d_binary_inplace_fusion_pass_cpu(
|
||||
self, include_ops=None, exclude_ops=None
|
||||
@ -3083,7 +3017,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(x, w, s),
|
||||
matcher_check_fn,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
check_quantization=False,
|
||||
atol=0.001,
|
||||
rtol=0.07,
|
||||
@ -3118,11 +3052,7 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||
x_shape = (1, 3, 28, 28)
|
||||
mod = M().eval()
|
||||
v = torch.randn(x_shape, dtype=torch.float32)
|
||||
|
||||
def matcher_check_fn():
|
||||
return
|
||||
|
||||
self._test_common(mod, (v,), matcher_check_fn)
|
||||
self._test_common(mod, (v,), 0, 0)
|
||||
|
||||
def test_multi_linear_share_same_input_dynamic(self):
|
||||
# llama pattern.
|
||||
@ -3142,28 +3072,18 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||
dtypes.append(torch.bfloat16)
|
||||
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
||||
dtypes.append(torch.float16)
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], 7
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"], 6
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
mod = M().to(dtype).eval()
|
||||
v = torch.randn(2, 4, 16).to(dtype)
|
||||
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
||||
# 1. view(match_count=4, match_nodes=4).
|
||||
# 2. mm to packed linear(match_count=2, match_nodes=2).
|
||||
# 3. view+linear+view to linear(match_count=2, match_nodes=6).
|
||||
# 4. linear to linear+swish(match_count=1, match_nodes=2).
|
||||
# 5. linear to linear+relu(match_count=1, match_nodes=5).
|
||||
|
||||
match_count = 10
|
||||
match_nodes = 19
|
||||
self._test_common(mod, (v,), match_count, match_nodes, rtol=1e-2, atol=1e-2)
|
||||
|
||||
def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None):
|
||||
r"""
|
||||
@ -3241,9 +3161,9 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
is_qat=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -3319,8 +3239,8 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
matcher_check_fn,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
quantizer=quantizer,
|
||||
)
|
||||
|
||||
|
@ -5,7 +5,6 @@ from functools import reduce
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
||||
|
||||
from .. import ir
|
||||
@ -250,10 +249,6 @@ if torch._C._has_mkldnn:
|
||||
unary_attr.scalars_attr,
|
||||
unary_attr.algorithm_attr,
|
||||
]
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len(
|
||||
match.nodes
|
||||
)
|
||||
return L[computation_op](*computation_args)
|
||||
|
||||
return fn
|
||||
@ -277,10 +272,6 @@ if torch._C._has_mkldnn:
|
||||
)
|
||||
matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
|
||||
computation_args = list(args)
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len(
|
||||
match.nodes
|
||||
)
|
||||
if matched:
|
||||
computation_args = computation_args[:-3] + [
|
||||
"leaky_relu",
|
||||
@ -327,10 +318,6 @@ if torch._C._has_mkldnn:
|
||||
)
|
||||
matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
|
||||
computation_args = list(args)
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1
|
||||
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len(
|
||||
match.nodes
|
||||
)
|
||||
if matched:
|
||||
computation_args = computation_args[:-3] + [
|
||||
"hardtanh",
|
||||
@ -511,10 +498,6 @@ if torch._C._has_mkldnn:
|
||||
]
|
||||
else:
|
||||
computation_args += [1.0, None, [], None]
|
||||
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
|
||||
counters["inductor"][
|
||||
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
||||
] += len(match.nodes)
|
||||
return L[fusion_op](*computation_args)
|
||||
|
||||
return fn
|
||||
@ -559,10 +542,6 @@ if torch._C._has_mkldnn:
|
||||
]
|
||||
else:
|
||||
computation_args += [1.0, None, [], None]
|
||||
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
|
||||
counters["inductor"][
|
||||
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
||||
] += len(match.nodes)
|
||||
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
|
||||
other.realize()
|
||||
if not _can_be_inplace(other) or other.data.shape != list(
|
||||
@ -816,10 +795,6 @@ if torch._C._has_mkldnn:
|
||||
graph.erase_node(old_linear_node)
|
||||
if len(reshape_1_node.users) == 0:
|
||||
graph.erase_node(reshape_1_node)
|
||||
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_count"] += 1
|
||||
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"] += len(
|
||||
match.nodes
|
||||
)
|
||||
|
||||
def is_linear_add_bias(match):
|
||||
add_node = match.output_node()
|
||||
@ -870,8 +845,6 @@ if torch._C._has_mkldnn:
|
||||
repl.meta.update(add_node.meta)
|
||||
add_node.replace_all_uses_with(repl)
|
||||
match.erase_nodes()
|
||||
counters["inductor"]["mkldnn_linear_bias_matcher_count"] += 1
|
||||
counters["inductor"]["mkldnn_linear_bias_matcher_nodes"] += len(match.nodes)
|
||||
|
||||
def _is_packable_mkldnn_rnn_layer(match):
|
||||
lstm_node = match.output_node()
|
||||
@ -1110,10 +1083,6 @@ if torch._C._has_mkldnn:
|
||||
conv_node.replace_all_uses_with(packed_conv_node)
|
||||
packed_conv_node.meta.update(conv_node.meta)
|
||||
graph.erase_node(conv_node)
|
||||
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"] += 1
|
||||
counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"] += len(
|
||||
match.nodes
|
||||
)
|
||||
|
||||
@register_freezing_graph_pattern(
|
||||
CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args),
|
||||
@ -1165,10 +1134,6 @@ if torch._C._has_mkldnn:
|
||||
lstm_node.replace_all_uses_with(packed_lstm_node)
|
||||
packed_lstm_node.meta.update(lstm_node.meta)
|
||||
graph.erase_node(lstm_node)
|
||||
counters["inductor"]["mkldnn_rnn_weight_pack_matcher_count"] += 1
|
||||
counters["inductor"]["mkldnn_rnn_weight_pack_matcher_nodes"] += len(
|
||||
match.nodes
|
||||
)
|
||||
|
||||
@register_freezing_graph_pattern(
|
||||
CallFunction(
|
||||
@ -1257,10 +1222,6 @@ if torch._C._has_mkldnn:
|
||||
linear_node.replace_all_uses_with(packed_linear_node)
|
||||
packed_linear_node.meta.update(linear_node.meta)
|
||||
graph.erase_node(linear_node)
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"] += 1
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_nodes"] += len(
|
||||
match.nodes
|
||||
)
|
||||
|
||||
def _eliminate_duplicate_packed_nodes(gm):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user