mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes: - Here we assume the availability of the host-side TMA API added to upstream Triton in https://github.com/triton-lang/triton/pull/4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024. - To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created. - The newly introduced variables have `reconstruct` methods used in case of graph breaks. - The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like. - In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors. - In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required. - JIT Inductor and AOT Inductor support will be implemented in follow-up PRs. Differential Revision: [D64404928](https://our.internmc.facebook.com/intern/diff/D64404928) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137677 Approved by: https://github.com/zou3519
521 lines
15 KiB
Python
521 lines
15 KiB
Python
# mypy: ignore-errors
|
|
|
|
import unittest
|
|
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU
|
|
from torch.utils._triton import has_triton
|
|
|
|
|
|
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
|
requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")
|
|
|
|
if has_triton():
|
|
import triton
|
|
from triton import language as tl
|
|
|
|
# Define here so that multiple tests can take advantage of it
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_optional_param(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
ARGS_PASSED: "tl.constexpr",
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
if ARGS_PASSED == "two":
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
else:
|
|
output = x
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def add_kernel_autotuned(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def add_kernel_autotuned_weird_param_order(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
out_ptr,
|
|
):
|
|
# out_ptr is after an autotuned param that's declared as tl.constexpr.
|
|
# This param ordering can create bugs if not handled correctly.
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
|
|
),
|
|
],
|
|
key=[],
|
|
)
|
|
@triton.jit
|
|
def add_kernel_2d_autotuned(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
x_elements,
|
|
y_elements,
|
|
BLOCK_SIZE_X: "tl.constexpr",
|
|
BLOCK_SIZE_Y: "tl.constexpr",
|
|
):
|
|
xoffset = tl.program_id(0) * BLOCK_SIZE_X
|
|
xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
|
|
xmask = xindex < x_elements
|
|
yoffset = tl.program_id(1) * BLOCK_SIZE_Y
|
|
yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
|
|
ymask = yindex < y_elements
|
|
x1 = xindex
|
|
y0 = yindex
|
|
tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
|
|
tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
|
|
tmp2 = tmp0 + tmp1
|
|
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
|
|
|
|
def _dummy_early_config_prune(configs, *_, **__):
|
|
return configs
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=[],
|
|
warmup=10,
|
|
rep=20,
|
|
prune_configs_by={"early_config_prune": _dummy_early_config_prune},
|
|
)
|
|
@triton.jit
|
|
def add_kernel_autotuned_with_unsupported_args(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_scaling(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
scaling_factor,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = (x + y) * scaling_factor
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_tma_1d(
|
|
in_desc_ptr0,
|
|
in_desc_ptr1,
|
|
out_desc_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
offset = pid * BLOCK_SIZE
|
|
|
|
a = tl._experimental_descriptor_load(
|
|
in_desc_ptr0,
|
|
[offset],
|
|
[BLOCK_SIZE],
|
|
tl.float32,
|
|
)
|
|
b = tl._experimental_descriptor_load(
|
|
in_desc_ptr1,
|
|
[offset],
|
|
[BLOCK_SIZE],
|
|
tl.float32,
|
|
)
|
|
|
|
output = a + b
|
|
|
|
tl._experimental_descriptor_store(
|
|
out_desc_ptr,
|
|
output,
|
|
[offset],
|
|
)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_tma_2d(
|
|
in_desc_ptr0,
|
|
in_desc_ptr1,
|
|
out_desc_ptr,
|
|
BLOCK_SIZE_X: "tl.constexpr",
|
|
BLOCK_SIZE_Y: "tl.constexpr",
|
|
):
|
|
pid_x = tl.program_id(axis=0)
|
|
pid_y = tl.program_id(axis=1)
|
|
offset_x = pid_x * BLOCK_SIZE_X
|
|
offset_y = pid_y * BLOCK_SIZE_Y
|
|
|
|
x = tl._experimental_descriptor_load(
|
|
in_desc_ptr0,
|
|
[offset_x, offset_y],
|
|
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
|
|
tl.float32,
|
|
)
|
|
y = tl._experimental_descriptor_load(
|
|
in_desc_ptr1,
|
|
[offset_x, offset_y],
|
|
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
|
|
tl.float32,
|
|
)
|
|
|
|
output = x + y
|
|
|
|
tl._experimental_descriptor_store(
|
|
out_desc_ptr,
|
|
output,
|
|
[offset_x, offset_y],
|
|
)
|
|
|
|
@triton.jit
|
|
def mul2_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
output = 2 * x
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def mul2_inplace_kernel(
|
|
ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(ptr + offsets, mask=mask)
|
|
output = 2 * x
|
|
tl.store(ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def zero_negs(x):
|
|
return tl.where(x >= 0, x, 0)
|
|
|
|
@triton.jit
|
|
def indirection_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
ACTIVATION: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
if ACTIVATION == "mul2_inplace_kernel":
|
|
mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
|
|
elif ACTIVATION == "add_kernel":
|
|
add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
tl.store(out_ptr + offsets, x, mask=mask)
|
|
|
|
@triton.jit
|
|
def double_strided_kernel(
|
|
in_ptr,
|
|
out_ptr,
|
|
in_y_stride,
|
|
out_y_stride,
|
|
X_BLOCK_SIZE: "tl.constexpr",
|
|
Y_BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
xid = tl.program_id(axis=0)
|
|
yid = tl.program_id(axis=1)
|
|
x_start = xid * X_BLOCK_SIZE
|
|
y_start = yid * Y_BLOCK_SIZE
|
|
x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
|
|
y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
|
|
src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
|
|
dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
|
|
src = tl.load(in_ptr + src_offsets)
|
|
tl.store(out_ptr + dst_offsets, src * 2.0)
|
|
|
|
@triton.jit
|
|
def inline_asm_kernel(X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"):
|
|
x = tl.load(X + tl.arange(0, BLOCK))
|
|
y = tl.load(Y + tl.arange(0, BLOCK))
|
|
s = tl.full([BLOCK], n, tl.int32)
|
|
z = tl.inline_asm_elementwise(
|
|
"shf.l.wrap.b32 $0, $1, $2, $3;",
|
|
"=r,r, r, r",
|
|
[x, y, s],
|
|
dtype=tl.int32,
|
|
is_pure=True,
|
|
pack=1,
|
|
)
|
|
tl.store(Z + tl.arange(0, BLOCK), z)
|
|
|
|
@triton.jit
|
|
def add_kernel_with_block_ptr(
|
|
x_ptr,
|
|
y_ptr,
|
|
output_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
x = tl.load(
|
|
tl.make_block_ptr(
|
|
base=x_ptr,
|
|
shape=[n_elements],
|
|
strides=[1],
|
|
offsets=[block_start],
|
|
block_shape=[BLOCK_SIZE],
|
|
order=[0],
|
|
),
|
|
boundary_check=[0],
|
|
)
|
|
y = tl.load(
|
|
tl.make_block_ptr(
|
|
base=y_ptr,
|
|
shape=[n_elements],
|
|
strides=[1],
|
|
offsets=[block_start],
|
|
block_shape=[BLOCK_SIZE],
|
|
order=[0],
|
|
),
|
|
boundary_check=[0],
|
|
)
|
|
output = x + y
|
|
tl.store(
|
|
tl.make_block_ptr(
|
|
base=output_ptr,
|
|
shape=[n_elements],
|
|
strides=[1],
|
|
offsets=[block_start],
|
|
block_shape=[BLOCK_SIZE],
|
|
order=[0],
|
|
),
|
|
output,
|
|
boundary_check=[0],
|
|
)
|
|
|
|
@triton.jit
|
|
def kernel_with_block_ptr_2d(
|
|
x_ptr,
|
|
output_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
x = tl.load(
|
|
tl.make_block_ptr(
|
|
base=x_ptr,
|
|
shape=[n_elements, 1],
|
|
strides=[1, 1],
|
|
offsets=[block_start, 0],
|
|
block_shape=[BLOCK_SIZE, 1],
|
|
order=[1, 0],
|
|
),
|
|
boundary_check=[0],
|
|
)
|
|
output = x
|
|
tl.store(
|
|
tl.make_block_ptr(
|
|
base=output_ptr,
|
|
shape=[n_elements, 1],
|
|
strides=[1, 1],
|
|
offsets=[block_start, 0],
|
|
block_shape=[BLOCK_SIZE, 1],
|
|
order=[1, 0],
|
|
),
|
|
output,
|
|
boundary_check=[0],
|
|
)
|
|
|
|
from triton.language import load, store
|
|
|
|
@triton.jit
|
|
def add_kernel_with_import(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = load(in_ptr0 + offsets, mask=mask)
|
|
y = load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def cond_op_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
if tl.program_id(0) == 0:
|
|
output = x + y
|
|
else:
|
|
output = x * y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def atomic_add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.atomic_add(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_4_times_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
for i in range(2):
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
i = 2
|
|
while i > 0:
|
|
i -= 1
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
@triton.jit
|
|
def add_kernel_out_of_order_fn2(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
n_elements,
|
|
out_ptr,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|