diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 320c51087b25..cda3ca1c42d2 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -14,6 +14,7 @@ import torch._dynamo.config as dynamo_config import torch._inductor.config as inductor_config import torch._inductor.select_algorithm as select_algorithm from torch._dynamo.utils import counters +from torch._inductor import test_operators from torch._inductor.cpu_vec_isa import VecAMX from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_device_type import ( @@ -540,6 +541,137 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm): self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (8,)) + @parametrize("in_features", (128,)) + @parametrize("size_0", (4,)) + @parametrize("size_1", (14,)) + @parametrize("out_features", (512,)) + @parametrize("out_features_conv", (256,)) + @parametrize( + "bias", + ( + False, + True, + ), + ) + @parametrize( + "epilogue", + ( + False, + True, + ), + ) + @dtypes(torch.float32) + def test_linear_unsupported_epilogue_fusion( + self, + batch_size, + in_features, + size_0, + size_1, + out_features, + out_features_conv, + bias, + epilogue, + dtype, + ): + img_size_0 = int(size_0 * size_0) + img_size_1 = int(size_1 * size_1) + conv_shape = int(size_0 * size_1) + flatten_BS = int(batch_size * size_0 * size_0 * size_1 * size_1) + + # Reproducer from the jx_nest_base model in timm + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear1 = torch.nn.Linear(in_features, in_features, bias=bias) + self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias) + self.conv = torch.nn.Conv2d( + in_features, + out_features_conv, + kernel_size=3, + padding=1, + stride=1, + dilation=1, + groups=1, + ) + self.epilogue = epilogue + + def forward(self, mul_239, view_425, add_184): + _mkl_linear_91 = self.linear1(view_425) + view_426 = torch.ops.aten.reshape.default( + _mkl_linear_91, [batch_size, img_size_0, img_size_1, in_features] + ) + _mkl_linear_91 = None + add_187 = torch.ops.aten.add.Tensor(add_184, view_426) + add_184 = view_426 = None + view_429 = torch.ops.aten.reshape.default( + mul_239, [flatten_BS, out_features] + ) + mul_239 = None + + _mkl_linear_89 = self.linear2(view_429) + if self.epilogue: + _mkl_linear_89 = torch.pow(_mkl_linear_89, 2) + _mkl_linear_89 = test_operators.realize(_mkl_linear_89) + + view_430 = torch.ops.aten.reshape.default( + _mkl_linear_89, [batch_size, img_size_0, img_size_1, in_features] + ) + _mkl_linear_89 = None + + add_191 = torch.ops.aten.add.Tensor(add_187, view_430) + add_187 = view_430 = None + + view_431 = torch.ops.aten.reshape.default( + add_191, [batch_size, size_0, size_0, size_1, size_1, in_features] + ) + add_191 = None + permute_203 = torch.ops.aten.permute.default( + view_431, [0, 1, 3, 2, 4, 5] + ) + view_431 = None + clone_188 = torch.ops.aten.clone.default( + permute_203, memory_format=torch.contiguous_format + ) + permute_203 = None + view_432 = torch.ops.aten.reshape.default( + clone_188, [batch_size, conv_shape, conv_shape, in_features] + ) + clone_188 = None + permute_204 = torch.ops.aten.permute.default(view_432, [0, 3, 1, 2]) + view_432 = None + + _convolution_pointwise_default_1 = self.conv(permute_204) + + return _convolution_pointwise_default_1 + + mul_239 = torch.randn(batch_size, img_size_0, img_size_1, out_features) + view_425 = torch.randn(flatten_BS, in_features) + add_184 = torch.randn(batch_size, img_size_0, img_size_1, in_features) + mod = M(bias=bias).eval() + with verify(dtype) as (atol, rtol), torch.cpu.amp.autocast( + enabled=dtype == torch.bfloat16 + ): + self.common( + mod, + ( + mul_239, + view_425, + add_184, + ), + atol=atol, + rtol=rtol, + ) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) + # TODO: change cpp_epilogue_fusion_counter to 1 once supported + self.assertEqual( + counters["inductor"]["cpp_epilogue_fusion_counter"], 1 if epilogue else 0 + ) + @inductor_config.patch({"freezing": True}) @patches @torch.no_grad diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index ea1cc05f940b..92fa2f46492a 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -71,6 +71,7 @@ from .cpp_utils import ( INDEX_TYPE, LocalBufferContext, promote_args, + template_fusion_with_epilogues_supported, unify_mask_base_type, value_to_cpp, ) @@ -4148,7 +4149,10 @@ class CppScheduling(BaseScheduling): # TODO(jgong5): support pre-op fusion with template return False if node1.is_template(): - return not node2.is_reduction() + template_fusion_supported, _ = template_fusion_with_epilogues_supported( + node1, [node2] + ) + return not node2.is_reduction() and template_fusion_supported return ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index a68549d4a9b3..0e756e01f756 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -5,18 +5,21 @@ import functools import math import sys from collections import namedtuple -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch import sympy import torch from torch._prims_common import is_integer_dtype +from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import symbol_is_type, SymT from torch.utils._sympy.value_ranges import ValueRanges from .. import ir +from ..dependencies import Dep from ..loop_body import LoopBody +from ..scheduler import BaseSchedulerNode, SchedulerBuffer from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs from ..virtualized import ops, OpsValue, V from .common import ( @@ -916,3 +919,72 @@ def _get_dtype_from_loopbodies(loop_bodies): continue dtypes.add(node.meta[OptimizationContext.key].dtype) return dtypes + + +def template_fusion_with_epilogues_supported( + template: BaseSchedulerNode, epilogues: List[BaseSchedulerNode] +) -> Tuple[bool, bool]: + def _get_indexes_of_template_buf_read( + epilogue_node: ir.Operation, template_buf_names: List[str] + ) -> List[sympy.Expr]: + return [ + read.index + for read in epilogue_node.get_reads() + if read.name in template_buf_names + ] + + def _check_supported_and_same_indexes( + index_of_template_buf_read: sympy.Expr, epilogue_writes: OrderedSet[Dep] + ) -> Tuple[bool, bool]: + num_indexes = len(set(index_of_template_buf_read)) + + if num_indexes > 1: + same_index = False + supported = False # Different read indexes not supported + elif num_indexes == 0: + same_index = True + supported = True # No reads, automatically supported + elif num_indexes == 1: + index_of_template_buf_read = index_of_template_buf_read[0] + same_index = all( + write.index == index_of_template_buf_read for write in epilogue_writes + ) + # TODO: Add support of fusion when the read of template buffer and the write of epilogue output + # in the epilogue node don't have the same index and change supported to True + supported = same_index + else: + raise AssertionError("Should not reach here") + + return supported, same_index + + def _template_fusion_supported( + template_outputs: Sequence[SchedulerBuffer], epilogue_nodes: List[ir.Operation] + ) -> Tuple[bool, bool]: + template_buf_names = [x.get_name() for x in template_outputs] + indexes_of_template_buf_reads = [ + _get_indexes_of_template_buf_read(epilogue_node, template_buf_names) + for epilogue_node in epilogue_nodes + ] + epilogue_nodes_writes = [ + epilogue_node.get_read_writes().writes for epilogue_node in epilogue_nodes + ] + + results = [ + _check_supported_and_same_indexes(reads, writes) + for reads, writes in zip( + indexes_of_template_buf_reads, epilogue_nodes_writes + ) + ] + supported, same_indexes = zip(*results) + return all(supported), all(same_indexes) + + assert template.is_template() + template_outputs = template.get_outputs() + + epilogue_nodes = [ + n.node + for epilogue in epilogues + for n in epilogue.get_nodes() + if n.node is not None + ] + return _template_fusion_supported(template_outputs, epilogue_nodes)