From c843909d9e32f92b2e31cf9b8f066daf311a6f18 Mon Sep 17 00:00:00 2001 From: David Berard Date: Thu, 12 Jun 2025 09:40:11 -0700 Subject: [PATCH] [flex attention][triton pin] use new TMA API (#155771) Triton 3.4 will remove the experimental TMA APIs: https://github.com/triton-lang/triton/pull/6488. Ahead of this, we are **replacing the experimental TMA API usage with the stable TMA API** in flex attention. This means that **flex attention TMA will stop working with Triton 3.2 or Triton 3.3/3.3.1** for now (but it should work for Triton 3.4 in the PyTorch 2.8 release, and Meta-internal triton 3.3.1fb, which have the new TMA API). This PR does the following: * replace the experimental TMA APIs with the stable TMA APIs * remove the workspace args. Testing: I ran test/inductor/test_flex_attention.py on a H100 with @mandroid6's PR #153662 patched in to turn on TMA [TODO: confirm results once all the local tests pass, but from the first 100 tests I ran locally, all the failing tests were also failing on #153662 alone] Note: When #153662 lands, turning on TMA support by default, it should be checking specifically for stable TMA API support (commented on PR) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155771 Approved by: https://github.com/mandroid6, https://github.com/nmacchioni --- torch/_inductor/kernel/flex_attention.py | 120 ++++++++--------------- 1 file changed, 43 insertions(+), 77 deletions(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index e79683b9e8bb..a3204de8b39f 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -51,7 +51,6 @@ from ..select_algorithm import ( SymbolicGridFn, TritonTemplate, ) -from ..utils import get_tma_workspace_arg log = logging.getLogger(__name__) @@ -394,41 +393,26 @@ compute_flex_attention = r""" desc_q = None desc_k = None desc_v = None - if USE_TMA: - TMA_SIZE = 128 - workspace_base = ws_ptr + TMA_SIZE * 3 * ( - tl.program_id(1) + tl.program_id(0) * tl.num_programs(1) - ) - desc_q = workspace_base - desc_v = workspace_base + TMA_SIZE - desc_k = workspace_base + 2 * TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=desc_q, - global_address=Q, - load_size=[BLOCK_M, QK_HEAD_DIM_ROUNDED], - global_size=[Q_LEN*HQ*ZQ, QK_HEAD_DIM], - element_ty=Q.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=desc_v, - global_address=V, - load_size=[BLOCK_N, V_HEAD_DIM_ROUNDED], - global_size=[KV_LEN*ZKV*HQ, V_HEAD_DIM], - element_ty=K.dtype.element_ty, - ) - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=desc_k, - global_address=K, - load_size=[BLOCK_N, QK_HEAD_DIM_ROUNDED], - global_size=[KV_LEN*ZKV*HQ, QK_HEAD_DIM], - element_ty=K.dtype.element_ty, - ) - - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_q) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_k) + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM], + strides=[QK_HEAD_DIM, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + desc_k = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. @@ -483,15 +467,14 @@ compute_flex_attention = r""" order=(1, 0) ) - if USE_TMA: - q = tl._experimental_descriptor_load( # load in row major - desc_q, - [(q_start * BLOCK_M).to(tl.int32), 0], - [BLOCK_M, QK_HEAD_DIM_ROUNDED], - Q.dtype.element_ty, - ) - else: + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We don't know anything "special" about these blocks, so we need to apply @@ -709,15 +692,14 @@ def forward_block_mn( # -- load k -- # NB reversed order to since K is transposed - if USE_TMA: - k = tl._experimental_descriptor_load( # load in row major - desc_k, - [start_n.to(tl.int32) , kv_start], - [BLOCK_N, QK_HEAD_DIM_ROUNDED], - MATMUL_PRECISION, - ) - else: - k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( # load in row major + desc_k, + [start_n.to(tl.int32) , kv_start], + ) + {%- else %} + k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- endif %} if USE_TMA: k = tl.trans(k) @@ -784,15 +766,14 @@ def forward_block_mn( l_i = l_i * alpha + tl.sum(p, 1) # # -- scale and update acc -- acc = acc * alpha[:, None] - if USE_TMA: - v = tl._experimental_descriptor_load( # load in row major - desc_v, - [kv_start.to(tl.int32) + start_n.to(tl.int32),0], - [BLOCK_N, V_HEAD_DIM_ROUNDED], - MATMUL_PRECISION, - ) - else: - v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_start.to(tl.int32) + start_n.to(tl.int32),0], + ) + {%- else %} + v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) # -- update m_i @@ -1653,20 +1634,6 @@ def flex_attention( cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - workspace_arg = None - if cur_kernel_options.get("USE_TMA", False): - seq_len_q = V.graph.sizevars.evaluate_static_shape(seq_len_q) - - grid = flex_attention_grid( - Bq, Hq, seq_len_q, qk_head_dim, cur_kernel_options - ) - - num_programs = grid[0] * grid[1] * grid[2] - workspace_arg = get_tma_workspace_arg( - num_tma_descriptors=3, - device=query.get_device(), - num_programs=num_programs, - ) error = flex_attention_template.maybe_append_choice( choices=choices, input_nodes=[ @@ -1687,7 +1654,6 @@ def flex_attention( mutated_inputs=[ logsumexp, ], - workspace_arg=workspace_arg, call_sizes=query.get_size(), **cur_kernel_options, )