mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[flex attention][triton pin] use new TMA API (#155771)
Triton 3.4 will remove the experimental TMA APIs: https://github.com/triton-lang/triton/pull/6488. Ahead of this, we are **replacing the experimental TMA API usage with the stable TMA API** in flex attention. This means that **flex attention TMA will stop working with Triton 3.2 or Triton 3.3/3.3.1** for now (but it should work for Triton 3.4 in the PyTorch 2.8 release, and Meta-internal triton 3.3.1fb, which have the new TMA API). This PR does the following: * replace the experimental TMA APIs with the stable TMA APIs * remove the workspace args. Testing: I ran test/inductor/test_flex_attention.py on a H100 with @mandroid6's PR #153662 patched in to turn on TMA [TODO: confirm results once all the local tests pass, but from the first 100 tests I ran locally, all the failing tests were also failing on #153662 alone] Note: When #153662 lands, turning on TMA support by default, it should be checking specifically for stable TMA API support (commented on PR) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155771 Approved by: https://github.com/mandroid6, https://github.com/nmacchioni
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							92b7ed6d07
						
					
				
				
					commit
					c843909d9e
				
			| @ -51,7 +51,6 @@ from ..select_algorithm import ( | ||||
|     SymbolicGridFn, | ||||
|     TritonTemplate, | ||||
| ) | ||||
| from ..utils import get_tma_workspace_arg | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
| @ -394,41 +393,26 @@ compute_flex_attention = r""" | ||||
|     desc_q = None | ||||
|     desc_k = None | ||||
|     desc_v = None | ||||
|     if USE_TMA: | ||||
|         TMA_SIZE = 128 | ||||
|         workspace_base = ws_ptr + TMA_SIZE * 3 * ( | ||||
|             tl.program_id(1) + tl.program_id(0) * tl.num_programs(1) | ||||
|     {%- if USE_TMA %} | ||||
|     desc_q = tl.make_tensor_descriptor( | ||||
|         base=Q, | ||||
|         shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM], | ||||
|         strides=[QK_HEAD_DIM, 1], | ||||
|         block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], | ||||
|     ) | ||||
|         desc_q = workspace_base | ||||
|         desc_v = workspace_base + TMA_SIZE | ||||
|         desc_k = workspace_base + 2 * TMA_SIZE | ||||
|  | ||||
|         triton.language.extra.cuda.experimental_device_tensormap_create2d( | ||||
|             desc_ptr=desc_q, | ||||
|             global_address=Q, | ||||
|             load_size=[BLOCK_M, QK_HEAD_DIM_ROUNDED], | ||||
|             global_size=[Q_LEN*HQ*ZQ, QK_HEAD_DIM], | ||||
|             element_ty=Q.dtype.element_ty, | ||||
|     desc_v = tl.make_tensor_descriptor( | ||||
|         base=V, | ||||
|         shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], | ||||
|         strides=[V_HEAD_DIM, 1], | ||||
|         block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], | ||||
|     ) | ||||
|         triton.language.extra.cuda.experimental_device_tensormap_create2d( | ||||
|             desc_ptr=desc_v, | ||||
|             global_address=V, | ||||
|             load_size=[BLOCK_N, V_HEAD_DIM_ROUNDED], | ||||
|             global_size=[KV_LEN*ZKV*HQ, V_HEAD_DIM], | ||||
|             element_ty=K.dtype.element_ty, | ||||
|     desc_k = tl.make_tensor_descriptor( | ||||
|         base=V, | ||||
|         shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], | ||||
|         strides=[V_HEAD_DIM, 1], | ||||
|         block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], | ||||
|     ) | ||||
|  | ||||
|         triton.language.extra.cuda.experimental_device_tensormap_create2d( | ||||
|             desc_ptr=desc_k, | ||||
|             global_address=K, | ||||
|             load_size=[BLOCK_N, QK_HEAD_DIM_ROUNDED], | ||||
|             global_size=[KV_LEN*ZKV*HQ, QK_HEAD_DIM], | ||||
|             element_ty=K.dtype.element_ty, | ||||
|         ) | ||||
|  | ||||
|  | ||||
|         tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_q) | ||||
|         tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_k) | ||||
|     {%- endif %} | ||||
|  | ||||
|  | ||||
|     # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. | ||||
| @ -483,15 +467,14 @@ compute_flex_attention = r""" | ||||
|             order=(1, 0) | ||||
|         ) | ||||
|  | ||||
|     if USE_TMA: | ||||
|         q = tl._experimental_descriptor_load(  # load in row major | ||||
|     {%- if USE_TMA %} | ||||
|     q = tl.load_tensor_descriptor( | ||||
|         desc_q, | ||||
|         [(q_start * BLOCK_M).to(tl.int32), 0], | ||||
|             [BLOCK_M, QK_HEAD_DIM_ROUNDED], | ||||
|             Q.dtype.element_ty, | ||||
|     ) | ||||
|     else: | ||||
|     {%- else %} | ||||
|         q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) | ||||
|     {%- endif %} | ||||
|  | ||||
|     # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|     # We don't know anything "special" about these blocks, so we need to apply | ||||
| @ -709,15 +692,14 @@ def forward_block_mn( | ||||
|  | ||||
|     # -- load k -- | ||||
|     # NB reversed order to since K is transposed | ||||
|     if USE_TMA: | ||||
|        k = tl._experimental_descriptor_load(  # load in row major | ||||
|     {%- if USE_TMA %} | ||||
|     k = tl.load_tensor_descriptor(  # load in row major | ||||
|             desc_k, | ||||
|             [start_n.to(tl.int32) , kv_start], | ||||
|                 [BLOCK_N, QK_HEAD_DIM_ROUNDED], | ||||
|                 MATMUL_PRECISION, | ||||
|     ) | ||||
|     else: | ||||
|     {%- else %} | ||||
|     k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) | ||||
|     {%- endif %} | ||||
|  | ||||
|     if USE_TMA: | ||||
|         k = tl.trans(k) | ||||
| @ -784,15 +766,14 @@ def forward_block_mn( | ||||
|     l_i = l_i * alpha + tl.sum(p, 1) | ||||
|     # # -- scale and update acc -- | ||||
|     acc = acc * alpha[:, None] | ||||
|     if USE_TMA: | ||||
|         v = tl._experimental_descriptor_load(  # load in row major | ||||
|     {%- if USE_TMA %} | ||||
|     v = tl.load_tensor_descriptor( | ||||
|         desc_v, | ||||
|         [kv_start.to(tl.int32) + start_n.to(tl.int32),0], | ||||
|                     [BLOCK_N, V_HEAD_DIM_ROUNDED], | ||||
|                     MATMUL_PRECISION, | ||||
|     ) | ||||
|     else: | ||||
|     {%- else %} | ||||
|     v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) | ||||
|     {%- endif %} | ||||
|     acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) | ||||
|  | ||||
|     # -- update m_i | ||||
| @ -1653,20 +1634,6 @@ def flex_attention( | ||||
|         cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) | ||||
|         cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) | ||||
|  | ||||
|         workspace_arg = None | ||||
|         if cur_kernel_options.get("USE_TMA", False): | ||||
|             seq_len_q = V.graph.sizevars.evaluate_static_shape(seq_len_q) | ||||
|  | ||||
|             grid = flex_attention_grid( | ||||
|                 Bq, Hq, seq_len_q, qk_head_dim, cur_kernel_options | ||||
|             ) | ||||
|  | ||||
|             num_programs = grid[0] * grid[1] * grid[2] | ||||
|             workspace_arg = get_tma_workspace_arg( | ||||
|                 num_tma_descriptors=3, | ||||
|                 device=query.get_device(), | ||||
|                 num_programs=num_programs, | ||||
|             ) | ||||
|         error = flex_attention_template.maybe_append_choice( | ||||
|             choices=choices, | ||||
|             input_nodes=[ | ||||
| @ -1687,7 +1654,6 @@ def flex_attention( | ||||
|             mutated_inputs=[ | ||||
|                 logsumexp, | ||||
|             ], | ||||
|             workspace_arg=workspace_arg, | ||||
|             call_sizes=query.get_size(), | ||||
|             **cur_kernel_options, | ||||
|         ) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user