Files
pytorch/torch/testing/_internal/triton_utils.py
Adnan Akhundov 809ff3b274 Add host-side Triton TMA support to Dynamo (#137677)
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in https://github.com/triton-lang/triton/pull/4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created.
- The newly introduced variables have `reconstruct` methods used in case of graph breaks.
- The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like.
- In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors.
- In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required.
- JIT Inductor and AOT Inductor support will be implemented in follow-up PRs.

Differential Revision: [D64404928](https://our.internmc.facebook.com/intern/diff/D64404928)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137677
Approved by: https://github.com/zou3519
2024-10-16 02:18:48 +00:00

521 lines
15 KiB
Python

# mypy: ignore-errors
import unittest
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU
from torch.utils._triton import has_triton
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")
if has_triton():
import triton
from triton import language as tl
# Define here so that multiple tests can take advantage of it
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_optional_param(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
ARGS_PASSED: "tl.constexpr",
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
if ARGS_PASSED == "two":
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
else:
output = x
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
],
key=[],
)
@triton.jit
def add_kernel_autotuned_weird_param_order(
in_ptr0,
in_ptr1,
n_elements,
BLOCK_SIZE: "tl.constexpr",
out_ptr,
):
# out_ptr is after an autotuned param that's declared as tl.constexpr.
# This param ordering can create bugs if not handled correctly.
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
),
],
key=[],
)
@triton.jit
def add_kernel_2d_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
x_elements,
y_elements,
BLOCK_SIZE_X: "tl.constexpr",
BLOCK_SIZE_Y: "tl.constexpr",
):
xoffset = tl.program_id(0) * BLOCK_SIZE_X
xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
xmask = xindex < x_elements
yoffset = tl.program_id(1) * BLOCK_SIZE_Y
yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
ymask = yindex < y_elements
x1 = xindex
y0 = yindex
tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
tmp2 = tmp0 + tmp1
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
def _dummy_early_config_prune(configs, *_, **__):
return configs
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
warmup=10,
rep=20,
prune_configs_by={"early_config_prune": _dummy_early_config_prune},
)
@triton.jit
def add_kernel_autotuned_with_unsupported_args(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_scaling(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
scaling_factor,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = (x + y) * scaling_factor
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_tma_1d(
in_desc_ptr0,
in_desc_ptr1,
out_desc_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE
a = tl._experimental_descriptor_load(
in_desc_ptr0,
[offset],
[BLOCK_SIZE],
tl.float32,
)
b = tl._experimental_descriptor_load(
in_desc_ptr1,
[offset],
[BLOCK_SIZE],
tl.float32,
)
output = a + b
tl._experimental_descriptor_store(
out_desc_ptr,
output,
[offset],
)
@triton.jit
def add_kernel_with_tma_2d(
in_desc_ptr0,
in_desc_ptr1,
out_desc_ptr,
BLOCK_SIZE_X: "tl.constexpr",
BLOCK_SIZE_Y: "tl.constexpr",
):
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)
offset_x = pid_x * BLOCK_SIZE_X
offset_y = pid_y * BLOCK_SIZE_Y
x = tl._experimental_descriptor_load(
in_desc_ptr0,
[offset_x, offset_y],
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
tl.float32,
)
y = tl._experimental_descriptor_load(
in_desc_ptr1,
[offset_x, offset_y],
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
tl.float32,
)
output = x + y
tl._experimental_descriptor_store(
out_desc_ptr,
output,
[offset_x, offset_y],
)
@triton.jit
def mul2_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = 2 * x
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def mul2_inplace_kernel(
ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(ptr + offsets, mask=mask)
output = 2 * x
tl.store(ptr + offsets, output, mask=mask)
@triton.jit
def zero_negs(x):
return tl.where(x >= 0, x, 0)
@triton.jit
def indirection_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
ACTIVATION: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
if ACTIVATION == "mul2_inplace_kernel":
mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
elif ACTIVATION == "add_kernel":
add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
x = tl.load(in_ptr0 + offsets, mask=mask)
tl.store(out_ptr + offsets, x, mask=mask)
@triton.jit
def double_strided_kernel(
in_ptr,
out_ptr,
in_y_stride,
out_y_stride,
X_BLOCK_SIZE: "tl.constexpr",
Y_BLOCK_SIZE: "tl.constexpr",
):
xid = tl.program_id(axis=0)
yid = tl.program_id(axis=1)
x_start = xid * X_BLOCK_SIZE
y_start = yid * Y_BLOCK_SIZE
x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
src = tl.load(in_ptr + src_offsets)
tl.store(out_ptr + dst_offsets, src * 2.0)
@triton.jit
def inline_asm_kernel(X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.load(Y + tl.arange(0, BLOCK))
s = tl.full([BLOCK], n, tl.int32)
z = tl.inline_asm_elementwise(
"shf.l.wrap.b32 $0, $1, $2, $3;",
"=r,r, r, r",
[x, y, s],
dtype=tl.int32,
is_pure=True,
pack=1,
)
tl.store(Z + tl.arange(0, BLOCK), z)
@triton.jit
def add_kernel_with_block_ptr(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
x = tl.load(
tl.make_block_ptr(
base=x_ptr,
shape=[n_elements],
strides=[1],
offsets=[block_start],
block_shape=[BLOCK_SIZE],
order=[0],
),
boundary_check=[0],
)
y = tl.load(
tl.make_block_ptr(
base=y_ptr,
shape=[n_elements],
strides=[1],
offsets=[block_start],
block_shape=[BLOCK_SIZE],
order=[0],
),
boundary_check=[0],
)
output = x + y
tl.store(
tl.make_block_ptr(
base=output_ptr,
shape=[n_elements],
strides=[1],
offsets=[block_start],
block_shape=[BLOCK_SIZE],
order=[0],
),
output,
boundary_check=[0],
)
@triton.jit
def kernel_with_block_ptr_2d(
x_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
x = tl.load(
tl.make_block_ptr(
base=x_ptr,
shape=[n_elements, 1],
strides=[1, 1],
offsets=[block_start, 0],
block_shape=[BLOCK_SIZE, 1],
order=[1, 0],
),
boundary_check=[0],
)
output = x
tl.store(
tl.make_block_ptr(
base=output_ptr,
shape=[n_elements, 1],
strides=[1, 1],
offsets=[block_start, 0],
block_shape=[BLOCK_SIZE, 1],
order=[1, 0],
),
output,
boundary_check=[0],
)
from triton.language import load, store
@triton.jit
def add_kernel_with_import(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = load(in_ptr0 + offsets, mask=mask)
y = load(in_ptr1 + offsets, mask=mask)
output = x + y
store(out_ptr + offsets, output, mask=mask)
@triton.jit
def cond_op_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
if tl.program_id(0) == 0:
output = x + y
else:
output = x * y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def atomic_add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.atomic_add(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_4_times_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
for i in range(2):
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
i = 2
while i > 0:
i -= 1
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_out_of_order_fn2(
in_ptr0,
in_ptr1,
n_elements,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)