mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This is a reland of D80285441, fixed the unit test. Test Plan: ``` buck2 run mode/opt-amd-gpu -m rocm641 -c fbcode.split-dwarf=true -c fbcode.use_link_groups=true -c fbcode.enable_gpu_sections=true //hpc/new/models/feed/benchmark:feed_lower_benchmark -- --load=manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/894698382/0/gpu_lowering/new_input8 --skip-eager --skip-flop-estimation --sync-mode=0 --lower-backend=AOT_INDUCTOR ``` will succeed after this diff. Rollback Plan: Differential Revision: D80971224 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161521 Approved by: https://github.com/frank-wei
1044 lines
31 KiB
Python
1044 lines
31 KiB
Python
# mypy: ignore-errors
|
|
|
|
import unittest
|
|
|
|
from torch.testing._internal.inductor_utils import (
|
|
HAS_CUDA_AND_TRITON,
|
|
HAS_GPU,
|
|
HAS_XPU_AND_TRITON,
|
|
)
|
|
from torch.utils._triton import has_triton
|
|
|
|
|
|
requires_cuda_and_triton = unittest.skipUnless(
|
|
HAS_CUDA_AND_TRITON, "requires cuda and triton"
|
|
)
|
|
requires_gpu_and_triton = unittest.skipUnless(
|
|
HAS_XPU_AND_TRITON or HAS_CUDA_AND_TRITON, "requires gpu and triton"
|
|
)
|
|
requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")
|
|
|
|
if has_triton():
|
|
import triton
|
|
from triton import language as tl
|
|
|
|
import torch
|
|
|
|
def _get_strange_configs() -> list[triton.Config]:
|
|
if torch.version.hip:
|
|
configs = [
|
|
triton.Config(
|
|
{
|
|
"BLOCK_SIZE_M": 16,
|
|
"BLOCK_SIZE_N": 16,
|
|
"BLOCK_SIZE_K": 16,
|
|
"GROUP_SIZE_M": 4,
|
|
"matrix_instr_nonkdim": 16,
|
|
"waves_per_eu": 3,
|
|
"kpack": 2,
|
|
},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_SIZE_M": 128,
|
|
"BLOCK_SIZE_N": 64,
|
|
"BLOCK_SIZE_K": 16,
|
|
"GROUP_SIZE_M": 4,
|
|
"matrix_instr_nonkdim": 16,
|
|
"waves_per_eu": 3,
|
|
"kpack": 2,
|
|
},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
]
|
|
else:
|
|
configs = [
|
|
triton.Config(
|
|
{
|
|
"BLOCK_SIZE_M": 16,
|
|
"BLOCK_SIZE_N": 16,
|
|
"BLOCK_SIZE_K": 16,
|
|
"GROUP_SIZE_M": 4,
|
|
},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_SIZE_M": 128,
|
|
"BLOCK_SIZE_N": 64,
|
|
"BLOCK_SIZE_K": 32,
|
|
"GROUP_SIZE_M": 8,
|
|
},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
]
|
|
return configs
|
|
|
|
# Define here so that multiple tests can take advantage of it
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def sub_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x - y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_optional_param(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
ARGS_PASSED: "tl.constexpr",
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
if ARGS_PASSED == "two":
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
else:
|
|
output = x
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_none_param_and_equal_to_1_arg(
|
|
in_ptr0,
|
|
in_ptr1, # in_ptr1 could be None
|
|
out_ptr,
|
|
n_elements,
|
|
stride,
|
|
ARGS_PASSED: "tl.constexpr",
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets * stride, mask=mask)
|
|
if ARGS_PASSED == "two":
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
else:
|
|
output = x
|
|
tl.store(out_ptr + offsets * stride, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def add_kernel_autotuned(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def sub_kernel_autotuned(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x - y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def add_kernel_autotuned_weird_param_order(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
out_ptr,
|
|
):
|
|
# out_ptr is after an autotuned param that's declared as tl.constexpr.
|
|
# This param ordering can create bugs if not handled correctly.
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
|
|
),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def add_kernel_2d_autotuned(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
x_elements,
|
|
y_elements,
|
|
BLOCK_SIZE_X: "tl.constexpr",
|
|
BLOCK_SIZE_Y: "tl.constexpr",
|
|
):
|
|
xoffset = tl.program_id(0) * BLOCK_SIZE_X
|
|
xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
|
|
xmask = xindex < x_elements
|
|
yoffset = tl.program_id(1) * BLOCK_SIZE_Y
|
|
yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
|
|
ymask = yindex < y_elements
|
|
x1 = xindex
|
|
y0 = yindex
|
|
tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
|
|
tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
|
|
tmp2 = tmp0 + tmp1
|
|
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
|
|
|
|
def _dummy_early_config_prune(configs, *_, **__):
|
|
return configs
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=[],
|
|
warmup=10,
|
|
rep=20,
|
|
prune_configs_by={"early_config_prune": _dummy_early_config_prune},
|
|
)
|
|
@triton.jit
|
|
def add_kernel_autotuned_with_unsupported_args(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_scaling(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
scaling_factor,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = (x + y) * scaling_factor
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_tma_1d_old_api(
|
|
in_desc_ptr0,
|
|
in_desc_ptr1,
|
|
out_desc_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
offset = pid * BLOCK_SIZE
|
|
|
|
a = tl._experimental_descriptor_load(
|
|
in_desc_ptr0,
|
|
[offset],
|
|
[BLOCK_SIZE],
|
|
tl.float32,
|
|
)
|
|
b = tl._experimental_descriptor_load(
|
|
in_desc_ptr1,
|
|
[offset],
|
|
[BLOCK_SIZE],
|
|
tl.float32,
|
|
)
|
|
|
|
output = a + b
|
|
|
|
tl._experimental_descriptor_store(
|
|
out_desc_ptr,
|
|
output,
|
|
[offset],
|
|
)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_tma_2d_old_api(
|
|
in_desc_ptr0,
|
|
in_desc_ptr1,
|
|
out_desc_ptr,
|
|
BLOCK_SIZE_X: "tl.constexpr",
|
|
BLOCK_SIZE_Y: "tl.constexpr",
|
|
):
|
|
pid_x = tl.program_id(axis=0)
|
|
pid_y = tl.program_id(axis=1)
|
|
offset_x = pid_x * BLOCK_SIZE_X
|
|
offset_y = pid_y * BLOCK_SIZE_Y
|
|
|
|
x = tl._experimental_descriptor_load(
|
|
in_desc_ptr0,
|
|
[offset_x, offset_y],
|
|
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
|
|
tl.float32,
|
|
)
|
|
y = tl._experimental_descriptor_load(
|
|
in_desc_ptr1,
|
|
[offset_x, offset_y],
|
|
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
|
|
tl.float32,
|
|
)
|
|
|
|
output = x + y
|
|
|
|
tl._experimental_descriptor_store(
|
|
out_desc_ptr,
|
|
output,
|
|
[offset_x, offset_y],
|
|
)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_tma_1d_new_api(
|
|
in_desc_ptr0,
|
|
in_desc_ptr1,
|
|
out_desc_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
offset = pid * BLOCK_SIZE
|
|
|
|
a = tl.load_tensor_descriptor(
|
|
in_desc_ptr0,
|
|
[offset],
|
|
)
|
|
b = tl.load_tensor_descriptor(
|
|
in_desc_ptr1,
|
|
[offset],
|
|
)
|
|
|
|
output = a + b
|
|
|
|
tl.store_tensor_descriptor(
|
|
out_desc_ptr,
|
|
[offset],
|
|
output,
|
|
)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_tma_2d_new_api(
|
|
in_desc_ptr0,
|
|
in_desc_ptr1,
|
|
out_desc_ptr,
|
|
BLOCK_SIZE_X: "tl.constexpr",
|
|
BLOCK_SIZE_Y: "tl.constexpr",
|
|
):
|
|
pid_x = tl.program_id(axis=0)
|
|
pid_y = tl.program_id(axis=1)
|
|
offset_x = pid_x * BLOCK_SIZE_X
|
|
offset_y = pid_y * BLOCK_SIZE_Y
|
|
|
|
x = tl.load_tensor_descriptor(
|
|
in_desc_ptr0,
|
|
[offset_x, offset_y],
|
|
)
|
|
y = tl.load_tensor_descriptor(
|
|
in_desc_ptr1,
|
|
[offset_x, offset_y],
|
|
)
|
|
|
|
output = x + y
|
|
|
|
tl.store_tensor_descriptor(
|
|
out_desc_ptr,
|
|
[offset_x, offset_y],
|
|
output,
|
|
)
|
|
|
|
@triton.jit
|
|
def add_kernel_on_device_tma_old_api(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
m,
|
|
n,
|
|
workspace,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
a_desc_ptr = workspace
|
|
b_desc_ptr = workspace + 128
|
|
c_desc_ptr = workspace + 256
|
|
tl.extra.cuda.experimental_device_tensormap_create2d(
|
|
desc_ptr=a_desc_ptr,
|
|
global_address=a_ptr,
|
|
load_size=[BLOCK_SIZE, BLOCK_SIZE],
|
|
global_size=[m, n],
|
|
element_ty=a_ptr.dtype.element_ty,
|
|
)
|
|
tl.extra.cuda.experimental_device_tensormap_create2d(
|
|
desc_ptr=b_desc_ptr,
|
|
global_address=b_ptr,
|
|
load_size=[BLOCK_SIZE, BLOCK_SIZE],
|
|
global_size=[m, n],
|
|
element_ty=b_ptr.dtype.element_ty,
|
|
)
|
|
tl.extra.cuda.experimental_device_tensormap_create2d(
|
|
desc_ptr=c_desc_ptr,
|
|
global_address=c_ptr,
|
|
load_size=[BLOCK_SIZE, BLOCK_SIZE],
|
|
global_size=[m, n],
|
|
element_ty=c_ptr.dtype.element_ty,
|
|
)
|
|
|
|
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
|
|
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
|
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
|
|
|
pid_x = tl.program_id(axis=0)
|
|
pid_y = tl.program_id(axis=1)
|
|
offset_x = pid_x * BLOCK_SIZE
|
|
offset_y = pid_y * BLOCK_SIZE
|
|
|
|
# Load data using the tensor descriptors
|
|
a = tl._experimental_descriptor_load(
|
|
a_desc_ptr,
|
|
[offset_x, offset_y],
|
|
[BLOCK_SIZE, BLOCK_SIZE],
|
|
tl.float32,
|
|
)
|
|
b = tl._experimental_descriptor_load(
|
|
b_desc_ptr,
|
|
[offset_x, offset_y],
|
|
[BLOCK_SIZE, BLOCK_SIZE],
|
|
tl.float32,
|
|
)
|
|
|
|
# Perform addition
|
|
output = a + b
|
|
|
|
# Store the result
|
|
tl._experimental_descriptor_store(
|
|
c_desc_ptr,
|
|
output,
|
|
[offset_x, offset_y],
|
|
)
|
|
|
|
@triton.jit
|
|
def add_kernel_on_device_tma_new_api(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
m,
|
|
n,
|
|
workspace, # unused but left here to match the old API kernel
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
# Create tensor descriptors using the new API
|
|
a_desc = tl.make_tensor_descriptor(
|
|
base=a_ptr,
|
|
shape=[m, n],
|
|
strides=[n, 1],
|
|
block_shape=[BLOCK_SIZE, BLOCK_SIZE],
|
|
)
|
|
b_desc = tl.make_tensor_descriptor(
|
|
base=b_ptr,
|
|
shape=[m, n],
|
|
strides=[n, 1],
|
|
block_shape=[BLOCK_SIZE, BLOCK_SIZE],
|
|
)
|
|
c_desc = tl.make_tensor_descriptor(
|
|
base=c_ptr,
|
|
shape=[m, n],
|
|
strides=[n, 1],
|
|
block_shape=[BLOCK_SIZE, BLOCK_SIZE],
|
|
)
|
|
|
|
pid_x = tl.program_id(axis=0)
|
|
pid_y = tl.program_id(axis=1)
|
|
offset_x = pid_x * BLOCK_SIZE
|
|
offset_y = pid_y * BLOCK_SIZE
|
|
|
|
# Load data using the tensor descriptors with the new API
|
|
a = tl.load_tensor_descriptor(
|
|
a_desc,
|
|
[offset_x, offset_y],
|
|
)
|
|
b = tl.load_tensor_descriptor(
|
|
b_desc,
|
|
[offset_x, offset_y],
|
|
)
|
|
|
|
# Perform addition
|
|
output = a + b
|
|
|
|
# Store the result with the new API
|
|
tl.store_tensor_descriptor(
|
|
c_desc,
|
|
[offset_x, offset_y],
|
|
output,
|
|
)
|
|
|
|
@triton.jit
|
|
def mul2_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
output = 2 * x
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def mul2_inplace_kernel(
|
|
ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(ptr + offsets, mask=mask)
|
|
output = 2 * x
|
|
tl.store(ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def zero_negs(x):
|
|
return tl.where(x >= 0, x, 0)
|
|
|
|
@triton.jit
|
|
def indirection_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
ACTIVATION: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
if ACTIVATION == "mul2_inplace_kernel":
|
|
mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
|
|
elif ACTIVATION == "add_kernel":
|
|
add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
tl.store(out_ptr + offsets, x, mask=mask)
|
|
|
|
@triton.jit
|
|
def double_strided_kernel(
|
|
in_ptr,
|
|
out_ptr,
|
|
in_y_stride,
|
|
out_y_stride,
|
|
X_BLOCK_SIZE: "tl.constexpr",
|
|
Y_BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
xid = tl.program_id(axis=0)
|
|
yid = tl.program_id(axis=1)
|
|
x_start = xid * X_BLOCK_SIZE
|
|
y_start = yid * Y_BLOCK_SIZE
|
|
x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
|
|
y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
|
|
src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
|
|
dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
|
|
src = tl.load(in_ptr + src_offsets)
|
|
tl.store(out_ptr + dst_offsets, src * 2.0)
|
|
|
|
@triton.jit
|
|
def inline_asm_kernel_is_pure_true(
|
|
X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
|
|
):
|
|
x = tl.load(X + tl.arange(0, BLOCK))
|
|
y = tl.load(Y + tl.arange(0, BLOCK))
|
|
s = tl.full([BLOCK], n, tl.int32)
|
|
z = tl.inline_asm_elementwise(
|
|
"shf.l.wrap.b32 $0, $1, $2, $3;",
|
|
"=r,r, r, r",
|
|
[x, y, s],
|
|
dtype=tl.int32,
|
|
is_pure=True,
|
|
pack=1,
|
|
)
|
|
tl.store(Z + tl.arange(0, BLOCK), z)
|
|
|
|
@triton.jit
|
|
def inline_asm_kernel_is_pure_false(
|
|
X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
|
|
):
|
|
x = tl.load(X + tl.arange(0, BLOCK))
|
|
y = tl.load(Y + tl.arange(0, BLOCK))
|
|
s = tl.full([BLOCK], n, tl.int32)
|
|
z = tl.inline_asm_elementwise(
|
|
"shf.l.wrap.b32 $0, $1, $2, $3;",
|
|
"=r,r, r, r",
|
|
[x, y, s],
|
|
dtype=tl.int32,
|
|
is_pure=False,
|
|
pack=1,
|
|
)
|
|
tl.store(Z + tl.arange(0, BLOCK), z)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_block_ptr(
|
|
x_ptr,
|
|
y_ptr,
|
|
output_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
x = tl.load(
|
|
tl.make_block_ptr(
|
|
base=x_ptr,
|
|
shape=[n_elements],
|
|
strides=[1],
|
|
offsets=[block_start],
|
|
block_shape=[BLOCK_SIZE],
|
|
order=[0],
|
|
),
|
|
boundary_check=[0],
|
|
)
|
|
y = tl.load(
|
|
tl.make_block_ptr(
|
|
base=y_ptr,
|
|
shape=[n_elements],
|
|
strides=[1],
|
|
offsets=[block_start],
|
|
block_shape=[BLOCK_SIZE],
|
|
order=[0],
|
|
),
|
|
boundary_check=[0],
|
|
)
|
|
output = x + y
|
|
tl.store(
|
|
tl.make_block_ptr(
|
|
base=output_ptr,
|
|
shape=[n_elements],
|
|
strides=[1],
|
|
offsets=[block_start],
|
|
block_shape=[BLOCK_SIZE],
|
|
order=[0],
|
|
),
|
|
output,
|
|
boundary_check=[0],
|
|
)
|
|
|
|
@triton.jit
|
|
def kernel_with_block_ptr_2d(
|
|
x_ptr,
|
|
output_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
x = tl.load(
|
|
tl.make_block_ptr(
|
|
base=x_ptr,
|
|
shape=[n_elements, 1],
|
|
strides=[1, 1],
|
|
offsets=[block_start, 0],
|
|
block_shape=[BLOCK_SIZE, 1],
|
|
order=[1, 0],
|
|
),
|
|
boundary_check=[0],
|
|
)
|
|
output = x
|
|
tl.store(
|
|
tl.make_block_ptr(
|
|
base=output_ptr,
|
|
shape=[n_elements, 1],
|
|
strides=[1, 1],
|
|
offsets=[block_start, 0],
|
|
block_shape=[BLOCK_SIZE, 1],
|
|
order=[1, 0],
|
|
),
|
|
output,
|
|
boundary_check=[0],
|
|
)
|
|
|
|
from triton.language import load, store
|
|
|
|
@triton.jit
|
|
def add_kernel_with_import(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = load(in_ptr0 + offsets, mask=mask)
|
|
y = load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def cond_op_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
if tl.program_id(0) == 0:
|
|
output = x + y
|
|
else:
|
|
output = x * y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def atomic_add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.atomic_add(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_4_times_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
for i in range(2):
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
i = 2
|
|
while i > 0:
|
|
i -= 1
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_out_of_order_fn2(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
n_elements,
|
|
out_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{
|
|
"BLOCK_SIZE_M": 16,
|
|
"BLOCK_SIZE_N": 16,
|
|
"BLOCK_SIZE_K": 16,
|
|
"GROUP_SIZE_M": 4,
|
|
},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_SIZE_M": 128,
|
|
"BLOCK_SIZE_N": 64,
|
|
"BLOCK_SIZE_K": 32,
|
|
"GROUP_SIZE_M": 8,
|
|
},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
],
|
|
key=["M_ptr", "N", "K"],
|
|
)
|
|
@triton.jit
|
|
def strange_config_matmul_kernel(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
M_ptr,
|
|
N,
|
|
K,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
):
|
|
# This is a simplified matmul from Triton tutorial.
|
|
pid = tl.program_id(axis=0)
|
|
M = tl.load(M_ptr)
|
|
if M == 0 and BLOCK_SIZE_M > 32:
|
|
# This will run the full matmul if BLOCK_SIZE_M > 32
|
|
M = 4096
|
|
elif M == 0:
|
|
# This directly returns, which will cut short the bad config of 16-block size.
|
|
return
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (offs_am[:, None] + offs_k[None, :])
|
|
b_ptrs = b_ptr + (offs_k[:, None] + offs_bn[None, :])
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
accumulator = tl.dot(a, b, accumulator)
|
|
a_ptrs += BLOCK_SIZE_K
|
|
b_ptrs += BLOCK_SIZE_K
|
|
c = accumulator.to(tl.float16)
|
|
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + offs_cm[:, None] + offs_cn[None, :]
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
@triton.jit
|
|
def kernel_with_docstring_double_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
|
|
"""
|
|
This kernel contains a triple-quote docstring w/ double quotes.
|
|
Make sure that codegen sanitizes the docstring.
|
|
"""
|
|
pid = tl.program_id(axis=0)
|
|
offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
|
ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
|
|
tl.store(out_ptr + offsets, ones, mask=offsets < numel)
|
|
|
|
@triton.jit
|
|
def kernel_with_docstring_single_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
|
|
'''
|
|
This kernel contains a triple-quote docstring w/ single quotes
|
|
Make sure that codegen sanitizes the docstring.
|
|
To prevent it from being linted to double quotes: """!!!"""
|
|
'''
|
|
pid = tl.program_id(axis=0)
|
|
offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
|
ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
|
|
tl.store(out_ptr + offsets, ones, mask=offsets < numel)
|
|
|
|
@triton.jit
|
|
def kernel_inline_asm_double_quotes(
|
|
in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
|
data = tl.load(in_ptr + offsets, mask=offsets < numel)
|
|
cos_pow = tl.inline_asm_elementwise(
|
|
asm="""
|
|
{
|
|
cos.approx.f32 $0, $1;
|
|
ex2.approx.f32 $0, $0;
|
|
}
|
|
""",
|
|
constraints=("=r, r"),
|
|
args=[data],
|
|
dtype=tl.float32,
|
|
is_pure=True,
|
|
pack=1,
|
|
)
|
|
tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)
|
|
|
|
@triton.jit
|
|
def kernel_inline_asm_single_quotes(
|
|
in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
|
data = tl.load(in_ptr + offsets, mask=offsets < numel)
|
|
cos_pow = tl.inline_asm_elementwise(
|
|
asm='''
|
|
{
|
|
// double quotes to pacify the linter """!!!"""
|
|
cos.approx.f32 $0, $1;
|
|
ex2.approx.f32 $0, $0;
|
|
}
|
|
''',
|
|
constraints=("=r, r"),
|
|
args=[data],
|
|
dtype=tl.float32,
|
|
is_pure=True,
|
|
pack=1,
|
|
)
|
|
tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_boolean_param(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
add_xy, # boolean param
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
if add_xy:
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
else:
|
|
output = x
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
# support the old (experimental) and new (tensor_descriptor) APIs
|
|
def create_tensor_descriptor_shim(
|
|
tensor, block_sizes: list[int], new_api: bool = True
|
|
):
|
|
if new_api:
|
|
return triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
|
|
tensor, block_sizes
|
|
)
|
|
else:
|
|
if len(block_sizes) == 1:
|
|
return triton.tools.experimental_descriptor.create_1d_tma_descriptor(
|
|
tensor.data_ptr(),
|
|
tensor.size(0),
|
|
block_sizes[0],
|
|
tensor.element_size(),
|
|
)
|
|
else:
|
|
assert len(block_sizes) == 2
|
|
return triton.tools.experimental_descriptor.create_2d_tma_descriptor(
|
|
tensor.data_ptr(),
|
|
tensor.size(0),
|
|
tensor.size(1),
|
|
block_sizes[0],
|
|
block_sizes[1],
|
|
tensor.element_size(),
|
|
)
|