Files
pytorch/test/inductor/test_multi_kernel.py
bobrenjc93 5221448574 multi-kernel matmuls based on varying hint sizes (#156628)
The core idea is to generate multiple matmul kernels using different hints for symbolic variables, then select the most appropriate one at runtime for each unique shape we encounter. You can find some early experimentation details in these posts:

https://fb.workplace.com/groups/8940092306109185/posts/9803850776399996/
https://fb.workplace.com/groups/8940092306109185/posts/9695805170537891/
https://fb.workplace.com/groups/257735836456307/posts/906589324904285/

Here’s a graph illustrating the empirically observed worst-case performance if an oracle always selected the least optimal hint for a given runtime size:

![image](https://github.com/user-attachments/assets/6d90ee06-a572-453e-9cba-03006f343301)

This graph illustrates the performance of a hint size of 64 relative to the worst case. Notice that as the runtime sizes increase, the performance gradually approaches the worst case:

![image](https://github.com/user-attachments/assets/85ad49fe-165a-474c-8d03-db2e57654213)

This graph shows the performance of a hint size of 4096 — very poor for small sizes, and also suboptimal for some mid-sized shapes:

![image](https://github.com/user-attachments/assets/adea1106-3bc8-40f3-97b0-20d940fb74f1)

Finally, here’s the graph that motivated this PR. It illustrates the performance when selecting the best of three kernels generated with three different hints — 64, 256, and 4096:

![image](https://github.com/user-attachments/assets/a7cb0ce5-8139-48b1-b5c9-7670e75cbfce)

## How to review this PR

At a high level, this extends @shunting314's multi-kernel abstraction to support varying GEMM choices driven by different hints. A few key points:

1. Unlike reduction kernels, triton template matmuls pass their grid as arguments to the kernel. This PR updates `MultiKernelCall` to support kernels with varying arguments.
2. The `V.graph.sizevars.size_hints` API is extended to accept a `hint_override`, allowing us to substitute the example input’s size hint with a custom value when generating multiple kernels.
3. The choice generation and benchmarking logic is updated to support multiple hint values. One kernel is generated per value in `torch._inductor.config.multi_kernel_hints`, and at runtime, we select the most suitable kernel for the current shape.
4. This PR does not add support for cpp wrapper codegen to keep it scoped. That will be added in the next PR.

## Results

The following is a basic test that shows our basic multi kernel working where we no longer show significant variance based on the original hint size: https://gist.github.com/bobrenjc93/ba711d529e65fd65839b34799f6323ec

Before
```
Hint\Runtime |     64     |    256     |    4096
---------------------------------------------------
     64      |   0.0948   |   0.3124   |   4.9477
    256      |   0.2243   |   0.2256   |   3.3880
    4096     |   0.3384   |   0.3404   |   3.3010
```

After
```
Hint\Runtime |     64     |    256     |    4096
---------------------------------------------------
     64      |   0.0951   |   0.2289   |   3.3013
    256      |   0.0952   |   0.2258   |   3.4045
    4096     |   0.0957   |   0.2231   |   3.3146
```

We also see an average speedup of 5.04% for the matrix of all hint/runtime pairs in [64, 4096] for every increment of 64: https://docs.google.com/spreadsheets/d/12TmYUDrAAFASGuP3POXTKPeAvQWIRzKzdrVSIb3vQkA/edit?gid=480268938#gid=480268938

![Worst Case, multi-kernel](https://github.com/user-attachments/assets/712df23b-87e2-4d9d-95c2-cc25305ba2ed)

NB: This is just the beginning and I plan on doing more investigation to see further improve on this initial result.

For posterity the script used to generate that matrix is here: https://gist.github.com/bobrenjc93/c211fd0bd97fad8f46b91ad9dee76ad0

HUD benchmark runs:
base: https://github.com/pytorch/pytorch/actions/runs/15889871988
head: https://github.com/pytorch/pytorch/actions/runs/15889876842

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156628
Approved by: https://github.com/jansel
2025-07-12 15:08:21 +00:00

361 lines
12 KiB
Python

# Owner(s): ["module: inductor"]
import os
import re
import unittest
import torch
from torch import nn
from torch._dynamo.testing import reset_rng_state
from torch._inductor import config, test_operators
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch.nn import functional as F
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
IS_BIG_GPU,
requires_triton,
)
class TransformerSnippet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.ln1 = nn.LayerNorm(64)
self.ln2 = nn.LayerNorm(64)
def forward(self, x1, x2):
x1 = F.dropout(x1, 0.1)
x2 = F.dropout(self.ln1(x2), 0.1)
return self.ln2(x1 + x2)
def example_inputs(self):
return (torch.randn(2, 64).to(GPU_TYPE), torch.randn(2, 64).to(GPU_TYPE))
def _contains_multi_kernel_code(wrapper_code: str):
return (
re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code)
is not None
)
def make_cpp_wrapper_test(orig_test, **extra_args):
"""
Wrap an existing test into a new test with cpp-wrapper enabled.
Make this as a free function rather than staticmethod in MultiKernelTest.
Otherwise we get 'TypeError: 'staticmethod' object is not callable'
error in py3.8. (py3.10 works)
"""
@config.patch("cpp_wrapper", True)
@skipIfXpu(msg="cpp wrapper doesn't currently work on the XPU stack")
def fn(self):
# The same kernel may have been compiled by previous tests with
# cpp_wrapper disabled. Clear the cache so we go ahead to re-compile
# the kernel with cpp_wrapper enabled.
from torch._inductor import codecache
codecache.PyCodeCache.cache_clear()
return orig_test(self, **extra_args)
return fn
@config.patch(
{
"triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
"benchmark_kernel": True,
"multi_kernel_hints": [64, 256, 4096],
}
)
@instantiate_parametrized_tests
class MultiKernelTest(TestCase):
def test_softmax(self, expect_multi_kernel=True):
x = torch.rand(2, 1024).to(GPU_TYPE)
ref = torch.softmax(x, -1)
compiled_fn = torch.compile(torch.softmax)
act, wrapper_code = run_and_get_code(compiled_fn, x, -1)
# wrapper_code will contains 2 entries if cpp_wrapper=True.
# One for the first pass and one for the second pass.
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertEqual(ref, act)
if expect_multi_kernel:
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
else:
self.assertFalse(_contains_multi_kernel_code(wrapper_code))
@requires_triton()
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
def test_triton_gemm(self):
def fn(x, y):
return x @ y
compiled_fn = torch.compile(
fn,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
x = torch.randn(4096, 4096, device=GPU_TYPE)
y = torch.randn(4096, 4096, device=GPU_TYPE)
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
ref = fn(x, y)
# wrapper_code will contains 2 entries if cpp_wrapper=True.
# One for the first pass and one for the second pass.
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertEqual(ref, act)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
@requires_triton()
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
def test_triton_relu_fused_gemm(self):
def fn(x, y):
return (x @ y).relu()
compiled_fn = torch.compile(
fn,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
x = torch.randn(4096, 4096, device=GPU_TYPE)
y = torch.randn(4096, 4096, device=GPU_TYPE)
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
ref = fn(x, y)
# wrapper_code will contains 2 entries if cpp_wrapper=True.
# One for the first pass and one for the second pass.
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertEqual(ref, act)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
@parametrize("force_kernel", (0, 1))
@unittest.mock.patch.dict(
os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}
)
def test_softmax_force_non_persistent_reduction(self, force_kernel):
"""
Force a specific sub-kernel being picked by mocking the benchmark result.
"""
x = torch.rand(2, 1024).to(GPU_TYPE)
mock_latency = [0.2, 0.2]
mock_latency[force_kernel] = 0.1 # this make sure force_kernel will be picked
def f(x):
return torch.softmax(x, -1) + force_kernel
orig_run = MultiKernelCall.run
picked_kernel = None
def mock_run(self, *args, **kwargs):
out = orig_run(self, *args, **kwargs)
nonlocal picked_kernel
picked_kernel = self.picked_kernel
return out
with (
unittest.mock.patch.object(MultiKernelCall, "run", mock_run),
unittest.mock.patch.object(
MultiKernelCall,
"benchmark_sub_kernels",
lambda *args, **kwargs: mock_latency,
),
):
torch.compile(f)(x)
self.assertEqual(picked_kernel, force_kernel)
@config.patch("warn_mix_layout", True)
def test_softmax_warn_mixed_layout(self):
self.test_softmax()
test_softmax_cpp_wrapper = make_cpp_wrapper_test(
test_softmax, expect_multi_kernel=True
)
def test_layernorm(self):
ln = nn.LayerNorm(1024).to(GPU_TYPE)
x = torch.rand(2, 1024).to(GPU_TYPE)
ref = ln(x)
act = torch.compile(ln)(x)
self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
def test_inplace_update(self):
"""
Inductor generate inplace kernel for mul.
"""
def f(x, y):
return x.sum(dim=-1, keepdims=True) * (y @ y)
x = torch.rand(1024, 1024).to(GPU_TYPE)
y = torch.rand(1024, 1024).to(GPU_TYPE)
ref = f(x, y)
act = torch.compile(f)(x, y)
self.assertEqual(ref, act)
def test_transformer_snippet(self):
model = TransformerSnippet().to(GPU_TYPE)
x = model.example_inputs()
def f(*x):
y = model(*x)
return y
reset_rng_state()
ref = f(*x)
opt_f = torch.compile(f)
reset_rng_state()
act = opt_f(*x)
# don't compare tensor if using inductor random number generator.
# inductor random number implementation is different to eager.
# We should fallback to eager if we want to test accuracy.
if config.fallback_random:
self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
def test_transformer_snippet_with_fallback_random(self):
"""
Same as test_transformer_snippet but fallback the random number
generator to eager so we can check accuracy.
"""
with config.patch("fallback_random", True):
self.test_transformer_snippet()
def test_batchnorm_training(self):
"""
For training, batchnorm will tracking running mean/variance during forward pass.
The kernel generated by inductor currently will pass in those tensors twice as arguments:
once for input and once for output. They are ruled out as in-out argument because
they are considered as graph inputs.
Multi-kernel previously assumes that we never pass the same argument multi times
for a kernel. No matter if we change inductor behavior to assure that, it's better
to make multi-kernel being able to handle those cases.
"""
bn = nn.BatchNorm2d(3).to(GPU_TYPE)
@torch.compile
def f(x):
bn(x).sum().backward()
_, (wrapper_code, _) = run_and_get_code(
f, torch.randn(2, 3, 8, 8, device=GPU_TYPE)
)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
def test_pass_same_arg_multi_times(self):
"""
A super simple example that simulate how BatchNorm update the running
stats.
Inductor currently pass the same tensor multiple times for the generated
kernel: once for input and once for output.
Here is a paster for the generated kernel (without multi-kernel enabled):
https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9
"""
def f(x, y):
x = x.sum(dim=1, keepdim=False)
y.copy_(y * 0.9 + x * 0.1)
x = torch.randn(8, 16, device=GPU_TYPE)
y = torch.randn(8, device=GPU_TYPE)
y_ref = y.clone()
ref = f(x, y_ref) # noqa: F841
act = torch.compile(f)(x, y) # noqa: F841
self.assertEqual(y_ref, y)
def test_reduction_scratch_buffer(self, force_multi_kernel=1):
"""
The explicitly realized buffer in the test function will be passed in
as a scratch buffer for the non-persistent reduction kernel but
can be skipped for the persistent reduction kernel.
This causes different argument lists for non-persistent reduction kernel and
persistent reduction kernel.
Check documentation around torch._inductor.config.triton.multi_kernel about
how to interpret the force_multi_kernel argument.
"""
def f(x):
x = x.sum(dim=-1, keepdim=True) + x
x = test_operators.realize(x)
x = x.sum(dim=-1, keepdim=True) + x
return x
x = torch.rand(16, 16, device=GPU_TYPE)
ref = f(x)
with config.patch("triton.multi_kernel", force_multi_kernel):
act = torch.compile(f)(x)
self.assertEqual(ref, act)
def test_split_scan(self, force_multi_kernel=1):
def f(x):
x = x.view(-1)
return torch.cumsum(x, 0)
x = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=GPU_TYPE)
expect = f(x)
with config.patch("triton.multi_kernel", force_multi_kernel):
actual = torch.compile(f)(x)
self.assertEqual(expect, actual)
def test_sort_disables_multi_kernel(self, force_multi_kernel=1):
"""
Sort currently requires a persistent kernel, so multi-kernel is not
possible. Make sure this falls back gracefully.
"""
def f(x):
return x.sort(-1).values
x = torch.rand(32, 32, device=GPU_TYPE)
expect = f(x)
with config.patch("triton.multi_kernel", force_multi_kernel):
actual = torch.compile(f)(x)
self.assertEqual(expect, actual)
# Use benchmarking to pick the faster kernel
test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test(
test_reduction_scratch_buffer, force_multi_kernel=1
)
# force pick persistent reduction. This can be a good test since this persistent
# reduction uses less call arguments than the corresponding non-persistent
# reduction.
test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2)
)
# force pick non-persistent reduction
test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests()