Files
pytorch/test/inductor/test_select_algorithm.py
Yuanyuan Chen a8c528c105 [1/N] Apply UP035 rule in tests (#163947)
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947
Approved by: https://github.com/ezyang
2025-09-29 01:42:01 +00:00

562 lines
18 KiB
Python

# Owner(s): ["module: inductor"]
import contextlib
import functools
import unittest.mock
from collections.abc import Callable
from unittest.mock import patch
import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
import torch._inductor.select_algorithm as select_algorithm
import torch.nn.functional as F
from torch._dynamo.testing import expectedFailureDynamicWrapper
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.autotune_process import TritonBenchmarkRequest
from torch._inductor.ir import FixedLayout
from torch._inductor.select_algorithm import (
autotune_select_algorithm,
TritonTemplate,
TritonTemplateKernel,
)
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import is_big_gpu, run_and_get_kernels
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
requires_gpu,
requires_triton,
)
aten = torch.ops.aten
def patches(fn):
def skip_cache(self, choices, name, key, benchmark, hint_override=None):
if benchmark is None:
return {}
return benchmark(choices)
for patcher in [
dynamo_config.patch(verbose=True),
inductor_config.patch(debug=True, max_autotune=True, epilogue_fusion=True),
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
torch.backends.cudnn.flags(allow_tf32=False),
]:
fn = patcher(fn)
@functools.wraps(fn)
def wrapped(*args, **kwargs):
counters.clear()
torch.manual_seed(12345)
assert not torch.backends.cuda.matmul.allow_tf32, (
"correctness testing is allergic to tf32"
)
return fn(*args, **kwargs)
return wrapped
class TestSelectAlgorithm(TestCase):
def setUp(self):
super().setUp()
if not is_big_gpu():
return self.skipTest("Need a big GPU to run max_autotune=True")
# Clear preprocessing functions to ensure clean state
select_algorithm.clear_preprocessing_fns()
@patches
def test_linear_relu(self):
@torch.compile
def foo(input, weight, bias):
return F.relu(F.linear(input, weight, bias))
foo(
torch.randn(64, 32, device=GPU_TYPE),
torch.randn(16, 32, device=GPU_TYPE),
torch.randn(1, 16, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
# It would be nice to assert this got fused into a single kernel, but that
# only happens if we select a triton template (and not aten).
@patches
def test_addmm(self):
@torch.compile
def foo(input, weight, bias):
return torch.addmm(bias, input, weight)
inps = (
torch.randn(20, 33, device=GPU_TYPE),
torch.randn(33, 16, device=GPU_TYPE),
torch.randn(20, 16, device=GPU_TYPE),
)
foo(*inps)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_preprocessing_single_choice(self):
# pass a list to the preprocessing function to assert that it was
# actually called
func_called = [False]
# Register a preprocessing function that returns only the first choice
# This in turn will lead to autotuning being skipped as it's a single
# choice, and the counter itself will not be bumped
def return_first_choice_only(choices):
func_called[0] = True
return choices[:1] if choices else []
select_algorithm.add_preprocessing_fn(return_first_choice_only)
@torch.compile
def foo(input, weight, bias):
return torch.addmm(bias, input, weight)
inps = (
torch.randn(20, 33, device=GPU_TYPE),
torch.randn(33, 16, device=GPU_TYPE),
torch.randn(20, 16, device=GPU_TYPE),
)
foo(*inps)
# Since we only have one choice, autotuning should be skipped
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
# The preprocessing function should have been called
self.assertTrue(func_called[0])
@patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2))
@patches
def test_addmm_fp16(self):
@torch.compile
def foo(input, weight, bias):
return torch.addmm(bias, input, weight)
inps = (
torch.randn(2, 320, device=GPU_TYPE, dtype=torch.half),
torch.randn(320, 320, device=GPU_TYPE, dtype=torch.half).t(),
torch.empty(320, device=GPU_TYPE, dtype=torch.half),
)
foo(*inps)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(8, 32, device=GPU_TYPE),
torch.randn(32, 8, device=GPU_TYPE),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test__int_mm(self):
@torch.compile
def foo(a, b):
return torch._int_mm(a, b)
foo(
torch.randint(-10, 10, (64, 32), device=GPU_TYPE, dtype=torch.int8),
torch.randint(-10, 10, (32, 64), device=GPU_TYPE, dtype=torch.int8),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
@skipIfXpu(msg="Double datatype matmul is not supported in oneDNN")
def test_mm_skip(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(8, 32, device=GPU_TYPE, dtype=torch.float64),
torch.randn(32, 8, device=GPU_TYPE, dtype=torch.float64),
)
# float64 not supported by tl.dot()
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
@patches
def test_bmm(self):
@torch.compile
def foo(a, b):
return torch.bmm(a, b)
foo(
torch.randn(2, 8, 32, device=GPU_TYPE),
torch.randn(2, 32, 8, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_not_even_k(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(11, 22, device=GPU_TYPE),
torch.randn(22, 33, device=GPU_TYPE),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_baddbmm(self):
@torch.compile
def foo(a, b, c):
return torch.baddbmm(c, a, b)
foo(
torch.randn(2, 8, 32, device=GPU_TYPE),
torch.randn(2, 32, 8, device=GPU_TYPE),
torch.randn(2, 1, 8, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_plus_mm(self):
@torch.compile
def foo(a, b, c, d):
return (a @ b) + (c @ d)
foo(
torch.randn(32, 32, device=GPU_TYPE),
torch.randn(32, 32, device=GPU_TYPE),
torch.randn(32, 32, device=GPU_TYPE),
torch.randn(32, 32, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
# TODO: fix accuracy failure of the triton template on XPU.
# and enable this test case.
@skipIfXpu
@patches
def test_mm_plus_mm2(self):
@torch.compile
def foo(a, b, c, d):
return (a @ b) + (c @ d)
foo(
torch.randn(512, 512, device=GPU_TYPE),
torch.randn(512, 512, device=GPU_TYPE),
torch.randn(512, 512, device=GPU_TYPE),
torch.randn(512, 512, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@expectedFailureDynamicWrapper
@patches
def test_mm_plus_mm3(self):
@torch.compile
def foo(a, b, c, d):
return (a @ b) + (c @ d)
foo(
torch.randn(512, 32, device=GPU_TYPE),
torch.randn(32, 8, device=GPU_TYPE),
torch.randn(512, 32, device=GPU_TYPE),
torch.randn(32, 8, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_dup_args(self):
@torch.compile
def foo(a):
return torch.mm(a, a)
foo(torch.randn(32, 32, device=GPU_TYPE))
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_dup_args_view(self):
@torch.compile
def foo(a):
q = a[:32, :]
k = a[32:, :]
return torch.mm(q, k.transpose(0, 1))
foo(torch.randn(64, 64, device=GPU_TYPE))
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@expectedFailureDynamicWrapper
@patches
def test_convolution1(self):
@torch.compile
def foo(x, w, b):
return aten.convolution(
x + 1,
w,
b,
stride=(2, 3),
padding=(4, 5),
dilation=(1, 1),
transposed=False,
output_padding=(0, 0),
groups=1,
)
foo(
torch.randn(2, 33, 34, 41, device=GPU_TYPE),
torch.randn(34, 33, 3, 3, device=GPU_TYPE),
torch.randn(34, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@skipIfRocm
@patches
def test_mm_dropout(self):
@torch.compile
def fn(x1, x2, seed):
mm_4 = torch.ops.aten.mm.default(x2, x1)
rnd = torch.ops.prims.inductor_random.default(mm_4.shape, seed, "rand")
return mm_4 * rnd
if GPU_TYPE == "xpu":
patcher = patch.object(
select_algorithm, "VERIFY", dict(atol=1e-3, rtol=1e-3)
)
fn = patcher(fn)
# sizes picked so triton autotuning wins
fn(
torch.randn(512, 1024, dtype=torch.float16, device=GPU_TYPE),
torch.randn(384, 512, dtype=torch.float16, device=GPU_TYPE),
torch.tensor(12345, device=GPU_TYPE),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
@torch._inductor.config.patch(conv_1x1_as_mm=False)
def test_convolution2(self):
@torch.compile
def foo(x, w, b):
return aten.convolution(
x,
w,
b,
stride=(1, 1),
padding=(0, 0),
dilation=(1, 1),
transposed=False,
output_padding=(0, 0),
groups=1,
)
foo(
torch.randn(1, 33, 16, 16, device=GPU_TYPE),
torch.randn(34, 33, 1, 1, device=GPU_TYPE),
torch.randn(34, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
@torch._inductor.config.patch(conv_1x1_as_mm=True)
def test_convolution_as_mm(self):
@torch.compile
def foo(x, w, b):
return aten.convolution(
x + 1,
w,
b,
stride=(1, 1),
padding=(0, 0),
dilation=(1, 1),
transposed=False,
output_padding=(0, 0),
groups=1,
)
foo(
torch.randn(2, 33, 16, 16, device=GPU_TYPE),
torch.randn(34, 33, 1, 1, device=GPU_TYPE),
torch.randn(34, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
@torch._inductor.config.patch(conv_1x1_as_mm=False)
def test_convolution2_group(self):
@torch.compile
def foo(x, w, b):
return aten.convolution(
x,
w,
b,
stride=(1, 1),
padding=(1, 1),
dilation=(1, 1),
transposed=False,
output_padding=(0, 0),
groups=32, # group is not 1
)
foo(
torch.randn(1, 32, 16, 16, device=GPU_TYPE),
torch.randn(32, 1, 3, 3, device=GPU_TYPE),
torch.randn(32, device=GPU_TYPE),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
def test_TritonTemplateCaller_str(self):
"""
Make sure str(TritonTemplateCaller) does not raise exceptions.
"""
module_path = "abc.py"
bmreq = TritonBenchmarkRequest(
module_path=module_path,
module_cache_key=None,
kernel_name=None,
extra_args=None,
num_stages=None,
num_warps=None,
num_consumer_groups=None,
num_buffers_warp_spec=None,
input_tensor_meta=None,
output_tensor_meta=None,
)
caller = select_algorithm.TritonTemplateCaller(
None, None, None, None, "extra", bmreq
)
caller_str = str(caller)
self.assertEqual(caller_str, f"TritonTemplateCaller({module_path}, extra)")
@contextlib.contextmanager
def patch_lowering(lowering_overrides) -> Callable[[], None]:
import torch._inductor.lowering as inductor_lowering
with unittest.mock.patch.dict(inductor_lowering.lowerings):
for fn, (
decomp_fn,
broadcast,
type_promotion_kind,
convert_input_to_bool,
) in lowering_overrides.items():
inductor_lowering._register_lowering(
fn,
decomp_fn,
broadcast=broadcast,
type_promotion_kind=type_promotion_kind,
convert_input_to_bool=convert_input_to_bool,
lowering_dict=inductor_lowering.lowerings,
)
yield
class TestTemplateRender(TestCase):
@requires_gpu()
@requires_triton()
@config.patch(cuda_backend="triton")
def test_finalized_subclass_hooks(self):
"""
Tests that all registered triton template hooks have been finalized,
especially in the case that the hooks are finalized manually by the
caller i.e. by calling template.finalize_hook(hook_name)
"""
hook_identifier = "# CUSTOM_HOOK"
class ExtensionTritonTemplateKernel(TritonTemplateKernel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_extra_template_env_fns(
self.custom_hook,
)
def custom_hook(self) -> str:
"""
Custom hook that just returns a test string for validation
"""
def hook() -> str:
return hook_identifier
return self._register_hook("<CUSTOM_HOOK>", hook)
def inductor_meta_common(self):
return super().inductor_meta_common()
class ExtensionTritonTemplate(TritonTemplate):
kernel_type = ExtensionTritonTemplateKernel
add_template = ExtensionTritonTemplate(
name="add",
grid=lambda *args, **kwargs: (1, 1, 1),
source=(
r"""
{{def_kernel("A", "B")}}
{{custom_hook()}}
xoffset = tl.program_id(0)
xindex = xoffset + tl.arange(0, XBLOCK)
xmask = tl.full([XBLOCK], True, tl.int1)
tmp0 = tl.load(A + xindex)
tmp1 = tl.load(B + xindex)
tmp2 = tmp0 + tmp1
{{store_output(("xindex",), "tmp2", mask="xmask")}}
"""
),
)
XBLOCK = 32
def add_override(a, b, alpha=None):
layout = FixedLayout(a.get_device(), a.get_dtype(), a.get_size())
choices = []
add_template.maybe_append_choice(
choices,
input_nodes=(a, b),
layout=layout,
num_stages=1,
num_warps=2,
XBLOCK=XBLOCK,
)
return autotune_select_algorithm("add", choices, [a, b], layout)
with patch_lowering(
{
torch.ops.aten.add.Tensor: (
add_override,
True,
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
False,
)
}
):
@torch.compile
def add(a, b):
return a + b
a = torch.zeros((XBLOCK,), device=GPU_TYPE)
b = torch.zeros((XBLOCK,), device=GPU_TYPE)
_result, kernels = run_and_get_kernels(add, a, b)
assert len(kernels) == 1
assert hook_identifier in kernels[0]
if __name__ == "__main__":
if IS_LINUX and HAS_GPU and is_big_gpu():
run_tests()