mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The core idea is to generate multiple matmul kernels using different hints for symbolic variables, then select the most appropriate one at runtime for each unique shape we encounter. You can find some early experimentation details in these posts: https://fb.workplace.com/groups/8940092306109185/posts/9803850776399996/ https://fb.workplace.com/groups/8940092306109185/posts/9695805170537891/ https://fb.workplace.com/groups/257735836456307/posts/906589324904285/ Here’s a graph illustrating the empirically observed worst-case performance if an oracle always selected the least optimal hint for a given runtime size:  This graph illustrates the performance of a hint size of 64 relative to the worst case. Notice that as the runtime sizes increase, the performance gradually approaches the worst case:  This graph shows the performance of a hint size of 4096 — very poor for small sizes, and also suboptimal for some mid-sized shapes:  Finally, here’s the graph that motivated this PR. It illustrates the performance when selecting the best of three kernels generated with three different hints — 64, 256, and 4096:  ## How to review this PR At a high level, this extends @shunting314's multi-kernel abstraction to support varying GEMM choices driven by different hints. A few key points: 1. Unlike reduction kernels, triton template matmuls pass their grid as arguments to the kernel. This PR updates `MultiKernelCall` to support kernels with varying arguments. 2. The `V.graph.sizevars.size_hints` API is extended to accept a `hint_override`, allowing us to substitute the example input’s size hint with a custom value when generating multiple kernels. 3. The choice generation and benchmarking logic is updated to support multiple hint values. One kernel is generated per value in `torch._inductor.config.multi_kernel_hints`, and at runtime, we select the most suitable kernel for the current shape. 4. This PR does not add support for cpp wrapper codegen to keep it scoped. That will be added in the next PR. ## Results The following is a basic test that shows our basic multi kernel working where we no longer show significant variance based on the original hint size: https://gist.github.com/bobrenjc93/ba711d529e65fd65839b34799f6323ec Before ``` Hint\Runtime | 64 | 256 | 4096 --------------------------------------------------- 64 | 0.0948 | 0.3124 | 4.9477 256 | 0.2243 | 0.2256 | 3.3880 4096 | 0.3384 | 0.3404 | 3.3010 ``` After ``` Hint\Runtime | 64 | 256 | 4096 --------------------------------------------------- 64 | 0.0951 | 0.2289 | 3.3013 256 | 0.0952 | 0.2258 | 3.4045 4096 | 0.0957 | 0.2231 | 3.3146 ``` We also see an average speedup of 5.04% for the matrix of all hint/runtime pairs in [64, 4096] for every increment of 64: https://docs.google.com/spreadsheets/d/12TmYUDrAAFASGuP3POXTKPeAvQWIRzKzdrVSIb3vQkA/edit?gid=480268938#gid=480268938  NB: This is just the beginning and I plan on doing more investigation to see further improve on this initial result. For posterity the script used to generate that matrix is here: https://gist.github.com/bobrenjc93/c211fd0bd97fad8f46b91ad9dee76ad0 HUD benchmark runs: base: https://github.com/pytorch/pytorch/actions/runs/15889871988 head: https://github.com/pytorch/pytorch/actions/runs/15889876842 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156628 Approved by: https://github.com/jansel
361 lines
12 KiB
Python
361 lines
12 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import os
|
|
import re
|
|
import unittest
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch._dynamo.testing import reset_rng_state
|
|
from torch._inductor import config, test_operators
|
|
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
|
from torch._inductor.test_case import TestCase
|
|
from torch._inductor.utils import run_and_get_code
|
|
from torch.nn import functional as F
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
skipIfXpu,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
GPU_TYPE,
|
|
HAS_GPU,
|
|
IS_BIG_GPU,
|
|
requires_triton,
|
|
)
|
|
|
|
|
|
class TransformerSnippet(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.ln1 = nn.LayerNorm(64)
|
|
self.ln2 = nn.LayerNorm(64)
|
|
|
|
def forward(self, x1, x2):
|
|
x1 = F.dropout(x1, 0.1)
|
|
x2 = F.dropout(self.ln1(x2), 0.1)
|
|
|
|
return self.ln2(x1 + x2)
|
|
|
|
def example_inputs(self):
|
|
return (torch.randn(2, 64).to(GPU_TYPE), torch.randn(2, 64).to(GPU_TYPE))
|
|
|
|
|
|
def _contains_multi_kernel_code(wrapper_code: str):
|
|
return (
|
|
re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code)
|
|
is not None
|
|
)
|
|
|
|
|
|
def make_cpp_wrapper_test(orig_test, **extra_args):
|
|
"""
|
|
Wrap an existing test into a new test with cpp-wrapper enabled.
|
|
|
|
Make this as a free function rather than staticmethod in MultiKernelTest.
|
|
Otherwise we get 'TypeError: 'staticmethod' object is not callable'
|
|
error in py3.8. (py3.10 works)
|
|
"""
|
|
|
|
@config.patch("cpp_wrapper", True)
|
|
@skipIfXpu(msg="cpp wrapper doesn't currently work on the XPU stack")
|
|
def fn(self):
|
|
# The same kernel may have been compiled by previous tests with
|
|
# cpp_wrapper disabled. Clear the cache so we go ahead to re-compile
|
|
# the kernel with cpp_wrapper enabled.
|
|
from torch._inductor import codecache
|
|
|
|
codecache.PyCodeCache.cache_clear()
|
|
return orig_test(self, **extra_args)
|
|
|
|
return fn
|
|
|
|
|
|
@config.patch(
|
|
{
|
|
"triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
|
|
"benchmark_kernel": True,
|
|
"multi_kernel_hints": [64, 256, 4096],
|
|
}
|
|
)
|
|
@instantiate_parametrized_tests
|
|
class MultiKernelTest(TestCase):
|
|
def test_softmax(self, expect_multi_kernel=True):
|
|
x = torch.rand(2, 1024).to(GPU_TYPE)
|
|
ref = torch.softmax(x, -1)
|
|
compiled_fn = torch.compile(torch.softmax)
|
|
act, wrapper_code = run_and_get_code(compiled_fn, x, -1)
|
|
|
|
# wrapper_code will contains 2 entries if cpp_wrapper=True.
|
|
# One for the first pass and one for the second pass.
|
|
# We mainly care about the wrapper for the final pass here.
|
|
wrapper_code = wrapper_code[-1]
|
|
self.assertEqual(ref, act)
|
|
if expect_multi_kernel:
|
|
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
|
else:
|
|
self.assertFalse(_contains_multi_kernel_code(wrapper_code))
|
|
|
|
@requires_triton()
|
|
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
|
|
def test_triton_gemm(self):
|
|
def fn(x, y):
|
|
return x @ y
|
|
|
|
compiled_fn = torch.compile(
|
|
fn,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
x = torch.randn(4096, 4096, device=GPU_TYPE)
|
|
y = torch.randn(4096, 4096, device=GPU_TYPE)
|
|
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
|
|
ref = fn(x, y)
|
|
|
|
# wrapper_code will contains 2 entries if cpp_wrapper=True.
|
|
# One for the first pass and one for the second pass.
|
|
# We mainly care about the wrapper for the final pass here.
|
|
wrapper_code = wrapper_code[-1]
|
|
self.assertEqual(ref, act)
|
|
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
|
|
|
@requires_triton()
|
|
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
|
|
def test_triton_relu_fused_gemm(self):
|
|
def fn(x, y):
|
|
return (x @ y).relu()
|
|
|
|
compiled_fn = torch.compile(
|
|
fn,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
x = torch.randn(4096, 4096, device=GPU_TYPE)
|
|
y = torch.randn(4096, 4096, device=GPU_TYPE)
|
|
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
|
|
ref = fn(x, y)
|
|
|
|
# wrapper_code will contains 2 entries if cpp_wrapper=True.
|
|
# One for the first pass and one for the second pass.
|
|
# We mainly care about the wrapper for the final pass here.
|
|
wrapper_code = wrapper_code[-1]
|
|
self.assertEqual(ref, act)
|
|
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
|
|
|
@parametrize("force_kernel", (0, 1))
|
|
@unittest.mock.patch.dict(
|
|
os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}
|
|
)
|
|
def test_softmax_force_non_persistent_reduction(self, force_kernel):
|
|
"""
|
|
Force a specific sub-kernel being picked by mocking the benchmark result.
|
|
"""
|
|
x = torch.rand(2, 1024).to(GPU_TYPE)
|
|
mock_latency = [0.2, 0.2]
|
|
mock_latency[force_kernel] = 0.1 # this make sure force_kernel will be picked
|
|
|
|
def f(x):
|
|
return torch.softmax(x, -1) + force_kernel
|
|
|
|
orig_run = MultiKernelCall.run
|
|
picked_kernel = None
|
|
|
|
def mock_run(self, *args, **kwargs):
|
|
out = orig_run(self, *args, **kwargs)
|
|
nonlocal picked_kernel
|
|
picked_kernel = self.picked_kernel
|
|
return out
|
|
|
|
with (
|
|
unittest.mock.patch.object(MultiKernelCall, "run", mock_run),
|
|
unittest.mock.patch.object(
|
|
MultiKernelCall,
|
|
"benchmark_sub_kernels",
|
|
lambda *args, **kwargs: mock_latency,
|
|
),
|
|
):
|
|
torch.compile(f)(x)
|
|
self.assertEqual(picked_kernel, force_kernel)
|
|
|
|
@config.patch("warn_mix_layout", True)
|
|
def test_softmax_warn_mixed_layout(self):
|
|
self.test_softmax()
|
|
|
|
test_softmax_cpp_wrapper = make_cpp_wrapper_test(
|
|
test_softmax, expect_multi_kernel=True
|
|
)
|
|
|
|
def test_layernorm(self):
|
|
ln = nn.LayerNorm(1024).to(GPU_TYPE)
|
|
x = torch.rand(2, 1024).to(GPU_TYPE)
|
|
ref = ln(x)
|
|
act = torch.compile(ln)(x)
|
|
self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_inplace_update(self):
|
|
"""
|
|
Inductor generate inplace kernel for mul.
|
|
"""
|
|
|
|
def f(x, y):
|
|
return x.sum(dim=-1, keepdims=True) * (y @ y)
|
|
|
|
x = torch.rand(1024, 1024).to(GPU_TYPE)
|
|
y = torch.rand(1024, 1024).to(GPU_TYPE)
|
|
ref = f(x, y)
|
|
act = torch.compile(f)(x, y)
|
|
self.assertEqual(ref, act)
|
|
|
|
def test_transformer_snippet(self):
|
|
model = TransformerSnippet().to(GPU_TYPE)
|
|
x = model.example_inputs()
|
|
|
|
def f(*x):
|
|
y = model(*x)
|
|
return y
|
|
|
|
reset_rng_state()
|
|
ref = f(*x)
|
|
|
|
opt_f = torch.compile(f)
|
|
reset_rng_state()
|
|
act = opt_f(*x)
|
|
|
|
# don't compare tensor if using inductor random number generator.
|
|
# inductor random number implementation is different to eager.
|
|
# We should fallback to eager if we want to test accuracy.
|
|
if config.fallback_random:
|
|
self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_transformer_snippet_with_fallback_random(self):
|
|
"""
|
|
Same as test_transformer_snippet but fallback the random number
|
|
generator to eager so we can check accuracy.
|
|
"""
|
|
with config.patch("fallback_random", True):
|
|
self.test_transformer_snippet()
|
|
|
|
def test_batchnorm_training(self):
|
|
"""
|
|
For training, batchnorm will tracking running mean/variance during forward pass.
|
|
The kernel generated by inductor currently will pass in those tensors twice as arguments:
|
|
once for input and once for output. They are ruled out as in-out argument because
|
|
they are considered as graph inputs.
|
|
|
|
Multi-kernel previously assumes that we never pass the same argument multi times
|
|
for a kernel. No matter if we change inductor behavior to assure that, it's better
|
|
to make multi-kernel being able to handle those cases.
|
|
"""
|
|
bn = nn.BatchNorm2d(3).to(GPU_TYPE)
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
bn(x).sum().backward()
|
|
|
|
_, (wrapper_code, _) = run_and_get_code(
|
|
f, torch.randn(2, 3, 8, 8, device=GPU_TYPE)
|
|
)
|
|
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
|
|
|
def test_pass_same_arg_multi_times(self):
|
|
"""
|
|
A super simple example that simulate how BatchNorm update the running
|
|
stats.
|
|
|
|
Inductor currently pass the same tensor multiple times for the generated
|
|
kernel: once for input and once for output.
|
|
|
|
Here is a paster for the generated kernel (without multi-kernel enabled):
|
|
https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9
|
|
"""
|
|
|
|
def f(x, y):
|
|
x = x.sum(dim=1, keepdim=False)
|
|
y.copy_(y * 0.9 + x * 0.1)
|
|
|
|
x = torch.randn(8, 16, device=GPU_TYPE)
|
|
y = torch.randn(8, device=GPU_TYPE)
|
|
y_ref = y.clone()
|
|
|
|
ref = f(x, y_ref) # noqa: F841
|
|
act = torch.compile(f)(x, y) # noqa: F841
|
|
self.assertEqual(y_ref, y)
|
|
|
|
def test_reduction_scratch_buffer(self, force_multi_kernel=1):
|
|
"""
|
|
The explicitly realized buffer in the test function will be passed in
|
|
as a scratch buffer for the non-persistent reduction kernel but
|
|
can be skipped for the persistent reduction kernel.
|
|
|
|
This causes different argument lists for non-persistent reduction kernel and
|
|
persistent reduction kernel.
|
|
|
|
Check documentation around torch._inductor.config.triton.multi_kernel about
|
|
how to interpret the force_multi_kernel argument.
|
|
"""
|
|
|
|
def f(x):
|
|
x = x.sum(dim=-1, keepdim=True) + x
|
|
x = test_operators.realize(x)
|
|
x = x.sum(dim=-1, keepdim=True) + x
|
|
return x
|
|
|
|
x = torch.rand(16, 16, device=GPU_TYPE)
|
|
ref = f(x)
|
|
with config.patch("triton.multi_kernel", force_multi_kernel):
|
|
act = torch.compile(f)(x)
|
|
self.assertEqual(ref, act)
|
|
|
|
def test_split_scan(self, force_multi_kernel=1):
|
|
def f(x):
|
|
x = x.view(-1)
|
|
return torch.cumsum(x, 0)
|
|
|
|
x = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=GPU_TYPE)
|
|
expect = f(x)
|
|
with config.patch("triton.multi_kernel", force_multi_kernel):
|
|
actual = torch.compile(f)(x)
|
|
self.assertEqual(expect, actual)
|
|
|
|
def test_sort_disables_multi_kernel(self, force_multi_kernel=1):
|
|
"""
|
|
Sort currently requires a persistent kernel, so multi-kernel is not
|
|
possible. Make sure this falls back gracefully.
|
|
"""
|
|
|
|
def f(x):
|
|
return x.sort(-1).values
|
|
|
|
x = torch.rand(32, 32, device=GPU_TYPE)
|
|
expect = f(x)
|
|
with config.patch("triton.multi_kernel", force_multi_kernel):
|
|
actual = torch.compile(f)(x)
|
|
self.assertEqual(expect, actual)
|
|
|
|
# Use benchmarking to pick the faster kernel
|
|
test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test(
|
|
test_reduction_scratch_buffer, force_multi_kernel=1
|
|
)
|
|
# force pick persistent reduction. This can be a good test since this persistent
|
|
# reduction uses less call arguments than the corresponding non-persistent
|
|
# reduction.
|
|
test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = (
|
|
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2)
|
|
)
|
|
# force pick non-persistent reduction
|
|
test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = (
|
|
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if HAS_GPU:
|
|
run_tests()
|