mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Introduces a variant of size-hint multi-kernel, where for novel runtime shapes, instead of performing full benchmarking to determine the optimal kernel, selects one of many kernels pre-generated from multi-kernel hints, based off similarity b/w hint / runtime input & output shapes (L1 distance in log2 space). Some caveats/changes: - Size-hint multi-kernel now only kicks in if the kernel has dynamic shapes - Pre-generation still only does 1-d search over specified hints, e.g. `matmul([s0, s1], [s1, s2])` with size-hints `[64, 256]` only generates 2 kernels - based on tuning shapes ([64, 64], [64, 64]) and ([256, 256], [256, 256]). Extending this to reasonable n-d search (via user API?) is an extension Benchmarking results, compared to multi-kernel w/ full benchmarking (hints 64, 4096), and compiling with the ground truth hint: <img width="1902" height="1222" alt="550541081_1088709150049684_6528797079439730237_n" src="https://github.com/user-attachments/assets/056cca48-c16a-4451-9b4a-fa13a7a058a9" /> Full benchmarking doing worse is extremely weird, but we did see similar spikes in #156628 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163090 Approved by: https://github.com/bobrenjc93
378 lines
12 KiB
Python
378 lines
12 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import os
|
|
import re
|
|
import unittest
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch._dynamo.testing import reset_rng_state
|
|
from torch._inductor import config, test_operators
|
|
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
|
from torch._inductor.test_case import TestCase
|
|
from torch._inductor.utils import run_and_get_code
|
|
from torch.nn import functional as F
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
skipIfRocm,
|
|
skipIfXpu,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
GPU_TYPE,
|
|
HAS_GPU,
|
|
IS_BIG_GPU,
|
|
requires_triton,
|
|
)
|
|
|
|
|
|
class TransformerSnippet(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.ln1 = nn.LayerNorm(64)
|
|
self.ln2 = nn.LayerNorm(64)
|
|
|
|
def forward(self, x1, x2):
|
|
x1 = F.dropout(x1, 0.1)
|
|
x2 = F.dropout(self.ln1(x2), 0.1)
|
|
|
|
return self.ln2(x1 + x2)
|
|
|
|
def example_inputs(self):
|
|
return (torch.randn(2, 64).to(GPU_TYPE), torch.randn(2, 64).to(GPU_TYPE))
|
|
|
|
|
|
def _contains_multi_kernel_code(wrapper_code: str):
|
|
return (
|
|
re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code)
|
|
is not None
|
|
)
|
|
|
|
|
|
def _contains_size_hint_multi_kernel_code(wrapper_code: str):
|
|
return (
|
|
re.search(
|
|
r"multi_kernel_[^ ]* = async_compile.size_hint_multi_kernel[(]",
|
|
wrapper_code,
|
|
)
|
|
is not None
|
|
)
|
|
|
|
|
|
def make_cpp_wrapper_test(orig_test, **extra_args):
|
|
"""
|
|
Wrap an existing test into a new test with cpp-wrapper enabled.
|
|
|
|
Make this as a free function rather than staticmethod in MultiKernelTest.
|
|
Otherwise we get 'TypeError: 'staticmethod' object is not callable'
|
|
error in py3.8. (py3.10 works)
|
|
"""
|
|
|
|
@config.patch("cpp_wrapper", True)
|
|
@skipIfXpu(msg="cpp wrapper doesn't currently work on the XPU stack")
|
|
def fn(self):
|
|
# The same kernel may have been compiled by previous tests with
|
|
# cpp_wrapper disabled. Clear the cache so we go ahead to re-compile
|
|
# the kernel with cpp_wrapper enabled.
|
|
from torch._inductor import codecache
|
|
|
|
codecache.PyCodeCache.cache_clear()
|
|
return orig_test(self, **extra_args)
|
|
|
|
return fn
|
|
|
|
|
|
@config.patch(
|
|
{
|
|
"triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
|
|
"benchmark_kernel": True,
|
|
"multi_kernel_hints": [64, 256, 4096],
|
|
}
|
|
)
|
|
@instantiate_parametrized_tests
|
|
class MultiKernelTest(TestCase):
|
|
def test_softmax(self, expect_multi_kernel=True):
|
|
x = torch.rand(2, 1024).to(GPU_TYPE)
|
|
ref = torch.softmax(x, -1)
|
|
compiled_fn = torch.compile(torch.softmax)
|
|
act, wrapper_code = run_and_get_code(compiled_fn, x, -1)
|
|
|
|
# wrapper_code will contains 2 entries if cpp_wrapper=True.
|
|
# One for the first pass and one for the second pass.
|
|
# We mainly care about the wrapper for the final pass here.
|
|
wrapper_code = wrapper_code[-1]
|
|
self.assertEqual(ref, act)
|
|
if expect_multi_kernel:
|
|
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
|
else:
|
|
self.assertFalse(_contains_multi_kernel_code(wrapper_code))
|
|
|
|
@requires_triton()
|
|
# TODO: bobrenjc93 to fix multi-kernel for ROCM
|
|
@skipIfRocm
|
|
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
|
|
def test_triton_gemm(self):
|
|
def fn(x, y):
|
|
return x @ y
|
|
|
|
compiled_fn = torch.compile(
|
|
fn,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
x = torch.randn(4096, 4096, device=GPU_TYPE)
|
|
y = torch.randn(4096, 4096, device=GPU_TYPE)
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
|
|
ref = fn(x, y)
|
|
|
|
# wrapper_code will contains 2 entries if cpp_wrapper=True.
|
|
# One for the first pass and one for the second pass.
|
|
# We mainly care about the wrapper for the final pass here.
|
|
wrapper_code = wrapper_code[-1]
|
|
self.assertEqual(ref, act)
|
|
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
|
|
|
|
@requires_triton()
|
|
# TODO: bobrenjc93 to fix multi-kernel for ROCM
|
|
@skipIfRocm
|
|
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
|
|
def test_triton_relu_fused_gemm(self):
|
|
def fn(x, y):
|
|
return (x @ y).relu()
|
|
|
|
compiled_fn = torch.compile(
|
|
fn,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
x = torch.randn(4096, 4096, device=GPU_TYPE)
|
|
y = torch.randn(4096, 4096, device=GPU_TYPE)
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
|
|
ref = fn(x, y)
|
|
|
|
# wrapper_code will contains 2 entries if cpp_wrapper=True.
|
|
# One for the first pass and one for the second pass.
|
|
# We mainly care about the wrapper for the final pass here.
|
|
wrapper_code = wrapper_code[-1]
|
|
self.assertEqual(ref, act)
|
|
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
|
|
|
|
@parametrize("force_kernel", (0, 1))
|
|
@unittest.mock.patch.dict(
|
|
os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}
|
|
)
|
|
def test_softmax_force_non_persistent_reduction(self, force_kernel):
|
|
"""
|
|
Force a specific sub-kernel being picked by mocking the benchmark result.
|
|
"""
|
|
x = torch.rand(2, 1024).to(GPU_TYPE)
|
|
mock_latency = [0.2, 0.2]
|
|
mock_latency[force_kernel] = 0.1 # this make sure force_kernel will be picked
|
|
|
|
def f(x):
|
|
return torch.softmax(x, -1) + force_kernel
|
|
|
|
orig_run = MultiKernelCall.run
|
|
picked_kernel = None
|
|
|
|
def mock_run(self, *args, **kwargs):
|
|
out = orig_run(self, *args, **kwargs)
|
|
nonlocal picked_kernel
|
|
picked_kernel = self.picked_kernel
|
|
return out
|
|
|
|
with (
|
|
unittest.mock.patch.object(MultiKernelCall, "run", mock_run),
|
|
unittest.mock.patch.object(
|
|
MultiKernelCall,
|
|
"benchmark_sub_kernels",
|
|
lambda *args, **kwargs: mock_latency,
|
|
),
|
|
):
|
|
torch.compile(f)(x)
|
|
self.assertEqual(picked_kernel, force_kernel)
|
|
|
|
@config.patch("warn_mix_layout", True)
|
|
def test_softmax_warn_mixed_layout(self):
|
|
self.test_softmax()
|
|
|
|
test_softmax_cpp_wrapper = make_cpp_wrapper_test(
|
|
test_softmax, expect_multi_kernel=True
|
|
)
|
|
|
|
def test_layernorm(self):
|
|
ln = nn.LayerNorm(1024).to(GPU_TYPE)
|
|
x = torch.rand(2, 1024).to(GPU_TYPE)
|
|
ref = ln(x)
|
|
act = torch.compile(ln)(x)
|
|
self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_inplace_update(self):
|
|
"""
|
|
Inductor generate inplace kernel for mul.
|
|
"""
|
|
|
|
def f(x, y):
|
|
return x.sum(dim=-1, keepdims=True) * (y @ y)
|
|
|
|
x = torch.rand(1024, 1024).to(GPU_TYPE)
|
|
y = torch.rand(1024, 1024).to(GPU_TYPE)
|
|
ref = f(x, y)
|
|
act = torch.compile(f)(x, y)
|
|
self.assertEqual(ref, act)
|
|
|
|
def test_transformer_snippet(self):
|
|
model = TransformerSnippet().to(GPU_TYPE)
|
|
x = model.example_inputs()
|
|
|
|
def f(*x):
|
|
y = model(*x)
|
|
return y
|
|
|
|
reset_rng_state()
|
|
ref = f(*x)
|
|
|
|
opt_f = torch.compile(f)
|
|
reset_rng_state()
|
|
act = opt_f(*x)
|
|
|
|
# don't compare tensor if using inductor random number generator.
|
|
# inductor random number implementation is different to eager.
|
|
# We should fallback to eager if we want to test accuracy.
|
|
if config.fallback_random:
|
|
self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_transformer_snippet_with_fallback_random(self):
|
|
"""
|
|
Same as test_transformer_snippet but fallback the random number
|
|
generator to eager so we can check accuracy.
|
|
"""
|
|
with config.patch("fallback_random", True):
|
|
self.test_transformer_snippet()
|
|
|
|
def test_batchnorm_training(self):
|
|
"""
|
|
For training, batchnorm will tracking running mean/variance during forward pass.
|
|
The kernel generated by inductor currently will pass in those tensors twice as arguments:
|
|
once for input and once for output. They are ruled out as in-out argument because
|
|
they are considered as graph inputs.
|
|
|
|
Multi-kernel previously assumes that we never pass the same argument multi times
|
|
for a kernel. No matter if we change inductor behavior to assure that, it's better
|
|
to make multi-kernel being able to handle those cases.
|
|
"""
|
|
bn = nn.BatchNorm2d(3).to(GPU_TYPE)
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
bn(x).sum().backward()
|
|
|
|
_, (wrapper_code, _) = run_and_get_code(
|
|
f, torch.randn(2, 3, 8, 8, device=GPU_TYPE)
|
|
)
|
|
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
|
|
|
def test_pass_same_arg_multi_times(self):
|
|
"""
|
|
A super simple example that simulate how BatchNorm update the running
|
|
stats.
|
|
|
|
Inductor currently pass the same tensor multiple times for the generated
|
|
kernel: once for input and once for output.
|
|
|
|
Here is a paster for the generated kernel (without multi-kernel enabled):
|
|
https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9
|
|
"""
|
|
|
|
def f(x, y):
|
|
x = x.sum(dim=1, keepdim=False)
|
|
y.copy_(y * 0.9 + x * 0.1)
|
|
|
|
x = torch.randn(8, 16, device=GPU_TYPE)
|
|
y = torch.randn(8, device=GPU_TYPE)
|
|
y_ref = y.clone()
|
|
|
|
ref = f(x, y_ref) # noqa: F841
|
|
act = torch.compile(f)(x, y) # noqa: F841
|
|
self.assertEqual(y_ref, y)
|
|
|
|
def test_reduction_scratch_buffer(self, force_multi_kernel=1):
|
|
"""
|
|
The explicitly realized buffer in the test function will be passed in
|
|
as a scratch buffer for the non-persistent reduction kernel but
|
|
can be skipped for the persistent reduction kernel.
|
|
|
|
This causes different argument lists for non-persistent reduction kernel and
|
|
persistent reduction kernel.
|
|
|
|
Check documentation around torch._inductor.config.triton.multi_kernel about
|
|
how to interpret the force_multi_kernel argument.
|
|
"""
|
|
|
|
def f(x):
|
|
x = x.sum(dim=-1, keepdim=True) + x
|
|
x = test_operators.realize(x)
|
|
x = x.sum(dim=-1, keepdim=True) + x
|
|
return x
|
|
|
|
x = torch.rand(16, 16, device=GPU_TYPE)
|
|
ref = f(x)
|
|
with config.patch("triton.multi_kernel", force_multi_kernel):
|
|
act = torch.compile(f)(x)
|
|
self.assertEqual(ref, act)
|
|
|
|
def test_split_scan(self, force_multi_kernel=1):
|
|
def f(x):
|
|
x = x.view(-1)
|
|
return torch.cumsum(x, 0)
|
|
|
|
x = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=GPU_TYPE)
|
|
expect = f(x)
|
|
with config.patch("triton.multi_kernel", force_multi_kernel):
|
|
actual = torch.compile(f)(x)
|
|
self.assertEqual(expect, actual)
|
|
|
|
def test_sort_disables_multi_kernel(self, force_multi_kernel=1):
|
|
"""
|
|
Sort currently requires a persistent kernel, so multi-kernel is not
|
|
possible. Make sure this falls back gracefully.
|
|
"""
|
|
|
|
def f(x):
|
|
return x.sort(-1).values
|
|
|
|
x = torch.rand(32, 32, device=GPU_TYPE)
|
|
expect = f(x)
|
|
with config.patch("triton.multi_kernel", force_multi_kernel):
|
|
actual = torch.compile(f)(x)
|
|
self.assertEqual(expect, actual)
|
|
|
|
# Use benchmarking to pick the faster kernel
|
|
test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test(
|
|
test_reduction_scratch_buffer, force_multi_kernel=1
|
|
)
|
|
# force pick persistent reduction. This can be a good test since this persistent
|
|
# reduction uses less call arguments than the corresponding non-persistent
|
|
# reduction.
|
|
test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = (
|
|
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2)
|
|
)
|
|
# force pick non-persistent reduction
|
|
test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = (
|
|
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if HAS_GPU:
|
|
run_tests()
|