[flex attention][triton pin] use new TMA API (#155771)

Triton 3.4 will remove the experimental TMA APIs: https://github.com/triton-lang/triton/pull/6488. Ahead of this, we are **replacing the experimental TMA API usage with the stable TMA API** in flex attention. This means that **flex attention TMA will stop working with Triton 3.2 or Triton 3.3/3.3.1** for now (but it should work for Triton 3.4 in the PyTorch 2.8 release, and Meta-internal triton 3.3.1fb, which have the new TMA API).

This PR does the following:
* replace the experimental TMA APIs with the stable TMA APIs
* remove the workspace args.

Testing: I ran test/inductor/test_flex_attention.py on a H100 with @mandroid6's PR #153662 patched in to turn on TMA [TODO: confirm results once all the local tests pass, but from the first 100 tests I ran locally, all the failing tests were also failing on #153662 alone]

Note: When #153662 lands, turning on TMA support by default, it should be checking specifically for stable TMA API support (commented on PR)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155771
Approved by: https://github.com/mandroid6, https://github.com/nmacchioni
This commit is contained in:
David Berard
2025-06-12 09:40:11 -07:00
committed by PyTorch MergeBot
parent 92b7ed6d07
commit c843909d9e

View File

@ -51,7 +51,6 @@ from ..select_algorithm import (
SymbolicGridFn,
TritonTemplate,
)
from ..utils import get_tma_workspace_arg
log = logging.getLogger(__name__)
@ -394,41 +393,26 @@ compute_flex_attention = r"""
desc_q = None
desc_k = None
desc_v = None
if USE_TMA:
TMA_SIZE = 128
workspace_base = ws_ptr + TMA_SIZE * 3 * (
tl.program_id(1) + tl.program_id(0) * tl.num_programs(1)
{%- if USE_TMA %}
desc_q = tl.make_tensor_descriptor(
base=Q,
shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM],
strides=[QK_HEAD_DIM, 1],
block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
)
desc_q = workspace_base
desc_v = workspace_base + TMA_SIZE
desc_k = workspace_base + 2 * TMA_SIZE
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=desc_q,
global_address=Q,
load_size=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
global_size=[Q_LEN*HQ*ZQ, QK_HEAD_DIM],
element_ty=Q.dtype.element_ty,
desc_v = tl.make_tensor_descriptor(
base=V,
shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
strides=[V_HEAD_DIM, 1],
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=desc_v,
global_address=V,
load_size=[BLOCK_N, V_HEAD_DIM_ROUNDED],
global_size=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
element_ty=K.dtype.element_ty,
desc_k = tl.make_tensor_descriptor(
base=V,
shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
strides=[V_HEAD_DIM, 1],
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=desc_k,
global_address=K,
load_size=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
global_size=[KV_LEN*ZKV*HQ, QK_HEAD_DIM],
element_ty=K.dtype.element_ty,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_q)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_k)
{%- endif %}
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
@ -483,15 +467,14 @@ compute_flex_attention = r"""
order=(1, 0)
)
if USE_TMA:
q = tl._experimental_descriptor_load( # load in row major
{%- if USE_TMA %}
q = tl.load_tensor_descriptor(
desc_q,
[(q_start * BLOCK_M).to(tl.int32), 0],
[BLOCK_M, QK_HEAD_DIM_ROUNDED],
Q.dtype.element_ty,
)
else:
{%- else %}
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
{%- endif %}
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We don't know anything "special" about these blocks, so we need to apply
@ -709,15 +692,14 @@ def forward_block_mn(
# -- load k --
# NB reversed order to since K is transposed
if USE_TMA:
k = tl._experimental_descriptor_load( # load in row major
{%- if USE_TMA %}
k = tl.load_tensor_descriptor( # load in row major
desc_k,
[start_n.to(tl.int32) , kv_start],
[BLOCK_N, QK_HEAD_DIM_ROUNDED],
MATMUL_PRECISION,
)
else:
{%- else %}
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
{%- endif %}
if USE_TMA:
k = tl.trans(k)
@ -784,15 +766,14 @@ def forward_block_mn(
l_i = l_i * alpha + tl.sum(p, 1)
# # -- scale and update acc --
acc = acc * alpha[:, None]
if USE_TMA:
v = tl._experimental_descriptor_load( # load in row major
{%- if USE_TMA %}
v = tl.load_tensor_descriptor(
desc_v,
[kv_start.to(tl.int32) + start_n.to(tl.int32),0],
[BLOCK_N, V_HEAD_DIM_ROUNDED],
MATMUL_PRECISION,
)
else:
{%- else %}
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
{%- endif %}
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
# -- update m_i
@ -1653,20 +1634,6 @@ def flex_attention(
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
workspace_arg = None
if cur_kernel_options.get("USE_TMA", False):
seq_len_q = V.graph.sizevars.evaluate_static_shape(seq_len_q)
grid = flex_attention_grid(
Bq, Hq, seq_len_q, qk_head_dim, cur_kernel_options
)
num_programs = grid[0] * grid[1] * grid[2]
workspace_arg = get_tma_workspace_arg(
num_tma_descriptors=3,
device=query.get_device(),
num_programs=num_programs,
)
error = flex_attention_template.maybe_append_choice(
choices=choices,
input_nodes=[
@ -1687,7 +1654,6 @@ def flex_attention(
mutated_inputs=[
logsumexp,
],
workspace_arg=workspace_arg,
call_sizes=query.get_size(),
**cur_kernel_options,
)