mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[flex attention][triton pin] use new TMA API (#155771)
Triton 3.4 will remove the experimental TMA APIs: https://github.com/triton-lang/triton/pull/6488. Ahead of this, we are **replacing the experimental TMA API usage with the stable TMA API** in flex attention. This means that **flex attention TMA will stop working with Triton 3.2 or Triton 3.3/3.3.1** for now (but it should work for Triton 3.4 in the PyTorch 2.8 release, and Meta-internal triton 3.3.1fb, which have the new TMA API). This PR does the following: * replace the experimental TMA APIs with the stable TMA APIs * remove the workspace args. Testing: I ran test/inductor/test_flex_attention.py on a H100 with @mandroid6's PR #153662 patched in to turn on TMA [TODO: confirm results once all the local tests pass, but from the first 100 tests I ran locally, all the failing tests were also failing on #153662 alone] Note: When #153662 lands, turning on TMA support by default, it should be checking specifically for stable TMA API support (commented on PR) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155771 Approved by: https://github.com/mandroid6, https://github.com/nmacchioni
This commit is contained in:
committed by
PyTorch MergeBot
parent
92b7ed6d07
commit
c843909d9e
@ -51,7 +51,6 @@ from ..select_algorithm import (
|
||||
SymbolicGridFn,
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import get_tma_workspace_arg
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -394,41 +393,26 @@ compute_flex_attention = r"""
|
||||
desc_q = None
|
||||
desc_k = None
|
||||
desc_v = None
|
||||
if USE_TMA:
|
||||
TMA_SIZE = 128
|
||||
workspace_base = ws_ptr + TMA_SIZE * 3 * (
|
||||
tl.program_id(1) + tl.program_id(0) * tl.num_programs(1)
|
||||
{%- if USE_TMA %}
|
||||
desc_q = tl.make_tensor_descriptor(
|
||||
base=Q,
|
||||
shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM],
|
||||
strides=[QK_HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
desc_q = workspace_base
|
||||
desc_v = workspace_base + TMA_SIZE
|
||||
desc_k = workspace_base + 2 * TMA_SIZE
|
||||
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=desc_q,
|
||||
global_address=Q,
|
||||
load_size=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
|
||||
global_size=[Q_LEN*HQ*ZQ, QK_HEAD_DIM],
|
||||
element_ty=Q.dtype.element_ty,
|
||||
desc_v = tl.make_tensor_descriptor(
|
||||
base=V,
|
||||
shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
|
||||
strides=[V_HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=desc_v,
|
||||
global_address=V,
|
||||
load_size=[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||
global_size=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
|
||||
element_ty=K.dtype.element_ty,
|
||||
desc_k = tl.make_tensor_descriptor(
|
||||
base=V,
|
||||
shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM],
|
||||
strides=[V_HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=desc_k,
|
||||
global_address=K,
|
||||
load_size=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
|
||||
global_size=[KV_LEN*ZKV*HQ, QK_HEAD_DIM],
|
||||
element_ty=K.dtype.element_ty,
|
||||
)
|
||||
|
||||
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_q)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_k)
|
||||
{%- endif %}
|
||||
|
||||
|
||||
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
||||
@ -483,15 +467,14 @@ compute_flex_attention = r"""
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
if USE_TMA:
|
||||
q = tl._experimental_descriptor_load( # load in row major
|
||||
{%- if USE_TMA %}
|
||||
q = tl.load_tensor_descriptor(
|
||||
desc_q,
|
||||
[(q_start * BLOCK_M).to(tl.int32), 0],
|
||||
[BLOCK_M, QK_HEAD_DIM_ROUNDED],
|
||||
Q.dtype.element_ty,
|
||||
)
|
||||
else:
|
||||
{%- else %}
|
||||
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
||||
{%- endif %}
|
||||
|
||||
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We don't know anything "special" about these blocks, so we need to apply
|
||||
@ -709,15 +692,14 @@ def forward_block_mn(
|
||||
|
||||
# -- load k --
|
||||
# NB reversed order to since K is transposed
|
||||
if USE_TMA:
|
||||
k = tl._experimental_descriptor_load( # load in row major
|
||||
{%- if USE_TMA %}
|
||||
k = tl.load_tensor_descriptor( # load in row major
|
||||
desc_k,
|
||||
[start_n.to(tl.int32) , kv_start],
|
||||
[BLOCK_N, QK_HEAD_DIM_ROUNDED],
|
||||
MATMUL_PRECISION,
|
||||
)
|
||||
else:
|
||||
{%- else %}
|
||||
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
|
||||
{%- endif %}
|
||||
|
||||
if USE_TMA:
|
||||
k = tl.trans(k)
|
||||
@ -784,15 +766,14 @@ def forward_block_mn(
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
# # -- scale and update acc --
|
||||
acc = acc * alpha[:, None]
|
||||
if USE_TMA:
|
||||
v = tl._experimental_descriptor_load( # load in row major
|
||||
{%- if USE_TMA %}
|
||||
v = tl.load_tensor_descriptor(
|
||||
desc_v,
|
||||
[kv_start.to(tl.int32) + start_n.to(tl.int32),0],
|
||||
[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||
MATMUL_PRECISION,
|
||||
)
|
||||
else:
|
||||
{%- else %}
|
||||
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
||||
{%- endif %}
|
||||
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
||||
|
||||
# -- update m_i
|
||||
@ -1653,20 +1634,6 @@ def flex_attention(
|
||||
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
||||
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
workspace_arg = None
|
||||
if cur_kernel_options.get("USE_TMA", False):
|
||||
seq_len_q = V.graph.sizevars.evaluate_static_shape(seq_len_q)
|
||||
|
||||
grid = flex_attention_grid(
|
||||
Bq, Hq, seq_len_q, qk_head_dim, cur_kernel_options
|
||||
)
|
||||
|
||||
num_programs = grid[0] * grid[1] * grid[2]
|
||||
workspace_arg = get_tma_workspace_arg(
|
||||
num_tma_descriptors=3,
|
||||
device=query.get_device(),
|
||||
num_programs=num_programs,
|
||||
)
|
||||
error = flex_attention_template.maybe_append_choice(
|
||||
choices=choices,
|
||||
input_nodes=[
|
||||
@ -1687,7 +1654,6 @@ def flex_attention(
|
||||
mutated_inputs=[
|
||||
logsumexp,
|
||||
],
|
||||
workspace_arg=workspace_arg,
|
||||
call_sizes=query.get_size(),
|
||||
**cur_kernel_options,
|
||||
)
|
||||
|
Reference in New Issue
Block a user