Apply Triton tensor descriptor for flex-decoding for performance (#161643)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161643
Approved by: https://github.com/drisspg
This commit is contained in:
Wang, Eikan
2025-09-04 15:49:01 +00:00
committed by PyTorch MergeBot
parent ef3be6726f
commit a3d72b09ae
2 changed files with 50 additions and 4 deletions

View File

@ -31,6 +31,7 @@ from torch.testing._internal.common_device_type import (
)
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton_tma_device
if IS_WINDOWS and IS_CI:
@ -101,12 +102,13 @@ def skip_on_xpu(test_func):
return decorated_func
def create_attention(score_mod, block_mask, enable_gqa=False):
def create_attention(score_mod, block_mask, enable_gqa=False, kernel_options=None):
return functools.partial(
flex_attention,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
kernel_options=kernel_options,
)
@ -379,6 +381,7 @@ class TestFlexDecoding(InductorTestCase):
V_D: int = D,
block_mask: Optional[BlockMask] = None,
device="cuda",
kernel_options=None,
):
assert score_mod is not None or block_mask is not None, (
"Must provide score_mod or block_mask"
@ -409,7 +412,10 @@ class TestFlexDecoding(InductorTestCase):
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
sdpa_partial = create_attention(
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
score_mod,
block_mask,
enable_gqa=(not Q_H == KV_H),
kernel_options=kernel_options,
)
compiled_sdpa = torch.compile(sdpa_partial)
if not self.test_inference_only:
@ -846,6 +852,28 @@ class TestFlexDecoding(InductorTestCase):
)
self.run_test(score_mod, dtype, block_mask=block_mask, device=device)
@unittest.skipIf(not has_triton_tma_device(), "Skip when TMA is not available")
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_tma_decoding(self, device, dtype: torch.dtype):
n_heads, head_dim, seq_len = 4, 16, 128
score_mod = _generate_alibi_bias(n_heads)
kernel_options = {"USE_TMA": True}
self.run_test(
score_mod=score_mod,
dtype=dtype,
Q_B=1,
Q_H=n_heads,
Q_S=1,
Q_D=head_dim,
KV_B=1,
KV_H=n_heads,
KV_S=seq_len,
V_D=head_dim,
device=device,
kernel_options=kernel_options,
)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("k_s", test_input_strides)

View File

@ -131,9 +131,27 @@
offs_n = tl.arange(0, BLOCK_N) + off_n
desc_k = None
desc_v = None
{%- if USE_TMA %}
desc_k = tl.make_tensor_descriptor(
base=K,
shape=[KV_LEN, QK_HEAD_DIM],
strides=[stride_kn, 1],
block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
)
desc_v = tl.make_tensor_descriptor(
base=V,
shape=[KV_LEN, V_HEAD_DIM],
strides=[stride_vn, 1],
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
)
{%- endif %}
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K, V, None, None, Q_LEN, KV_LEN,
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulatd values
acc, l_i, m_i,
#offsets
@ -168,7 +186,7 @@
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K, V, None, None, Q_LEN, KV_LEN,
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulatd values
acc, l_i, m_i,
#offsets