[inductor][templates] Distinguish between kernel input nodes and codegen input nodes (#163752)

If there is a single autotuner choice, the wrong type of input node is used to instantiate `TritonTemplateBuffer` through `TritonTemplateCaller.output_node`. This PR distinguishes the input nodes used in `AlgorithmSelectorCache.__call__` between the actual inputs passed to the kernel at runtime, vs the possibly viewed inputs that influence scheduling behaviour (e.g. `MemoryDeps`) and codegen. See the added unit test for more detail.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163752
Approved by: https://github.com/eellison
This commit is contained in:
Mwiza Kunda
2025-10-08 14:12:14 +00:00
committed by PyTorch MergeBot
parent 322091d8d8
commit 2b58adc3bd
2 changed files with 82 additions and 3 deletions

View File

@ -3,6 +3,7 @@ import contextlib
import functools
import unittest.mock
from collections.abc import Callable
from typing import Any, Optional, Union
from unittest.mock import patch
import torch
@ -14,14 +15,19 @@ from torch._dynamo.testing import expectedFailureDynamicWrapper
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.autotune_process import TritonBenchmarkRequest
from torch._inductor.choices import InductorChoices
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import FixedLayout
from torch._inductor.kernel_inputs import KernelInputs
from torch._inductor.select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
TritonTemplate,
TritonTemplateKernel,
)
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import is_big_gpu, run_and_get_kernels
from torch._inductor.virtualized import V
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu
from torch.testing._internal.inductor_utils import (
@ -393,6 +399,68 @@ class TestSelectAlgorithm(TestCase):
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
@torch._inductor.config.patch(
{"conv_1x1_as_mm": True, "max_autotune_gemm_backends": "TRITON"}
)
def test_convolution_as_mm_triton_only(self):
# To convert the 1x1 conv to matmul, x is converted to a channels last
# tensor and the channels dimension is permuted to be innermost. This
# prologue should not be fused with the matmul since the prologue writes
# discontiguously, whilst the mm template currently only supports reading
# the input contiguously.
#
# Before the change associated with this PR, fusion would occur because the actual kernel
# input nodes (which don't include views e.g. permute) would be passed to the
# `TritonTemplateCaller` rather than the input nodes that include views.
# For example after x is converted to channels last, its layout is shape @ stride
# [2, 33, 16, 16] @ [8432, 1, 528, 33], or [2, 33, 256] @ [8432, 1, 33], and the
# prologue writes this value discontiguously.
# After the permute, the mm template fixes the layout to [512, 33] @ [33, 1] and
# reads the input contiguously. If the kernel input node for x is passed to the
# `TritonTemplateCaller`, then the scheduler will fuse the prologue since the
# write is compatible with the read. If however the viewed input is passed
# to `TritonTemplateCaller`, then the write won't be compatible with the read,
# and the prologue won't be fused.
def foo(x, w, b):
return aten.convolution(
x + 1,
w,
b,
stride=(1, 1),
padding=(0, 0),
dilation=(1, 1),
transposed=False,
output_padding=(0, 0),
groups=1,
)
x = torch.randn(2, 33, 16, 16, device=GPU_TYPE)
w = torch.randn(34, 33, 1, 1, device=GPU_TYPE)
b = torch.randn(34, device=GPU_TYPE)
class SingleMMConfigChoice(InductorChoices):
def get_template_configs(
self,
kernel_inputs: KernelInputs,
templates: list[Union[KernelTemplate, ExternKernelChoice]],
op_name: str,
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
):
return super().get_template_configs(
kernel_inputs, templates, op_name, kwarg_overrides
)[:1]
with V.set_choices_handler(SingleMMConfigChoice()):
result_compile = torch.compile(foo)(x, w, b)
result_eager = foo(x, w, b)
# If the prologue has been fused this should fail
torch.testing.assert_close(result_compile, result_eager)
# There should not be any autotuning
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
@patches
@torch._inductor.config.patch(conv_1x1_as_mm=False)
def test_convolution2_group(self):