mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AMD] [Reland] Fix AMD User Defined Kernel Autotune (#161521)
Summary: This is a reland of D80285441, fixed the unit test. Test Plan: ``` buck2 run mode/opt-amd-gpu -m rocm641 -c fbcode.split-dwarf=true -c fbcode.use_link_groups=true -c fbcode.enable_gpu_sections=true //hpc/new/models/feed/benchmark:feed_lower_benchmark -- --load=manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/894698382/0/gpu_lowering/new_input8 --skip-eager --skip-flop-estimation --sync-mode=0 --lower-backend=AOT_INDUCTOR ``` will succeed after this diff. Rollback Plan: Differential Revision: D80971224 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161521 Approved by: https://github.com/frank-wei
This commit is contained in:
committed by
PyTorch MergeBot
parent
8fd3c9ce91
commit
c024b1f5a1
@ -63,6 +63,7 @@ from torch.testing._internal.common_utils import (
|
||||
MACOS_VERSION,
|
||||
MI300_ARCH,
|
||||
parametrize,
|
||||
runOnRocm,
|
||||
skipIfMPS,
|
||||
skipIfRocm,
|
||||
skipIfRocmArch,
|
||||
@ -6440,6 +6441,48 @@ class AOTInductorTestsTemplate:
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
@runOnRocm
|
||||
def test_rocm_triton_autotuning(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y, m):
|
||||
_M, K = x.shape
|
||||
K, N = y.shape
|
||||
M = torch.abs(m)
|
||||
out = torch.empty((_M, N), device=x.device, dtype=torch.float32)
|
||||
grid = lambda META: ( # noqa: E731
|
||||
triton.cdiv(
|
||||
4096 * 2046, META["BLOCK_SIZE_M"] * META["BLOCK_SIZE_N"]
|
||||
),
|
||||
)
|
||||
strange_config_matmul_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
)
|
||||
return out
|
||||
|
||||
x = torch.randn(4096, 1024, device=self.device)
|
||||
y = torch.randn(1024, 2048, device=self.device)
|
||||
m = torch.tensor([4096], dtype=torch.int32, device=self.device)
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
config.patch(
|
||||
{
|
||||
"triton.autotune_with_sample_inputs": True,
|
||||
"aot_inductor.allow_stack_allocation": self.allow_stack_allocation,
|
||||
"aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface,
|
||||
}
|
||||
),
|
||||
):
|
||||
torch._export.aot_compile(Model(), (x, y, m))
|
||||
|
||||
@skipIfRocm # RoCM does not support the config block size in test suite.
|
||||
def test_triton_autotuning(self):
|
||||
if self.device != GPU_TYPE:
|
||||
|
@ -228,11 +228,18 @@ def user_defined_kernel_grid_fn_code(
|
||||
key=lambda x: len(x[1].kwargs),
|
||||
reverse=True,
|
||||
):
|
||||
guardslist = []
|
||||
if c.kwargs:
|
||||
guards = [
|
||||
f"meta['{name}'] == {val}" for name, val in c.kwargs.items()
|
||||
]
|
||||
guards = " and ".join(guards)
|
||||
# Remove AMD specific kwargs.
|
||||
for kwarg in c.kwargs:
|
||||
if kwarg not in [
|
||||
"matrix_instr_nonkdim",
|
||||
"waves_per_eu",
|
||||
"kpack",
|
||||
]:
|
||||
guardslist.append(f"meta['{kwarg}'] == {c.kwargs[kwarg]}")
|
||||
if guardslist:
|
||||
guards = " and ".join(guardslist)
|
||||
else:
|
||||
guards = "True" # for configs with empty kwargs
|
||||
grid, example_grid = determine_grid(grid, example_grid)
|
||||
|
@ -22,6 +22,63 @@ if has_triton():
|
||||
import triton
|
||||
from triton import language as tl
|
||||
|
||||
import torch
|
||||
|
||||
def _get_strange_configs() -> list[triton.Config]:
|
||||
if torch.version.hip:
|
||||
configs = [
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 16,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"waves_per_eu": 3,
|
||||
"kpack": 2,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 16,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"waves_per_eu": 3,
|
||||
"kpack": 2,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
]
|
||||
else:
|
||||
configs = [
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 16,
|
||||
"GROUP_SIZE_M": 4,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
]
|
||||
return configs
|
||||
|
||||
# Define here so that multiple tests can take advantage of it
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
|
Reference in New Issue
Block a user