mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][templates] Distinguish between kernel input nodes and codegen input nodes (#163752)
If there is a single autotuner choice, the wrong type of input node is used to instantiate `TritonTemplateBuffer` through `TritonTemplateCaller.output_node`. This PR distinguishes the input nodes used in `AlgorithmSelectorCache.__call__` between the actual inputs passed to the kernel at runtime, vs the possibly viewed inputs that influence scheduling behaviour (e.g. `MemoryDeps`) and codegen. See the added unit test for more detail. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163752 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
322091d8d8
commit
2b58adc3bd
@ -3,6 +3,7 @@ import contextlib
|
||||
import functools
|
||||
import unittest.mock
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -14,14 +15,19 @@ from torch._dynamo.testing import expectedFailureDynamicWrapper
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.autotune_process import TritonBenchmarkRequest
|
||||
from torch._inductor.choices import InductorChoices
|
||||
from torch._inductor.codegen.common import KernelTemplate
|
||||
from torch._inductor.ir import FixedLayout
|
||||
from torch._inductor.kernel_inputs import KernelInputs
|
||||
from torch._inductor.select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
TritonTemplate,
|
||||
TritonTemplateKernel,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import is_big_gpu, run_and_get_kernels
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
@ -393,6 +399,68 @@ class TestSelectAlgorithm(TestCase):
|
||||
# Autotuning checks correctness of each version
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patches
|
||||
@torch._inductor.config.patch(
|
||||
{"conv_1x1_as_mm": True, "max_autotune_gemm_backends": "TRITON"}
|
||||
)
|
||||
def test_convolution_as_mm_triton_only(self):
|
||||
# To convert the 1x1 conv to matmul, x is converted to a channels last
|
||||
# tensor and the channels dimension is permuted to be innermost. This
|
||||
# prologue should not be fused with the matmul since the prologue writes
|
||||
# discontiguously, whilst the mm template currently only supports reading
|
||||
# the input contiguously.
|
||||
#
|
||||
# Before the change associated with this PR, fusion would occur because the actual kernel
|
||||
# input nodes (which don't include views e.g. permute) would be passed to the
|
||||
# `TritonTemplateCaller` rather than the input nodes that include views.
|
||||
# For example after x is converted to channels last, its layout is shape @ stride
|
||||
# [2, 33, 16, 16] @ [8432, 1, 528, 33], or [2, 33, 256] @ [8432, 1, 33], and the
|
||||
# prologue writes this value discontiguously.
|
||||
# After the permute, the mm template fixes the layout to [512, 33] @ [33, 1] and
|
||||
# reads the input contiguously. If the kernel input node for x is passed to the
|
||||
# `TritonTemplateCaller`, then the scheduler will fuse the prologue since the
|
||||
# write is compatible with the read. If however the viewed input is passed
|
||||
# to `TritonTemplateCaller`, then the write won't be compatible with the read,
|
||||
# and the prologue won't be fused.
|
||||
def foo(x, w, b):
|
||||
return aten.convolution(
|
||||
x + 1,
|
||||
w,
|
||||
b,
|
||||
stride=(1, 1),
|
||||
padding=(0, 0),
|
||||
dilation=(1, 1),
|
||||
transposed=False,
|
||||
output_padding=(0, 0),
|
||||
groups=1,
|
||||
)
|
||||
|
||||
x = torch.randn(2, 33, 16, 16, device=GPU_TYPE)
|
||||
w = torch.randn(34, 33, 1, 1, device=GPU_TYPE)
|
||||
b = torch.randn(34, device=GPU_TYPE)
|
||||
|
||||
class SingleMMConfigChoice(InductorChoices):
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
templates: list[Union[KernelTemplate, ExternKernelChoice]],
|
||||
op_name: str,
|
||||
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
|
||||
):
|
||||
return super().get_template_configs(
|
||||
kernel_inputs, templates, op_name, kwarg_overrides
|
||||
)[:1]
|
||||
|
||||
with V.set_choices_handler(SingleMMConfigChoice()):
|
||||
result_compile = torch.compile(foo)(x, w, b)
|
||||
result_eager = foo(x, w, b)
|
||||
|
||||
# If the prologue has been fused this should fail
|
||||
torch.testing.assert_close(result_compile, result_eager)
|
||||
|
||||
# There should not be any autotuning
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
|
||||
|
||||
@patches
|
||||
@torch._inductor.config.patch(conv_1x1_as_mm=False)
|
||||
def test_convolution2_group(self):
|
||||
|
Reference in New Issue
Block a user