Revert "[inductor][pattern matcher] revise mkldnn pattern matcher UT (#141334)"

This reverts commit 942a2438e263a2632b8934dd245060c9b237f4be.

Reverted https://github.com/pytorch/pytorch/pull/141334 on behalf of https://github.com/atalman due to Failing internally ([comment](https://github.com/pytorch/pytorch/pull/141334#issuecomment-2512891840))
This commit is contained in:
PyTorch MergeBot
2024-12-02 21:29:01 +00:00
parent 6b05e31042
commit b47bdb06d8
2 changed files with 134 additions and 253 deletions

View File

@ -135,18 +135,23 @@ class TestPatternMatcherBase(TestCase):
self,
mod,
inputs,
matcher_check_fn,
matcher_count=None,
matcher_nodes=None,
atol=1e-5,
rtol=1.3e-6,
check_autocast=torch.float32,
check_quantization=False,
is_qat=False,
matcher_check_fn=None,
dtype=None,
is_dynamic=False,
quantizer=None,
):
counters.clear()
torch._dynamo.reset()
assert matcher_check_fn is not None or (
matcher_count is not None and matcher_nodes is not None
)
if (
check_autocast == torch.bfloat16
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
@ -169,14 +174,34 @@ class TestPatternMatcherBase(TestCase):
)
with torch.no_grad(), maybe_autocast:
_ = torch.compile(convert_model)(*inputs)
matcher_check_fn()
if matcher_count is not None:
self.assertEqual(
counters["inductor"]["pattern_matcher_count"], matcher_count
)
if matcher_nodes is not None:
self.assertEqual(
counters["inductor"]["pattern_matcher_nodes"],
matcher_nodes,
)
if matcher_check_fn is not None:
matcher_check_fn()
else:
with torch.no_grad(), maybe_autocast:
clone_inputs = self._clone_inputs(inputs)
expected = mod(*inputs)
actual = torch.compile(mod)(*clone_inputs)
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
matcher_check_fn()
if matcher_count is not None:
self.assertEqual(
counters["inductor"]["pattern_matcher_count"], matcher_count
)
if matcher_nodes is not None:
self.assertEqual(
counters["inductor"]["pattern_matcher_nodes"],
matcher_nodes,
)
if matcher_check_fn is not None:
matcher_check_fn()
def _test_code_common(
self,
@ -268,24 +293,15 @@ class TestPatternMatcher(TestPatternMatcherBase):
.add(1)
.to(memory_format=memory_format)
)
def matcher_check_fn():
match_nodes = unary_list[unary_fn]
if dtype in (
torch.float16,
torch.bfloat16,
) and self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
match_nodes += 2
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
match_nodes,
)
self.assertEqual(
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
)
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
# Add 1 for weight packing pass.
match_nodes = unary_list[unary_fn] + 1
if dtype in (
torch.float16,
torch.bfloat16,
) and self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
match_nodes += 2
self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype)
generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
@ -336,21 +352,16 @@ class TestPatternMatcher(TestPatternMatcherBase):
# only fuse for linear when the dtype is bf16
mod = mod
v = torch.randn(2, 10)
def matcher_check_fn():
match_nodes = unary_list[unary_fn]
if self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
match_nodes += 2
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
match_nodes,
)
self.assertEqual(
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1
)
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
# packing pass + unary fusion.
matcher_count = 2
# Add 1 for weight packing pass.
matcher_nodes = unary_list[unary_fn] + 1
if self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
matcher_nodes += 2
self._test_common(
mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype
)
# only generated 1 kernel for "to"
self.assertEqual(metrics.generated_kernel_count, 1)
@ -367,14 +378,10 @@ class TestPatternMatcher(TestPatternMatcherBase):
for bias in [True, False]:
mod = M(bias=bias).eval()
v = torch.randn(2, 10)
# packing pass.
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1
)
self._test_common(mod, (v,), matcher_check_fn)
matcher_count = 1
matcher_nodes = 1
self._test_common(mod, (v,), matcher_count, matcher_nodes)
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_linear_input_non_contiguous_3D_wo_bias(self):
@ -447,29 +454,18 @@ class TestPatternMatcher(TestPatternMatcherBase):
metrics.reset()
fold_mod = M(dtype, unary_fn, cast_bias=True).eval()
v = torch.randn(2, 10)
def folder_matcher_check_fn():
match_nodes = unary_list[unary_fn]
if self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
match_nodes += 2
# we have 2 linears, so we double the matcher_count/nodes
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
match_nodes * 2,
)
self.assertEqual(
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
)
self.assertEqual(counters["inductor"]["binary_folding"], 2)
matcher_count = 3
# Add 1 for weight packing pass, add 2 for bias folding pass per linear.
matcher_nodes = unary_list[unary_fn] + 3
if self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
matcher_nodes += 2
# we have 2 linears, so we double the matcher_count/nodes
self._test_common(
fold_mod,
(v,),
folder_matcher_check_fn,
matcher_count * 2,
matcher_nodes * 2,
check_autocast=dtype,
)
self.assertEqual(metrics.generated_kernel_count, 1)
@ -477,13 +473,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
# https://github.com/pytorch/pytorch/pull/129138
metrics.reset()
mod = M(dtype, unary_fn, cast_bias=False).eval()
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
)
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
self._test_common(mod, (v,), 2, 2, check_autocast=dtype)
# 1 kernel for "to_lowp", 2 kernels for unary ops
self.assertEqual(metrics.generated_kernel_count, 3)
@ -537,24 +527,15 @@ class TestPatternMatcher(TestPatternMatcherBase):
v = torch.randn(x_shape, dtype=torch.float32).to(
memory_format=memory_format
)
def matcher_check_fn():
match_nodes = unary_list[unary_fn]
if dtype in (
torch.float16,
torch.bfloat16,
) and self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
match_nodes += 2
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
match_nodes,
)
self.assertEqual(
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
)
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
# Add 1 for weight packing pass.
match_nodes = unary_list[unary_fn] + 1
if dtype in (
torch.float16,
torch.bfloat16,
) and self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
match_nodes += 2
self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype)
generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
@ -631,22 +612,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
.add(1)
.to(memory_format=memory_format)
)
def matcher_check_fn():
match_nodes = binary_list[binary_fn][1]
if has_relu:
match_nodes += 1
self.assertEqual(
counters["inductor"][
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
],
match_nodes,
)
self.assertEqual(
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 2
)
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
match_count = binary_list[binary_fn][0] + 2
match_nodes = binary_list[binary_fn][1]
if has_relu:
match_nodes += 1
self._test_common(
mod, (v,), match_count, match_nodes + 2, check_autocast=dtype
)
generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype)
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
@ -774,33 +746,26 @@ class TestPatternMatcher(TestPatternMatcherBase):
for binary_fn, input_shape, bias, dtype in options:
metrics.reset()
def matcher_check_fn():
self.assertEqual(
counters["inductor"][
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
],
2,
)
reshape_linear_reshape_match_nodes = 3 if len(input_shape) == 3 else 0
self.assertEqual(
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"],
reshape_linear_reshape_match_nodes,
)
self.assertEqual(
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1
)
# addmm(mm) + (linear+add)
match_count = 2
match_nodes = 3
if len(input_shape) == 3:
is_inplace = binary_list[binary_fn][2]
# view + linear + view(joint_graph+freeze pass)
match_count = match_count + 5 if is_inplace else match_count + 3
match_nodes = match_nodes + 8 if is_inplace else match_nodes + 5
mod = M(binary_fn, input_shape[-1], out_feature, bias).eval()
v = torch.randn(input_shape)
other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype)
self._test_common(
mod,
(
v,
other,
),
matcher_check_fn,
match_count,
match_nodes,
check_autocast=dtype,
)
self.assertEqual(metrics.generated_kernel_count, 1)
@ -872,25 +837,18 @@ class TestPatternMatcher(TestPatternMatcherBase):
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], 7
)
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"], 6
)
self.assertEqual(
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
)
for dtype in dtypes:
mod = M().to(dtype).eval()
v = torch.randn(2, 4, 16).to(dtype)
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
# 1. view(match_count=4, match_nodes=4).
# 2. mm to packed linear(match_count=2, match_nodes=2).
# 3. view+linear+view to linear(match_count=2, match_nodes=6).
# 4. linear+silu fusion(match_count=1, match_nodes=5)
# 5. linear+relu fusion(match_count=1, match_nodes=2)
match_count = 10
match_nodes = 19
self._test_common(mod, (v,), match_count, match_nodes, rtol=1e-2, atol=1e-2)
def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False):
class M(torch.nn.Module):
@ -928,9 +886,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1220,9 +1178,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
matcher_check_fn=matcher_check_fn,
)
def _qconv2d_add_cpu_test_helper2(self, use_relu=False, int8_mixed_bf16=False):
@ -1304,9 +1262,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(x, x2, x3),
matcher_check_fn,
check_quantization=True,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1370,8 +1328,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(x1, x2),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1423,8 +1381,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1468,8 +1426,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1542,8 +1500,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1586,9 +1544,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
is_qat=True,
matcher_check_fn=matcher_check_fn,
)
def _qat_qconv2d_unary_cpu_test_helper(
@ -1628,9 +1586,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
is_qat=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1725,9 +1683,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
is_qat=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1784,9 +1742,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
is_qat=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1845,8 +1803,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
def _qlinear_cpu_test_helper(
@ -1881,13 +1839,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
inputs,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
check_quantization=True,
matcher_check_fn=(
matcher_check_fn
if matcher_check_fn is not None
else _default_matcher_check_fn
),
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
check_quantization=True,
is_qat=is_qat,
is_dynamic=is_dynamic,
)
@ -2051,9 +2009,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
inputs,
matcher_check_fn,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -2222,9 +2180,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
matcher_check_fn=matcher_check_fn,
is_qat=is_qat,
is_dynamic=is_dynamic,
)
@ -2316,13 +2274,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
inputs,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
check_quantization=True,
matcher_check_fn=(
matcher_check_fn
if matcher_check_fn is not None
else default_matcher_check_fn
),
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
check_quantization=True,
is_dynamic=is_dynamic,
)
@ -2461,8 +2419,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(x1, x2),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -2508,8 +2466,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -2547,8 +2505,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -2596,8 +2554,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
# https://github.com/pytorch/pytorch/issues/99841.
@ -2619,18 +2577,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
min_values = [3, torch.randn(1, 32, 28, 28)]
max_values = [0, torch.randn(1, 32, 28, 28)]
v = torch.randn(1, 3, 28, 28)
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], 3
)
self.assertEqual(
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
)
for min_value, max_value in zip(min_values, max_values):
mod = Model().eval()
self._test_common(mod, (v, min_value, max_value), matcher_check_fn)
self._test_common(mod, (v, min_value, max_value), 2, 4)
def test_leaky_relu_pattern_fallback(self):
class Model(torch.nn.Module):
@ -2645,20 +2594,11 @@ class TestPatternMatcher(TestPatternMatcherBase):
return torch.where(conv_out > 0, conv_out, conv_out * negative_slope)
negative_slopes = [0.1, torch.randn(1, 32, 28, 28)]
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], 4
)
self.assertEqual(
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
)
with torch.no_grad():
v = torch.randn(1, 3, 28, 28)
for negative_slope in negative_slopes:
mod = Model().eval()
self._test_common(mod, (v, negative_slope), matcher_check_fn)
self._test_common(mod, (v, negative_slope), 2, 5)
# https://github.com/pytorch/pytorch/issues/99838.
def test_conv2d_add_scalar(self):
@ -2674,16 +2614,10 @@ class TestPatternMatcher(TestPatternMatcherBase):
out = torch.add(out_conv, 1.0)
return out
def matcher_check_fn():
self.assertEqual(counters["inductor"]["binary_folding"], 1)
self.assertEqual(
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
)
with torch.no_grad():
mod = Model().eval()
v = torch.randn(1, 3, 28, 28)
self._test_common(mod, (v,), matcher_check_fn)
self._test_common(mod, (v,), 2, 3)
def test_conv2d_binary_inplace_fusion_pass_cpu(
self, include_ops=None, exclude_ops=None
@ -3083,7 +3017,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(x, w, s),
matcher_check_fn,
matcher_check_fn=matcher_check_fn,
check_quantization=False,
atol=0.001,
rtol=0.07,
@ -3118,11 +3052,7 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
x_shape = (1, 3, 28, 28)
mod = M().eval()
v = torch.randn(x_shape, dtype=torch.float32)
def matcher_check_fn():
return
self._test_common(mod, (v,), matcher_check_fn)
self._test_common(mod, (v,), 0, 0)
def test_multi_linear_share_same_input_dynamic(self):
# llama pattern.
@ -3142,28 +3072,18 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], 7
)
self.assertEqual(
counters["inductor"]["mkldnn_unary_fusion_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"], 6
)
self.assertEqual(
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
)
for dtype in dtypes:
mod = M().to(dtype).eval()
v = torch.randn(2, 4, 16).to(dtype)
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
# 1. view(match_count=4, match_nodes=4).
# 2. mm to packed linear(match_count=2, match_nodes=2).
# 3. view+linear+view to linear(match_count=2, match_nodes=6).
# 4. linear to linear+swish(match_count=1, match_nodes=2).
# 5. linear to linear+relu(match_count=1, match_nodes=5).
match_count = 10
match_nodes = 19
self._test_common(mod, (v,), match_count, match_nodes, rtol=1e-2, atol=1e-2)
def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None):
r"""
@ -3241,9 +3161,9 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
is_qat=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -3319,8 +3239,8 @@ class TestDynamicPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
matcher_check_fn,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
quantizer=quantizer,
)

View File

@ -5,7 +5,6 @@ from functools import reduce
from typing import Any, Tuple
import torch
from torch._dynamo.utils import counters
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from .. import ir
@ -250,10 +249,6 @@ if torch._C._has_mkldnn:
unary_attr.scalars_attr,
unary_attr.algorithm_attr,
]
counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len(
match.nodes
)
return L[computation_op](*computation_args)
return fn
@ -277,10 +272,6 @@ if torch._C._has_mkldnn:
)
matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
computation_args = list(args)
counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len(
match.nodes
)
if matched:
computation_args = computation_args[:-3] + [
"leaky_relu",
@ -327,10 +318,6 @@ if torch._C._has_mkldnn:
)
matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
computation_args = list(args)
counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len(
match.nodes
)
if matched:
computation_args = computation_args[:-3] + [
"hardtanh",
@ -511,10 +498,6 @@ if torch._C._has_mkldnn:
]
else:
computation_args += [1.0, None, [], None]
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
counters["inductor"][
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
] += len(match.nodes)
return L[fusion_op](*computation_args)
return fn
@ -559,10 +542,6 @@ if torch._C._has_mkldnn:
]
else:
computation_args += [1.0, None, [], None]
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
counters["inductor"][
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
] += len(match.nodes)
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
other.realize()
if not _can_be_inplace(other) or other.data.shape != list(
@ -816,10 +795,6 @@ if torch._C._has_mkldnn:
graph.erase_node(old_linear_node)
if len(reshape_1_node.users) == 0:
graph.erase_node(reshape_1_node)
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_count"] += 1
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"] += len(
match.nodes
)
def is_linear_add_bias(match):
add_node = match.output_node()
@ -870,8 +845,6 @@ if torch._C._has_mkldnn:
repl.meta.update(add_node.meta)
add_node.replace_all_uses_with(repl)
match.erase_nodes()
counters["inductor"]["mkldnn_linear_bias_matcher_count"] += 1
counters["inductor"]["mkldnn_linear_bias_matcher_nodes"] += len(match.nodes)
def _is_packable_mkldnn_rnn_layer(match):
lstm_node = match.output_node()
@ -1110,10 +1083,6 @@ if torch._C._has_mkldnn:
conv_node.replace_all_uses_with(packed_conv_node)
packed_conv_node.meta.update(conv_node.meta)
graph.erase_node(conv_node)
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"] += 1
counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"] += len(
match.nodes
)
@register_freezing_graph_pattern(
CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args),
@ -1165,10 +1134,6 @@ if torch._C._has_mkldnn:
lstm_node.replace_all_uses_with(packed_lstm_node)
packed_lstm_node.meta.update(lstm_node.meta)
graph.erase_node(lstm_node)
counters["inductor"]["mkldnn_rnn_weight_pack_matcher_count"] += 1
counters["inductor"]["mkldnn_rnn_weight_pack_matcher_nodes"] += len(
match.nodes
)
@register_freezing_graph_pattern(
CallFunction(
@ -1257,10 +1222,6 @@ if torch._C._has_mkldnn:
linear_node.replace_all_uses_with(packed_linear_node)
packed_linear_node.meta.update(linear_node.meta)
graph.erase_node(linear_node)
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"] += 1
counters["inductor"]["mkldnn_linear_weight_pack_matcher_nodes"] += len(
match.nodes
)
def _eliminate_duplicate_packed_nodes(gm):
"""