mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: When a Triton kernel has arguments with None values followed by arguments with value 1, AOTI attempts to remove the None arguments and update the indices of the equal_to_1 arguments in triton_meta["configs"]. However, if the same kernel is called multiple times, this optimization process is repeated. Prior to this diff, the indices of equal_to_1 arguments from subsequent calls (second and later) were based on the updated indices from the previous call, resulting in incorrect behavior. This diff aims to localize the updated indices for equal_to_1 arguments within the optimization process of the current call, ensuring accurate and consistent results. Test Plan: Unit Test: ``` buck2 run mode/dev-nosan caffe2/test/inductor:test_aot_inductor -- -r test_triton_kernel_with_none_inputs_and_equal_to_1_arg ``` Differential Revision: D69998314 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148102 Approved by: https://github.com/davidberard98, https://github.com/chenyang78
579 lines
17 KiB
Python
579 lines
17 KiB
Python
# mypy: ignore-errors
|
|
|
|
import unittest
|
|
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU
|
|
from torch.utils._triton import has_triton
|
|
|
|
|
|
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
|
requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")
|
|
|
|
if has_triton():
|
|
import triton
|
|
from triton import language as tl
|
|
|
|
# Define here so that multiple tests can take advantage of it
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def sub_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x - y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_optional_param(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
ARGS_PASSED: "tl.constexpr",
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
if ARGS_PASSED == "two":
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
else:
|
|
output = x
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_none_param_and_equal_to_1_arg(
|
|
in_ptr0,
|
|
in_ptr1, # in_ptr1 could be None
|
|
out_ptr,
|
|
n_elements,
|
|
stride,
|
|
ARGS_PASSED: "tl.constexpr",
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets * stride, mask=mask)
|
|
if ARGS_PASSED == "two":
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
else:
|
|
output = x
|
|
tl.store(out_ptr + offsets * stride, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def add_kernel_autotuned(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def add_kernel_autotuned_weird_param_order(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
out_ptr,
|
|
):
|
|
# out_ptr is after an autotuned param that's declared as tl.constexpr.
|
|
# This param ordering can create bugs if not handled correctly.
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
|
|
),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def add_kernel_2d_autotuned(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
x_elements,
|
|
y_elements,
|
|
BLOCK_SIZE_X: "tl.constexpr",
|
|
BLOCK_SIZE_Y: "tl.constexpr",
|
|
):
|
|
xoffset = tl.program_id(0) * BLOCK_SIZE_X
|
|
xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
|
|
xmask = xindex < x_elements
|
|
yoffset = tl.program_id(1) * BLOCK_SIZE_Y
|
|
yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
|
|
ymask = yindex < y_elements
|
|
x1 = xindex
|
|
y0 = yindex
|
|
tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
|
|
tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
|
|
tmp2 = tmp0 + tmp1
|
|
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
|
|
|
|
def _dummy_early_config_prune(configs, *_, **__):
|
|
return configs
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=[],
|
|
warmup=10,
|
|
rep=20,
|
|
prune_configs_by={"early_config_prune": _dummy_early_config_prune},
|
|
)
|
|
@triton.jit
|
|
def add_kernel_autotuned_with_unsupported_args(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_scaling(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
scaling_factor,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = (x + y) * scaling_factor
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_tma_1d(
|
|
in_desc_ptr0,
|
|
in_desc_ptr1,
|
|
out_desc_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
offset = pid * BLOCK_SIZE
|
|
|
|
a = tl._experimental_descriptor_load(
|
|
in_desc_ptr0,
|
|
[offset],
|
|
[BLOCK_SIZE],
|
|
tl.float32,
|
|
)
|
|
b = tl._experimental_descriptor_load(
|
|
in_desc_ptr1,
|
|
[offset],
|
|
[BLOCK_SIZE],
|
|
tl.float32,
|
|
)
|
|
|
|
output = a + b
|
|
|
|
tl._experimental_descriptor_store(
|
|
out_desc_ptr,
|
|
output,
|
|
[offset],
|
|
)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_tma_2d(
|
|
in_desc_ptr0,
|
|
in_desc_ptr1,
|
|
out_desc_ptr,
|
|
BLOCK_SIZE_X: "tl.constexpr",
|
|
BLOCK_SIZE_Y: "tl.constexpr",
|
|
):
|
|
pid_x = tl.program_id(axis=0)
|
|
pid_y = tl.program_id(axis=1)
|
|
offset_x = pid_x * BLOCK_SIZE_X
|
|
offset_y = pid_y * BLOCK_SIZE_Y
|
|
|
|
x = tl._experimental_descriptor_load(
|
|
in_desc_ptr0,
|
|
[offset_x, offset_y],
|
|
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
|
|
tl.float32,
|
|
)
|
|
y = tl._experimental_descriptor_load(
|
|
in_desc_ptr1,
|
|
[offset_x, offset_y],
|
|
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
|
|
tl.float32,
|
|
)
|
|
|
|
output = x + y
|
|
|
|
tl._experimental_descriptor_store(
|
|
out_desc_ptr,
|
|
output,
|
|
[offset_x, offset_y],
|
|
)
|
|
|
|
@triton.jit
|
|
def mul2_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
output = 2 * x
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def mul2_inplace_kernel(
|
|
ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(ptr + offsets, mask=mask)
|
|
output = 2 * x
|
|
tl.store(ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def zero_negs(x):
|
|
return tl.where(x >= 0, x, 0)
|
|
|
|
@triton.jit
|
|
def indirection_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
ACTIVATION: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
if ACTIVATION == "mul2_inplace_kernel":
|
|
mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
|
|
elif ACTIVATION == "add_kernel":
|
|
add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
tl.store(out_ptr + offsets, x, mask=mask)
|
|
|
|
@triton.jit
|
|
def double_strided_kernel(
|
|
in_ptr,
|
|
out_ptr,
|
|
in_y_stride,
|
|
out_y_stride,
|
|
X_BLOCK_SIZE: "tl.constexpr",
|
|
Y_BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
xid = tl.program_id(axis=0)
|
|
yid = tl.program_id(axis=1)
|
|
x_start = xid * X_BLOCK_SIZE
|
|
y_start = yid * Y_BLOCK_SIZE
|
|
x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
|
|
y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
|
|
src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
|
|
dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
|
|
src = tl.load(in_ptr + src_offsets)
|
|
tl.store(out_ptr + dst_offsets, src * 2.0)
|
|
|
|
@triton.jit
|
|
def inline_asm_kernel_is_pure_true(
|
|
X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
|
|
):
|
|
x = tl.load(X + tl.arange(0, BLOCK))
|
|
y = tl.load(Y + tl.arange(0, BLOCK))
|
|
s = tl.full([BLOCK], n, tl.int32)
|
|
z = tl.inline_asm_elementwise(
|
|
"shf.l.wrap.b32 $0, $1, $2, $3;",
|
|
"=r,r, r, r",
|
|
[x, y, s],
|
|
dtype=tl.int32,
|
|
is_pure=True,
|
|
pack=1,
|
|
)
|
|
tl.store(Z + tl.arange(0, BLOCK), z)
|
|
|
|
@triton.jit
|
|
def inline_asm_kernel_is_pure_false(
|
|
X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
|
|
):
|
|
x = tl.load(X + tl.arange(0, BLOCK))
|
|
y = tl.load(Y + tl.arange(0, BLOCK))
|
|
s = tl.full([BLOCK], n, tl.int32)
|
|
z = tl.inline_asm_elementwise(
|
|
"shf.l.wrap.b32 $0, $1, $2, $3;",
|
|
"=r,r, r, r",
|
|
[x, y, s],
|
|
dtype=tl.int32,
|
|
is_pure=False,
|
|
pack=1,
|
|
)
|
|
tl.store(Z + tl.arange(0, BLOCK), z)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_block_ptr(
|
|
x_ptr,
|
|
y_ptr,
|
|
output_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
x = tl.load(
|
|
tl.make_block_ptr(
|
|
base=x_ptr,
|
|
shape=[n_elements],
|
|
strides=[1],
|
|
offsets=[block_start],
|
|
block_shape=[BLOCK_SIZE],
|
|
order=[0],
|
|
),
|
|
boundary_check=[0],
|
|
)
|
|
y = tl.load(
|
|
tl.make_block_ptr(
|
|
base=y_ptr,
|
|
shape=[n_elements],
|
|
strides=[1],
|
|
offsets=[block_start],
|
|
block_shape=[BLOCK_SIZE],
|
|
order=[0],
|
|
),
|
|
boundary_check=[0],
|
|
)
|
|
output = x + y
|
|
tl.store(
|
|
tl.make_block_ptr(
|
|
base=output_ptr,
|
|
shape=[n_elements],
|
|
strides=[1],
|
|
offsets=[block_start],
|
|
block_shape=[BLOCK_SIZE],
|
|
order=[0],
|
|
),
|
|
output,
|
|
boundary_check=[0],
|
|
)
|
|
|
|
@triton.jit
|
|
def kernel_with_block_ptr_2d(
|
|
x_ptr,
|
|
output_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
x = tl.load(
|
|
tl.make_block_ptr(
|
|
base=x_ptr,
|
|
shape=[n_elements, 1],
|
|
strides=[1, 1],
|
|
offsets=[block_start, 0],
|
|
block_shape=[BLOCK_SIZE, 1],
|
|
order=[1, 0],
|
|
),
|
|
boundary_check=[0],
|
|
)
|
|
output = x
|
|
tl.store(
|
|
tl.make_block_ptr(
|
|
base=output_ptr,
|
|
shape=[n_elements, 1],
|
|
strides=[1, 1],
|
|
offsets=[block_start, 0],
|
|
block_shape=[BLOCK_SIZE, 1],
|
|
order=[1, 0],
|
|
),
|
|
output,
|
|
boundary_check=[0],
|
|
)
|
|
|
|
from triton.language import load, store
|
|
|
|
@triton.jit
|
|
def add_kernel_with_import(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = load(in_ptr0 + offsets, mask=mask)
|
|
y = load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def cond_op_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
if tl.program_id(0) == 0:
|
|
output = x + y
|
|
else:
|
|
output = x * y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def atomic_add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.atomic_add(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_4_times_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
for i in range(2):
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
i = 2
|
|
while i > 0:
|
|
i -= 1
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_out_of_order_fn2(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
n_elements,
|
|
out_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|