[Inductor] add flag for linear binary folding and turn it off by default (#142108)

Fix https://github.com/pytorch/pytorch/issues/141755.

Summary:
linear binary folding results in a timm_model(levit_128) accuracy regression, this PR adds flag `enable_linear_binary_folding` for linear binary folding and turn it off by default.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142108
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
Sun, Jiayi
2024-12-04 23:00:19 -08:00
committed by PyTorch MergeBot
parent 67ba79676f
commit 23e2f8ab3a
5 changed files with 12 additions and 3 deletions

View File

@ -209,6 +209,7 @@ class BinaryFoldingTemplate(TestCase):
expect_success=True,
)
@inductor_config.patch({"enable_linear_binary_folding": True})
def test_linear_binary_folding(self):
@torch.no_grad()
def test_linear_fusion(

View File

@ -919,7 +919,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
rtol=rtol,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
@inductor_config.patch({"freezing": True})
@patches

View File

@ -474,7 +474,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
self.assertEqual(
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
)
self.assertEqual(counters["inductor"]["binary_folding"], 2)
self._test_common(
fold_mod,

View File

@ -764,6 +764,11 @@ check_stack_no_cycles_TESTING_ONLY: bool = False
# When True, complex_memory_overlap always reports True
always_complex_memory_overlap_TESTING_ONLY: bool = False
# enable linear binary folding
enable_linear_binary_folding = (
os.environ.get("TORCHINDUCTOR_ENABLE_LINEAR_BINARY_FOLDING", "0") == "1"
)
# config specific to codegen/cpp.py
class cpp:

View File

@ -5,6 +5,7 @@ import itertools
import torch
from ..._dynamo.utils import counters
from .. import config
from ..pattern_matcher import Arg, CallFunction, KeywordArg
from .freezing_patterns import register_binary_folding_pattern
@ -297,7 +298,10 @@ def binary_folding_init():
if computation_node.target == aten.convolution.default:
return _check_conv_and_broadcast_op(computation_node, other)
elif computation_node.target in [aten.addmm.default, aten.mm.default]:
return _check_linear_and_broadcast_op(computation_node, other, has_reshape)
return (
config.enable_linear_binary_folding
and _check_linear_and_broadcast_op(computation_node, other, has_reshape)
)
return False