mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] add flag for linear binary folding and turn it off by default (#142108)
Fix https://github.com/pytorch/pytorch/issues/141755. Summary: linear binary folding results in a timm_model(levit_128) accuracy regression, this PR adds flag `enable_linear_binary_folding` for linear binary folding and turn it off by default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142108 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
67ba79676f
commit
23e2f8ab3a
@ -209,6 +209,7 @@ class BinaryFoldingTemplate(TestCase):
|
||||
expect_success=True,
|
||||
)
|
||||
|
||||
@inductor_config.patch({"enable_linear_binary_folding": True})
|
||||
def test_linear_binary_folding(self):
|
||||
@torch.no_grad()
|
||||
def test_linear_fusion(
|
||||
|
@ -919,7 +919,7 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
||||
rtol=rtol,
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
|
@ -474,7 +474,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
self.assertEqual(
|
||||
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["binary_folding"], 2)
|
||||
|
||||
self._test_common(
|
||||
fold_mod,
|
||||
|
@ -764,6 +764,11 @@ check_stack_no_cycles_TESTING_ONLY: bool = False
|
||||
# When True, complex_memory_overlap always reports True
|
||||
always_complex_memory_overlap_TESTING_ONLY: bool = False
|
||||
|
||||
# enable linear binary folding
|
||||
enable_linear_binary_folding = (
|
||||
os.environ.get("TORCHINDUCTOR_ENABLE_LINEAR_BINARY_FOLDING", "0") == "1"
|
||||
)
|
||||
|
||||
|
||||
# config specific to codegen/cpp.py
|
||||
class cpp:
|
||||
|
@ -5,6 +5,7 @@ import itertools
|
||||
import torch
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import config
|
||||
from ..pattern_matcher import Arg, CallFunction, KeywordArg
|
||||
from .freezing_patterns import register_binary_folding_pattern
|
||||
|
||||
@ -297,7 +298,10 @@ def binary_folding_init():
|
||||
if computation_node.target == aten.convolution.default:
|
||||
return _check_conv_and_broadcast_op(computation_node, other)
|
||||
elif computation_node.target in [aten.addmm.default, aten.mm.default]:
|
||||
return _check_linear_and_broadcast_op(computation_node, other, has_reshape)
|
||||
return (
|
||||
config.enable_linear_binary_folding
|
||||
and _check_linear_and_broadcast_op(computation_node, other, has_reshape)
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
Reference in New Issue
Block a user