mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Apply Triton tensor descriptor for flex-decoding for performance (#161643)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161643 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
ef3be6726f
commit
a3d72b09ae
@ -31,6 +31,7 @@ from torch.testing._internal.common_device_type import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
|
||||
if IS_WINDOWS and IS_CI:
|
||||
@ -101,12 +102,13 @@ def skip_on_xpu(test_func):
|
||||
return decorated_func
|
||||
|
||||
|
||||
def create_attention(score_mod, block_mask, enable_gqa=False):
|
||||
def create_attention(score_mod, block_mask, enable_gqa=False, kernel_options=None):
|
||||
return functools.partial(
|
||||
flex_attention,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=enable_gqa,
|
||||
kernel_options=kernel_options,
|
||||
)
|
||||
|
||||
|
||||
@ -379,6 +381,7 @@ class TestFlexDecoding(InductorTestCase):
|
||||
V_D: int = D,
|
||||
block_mask: Optional[BlockMask] = None,
|
||||
device="cuda",
|
||||
kernel_options=None,
|
||||
):
|
||||
assert score_mod is not None or block_mask is not None, (
|
||||
"Must provide score_mod or block_mask"
|
||||
@ -409,7 +412,10 @@ class TestFlexDecoding(InductorTestCase):
|
||||
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
||||
|
||||
sdpa_partial = create_attention(
|
||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
||||
score_mod,
|
||||
block_mask,
|
||||
enable_gqa=(not Q_H == KV_H),
|
||||
kernel_options=kernel_options,
|
||||
)
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
if not self.test_inference_only:
|
||||
@ -846,6 +852,28 @@ class TestFlexDecoding(InductorTestCase):
|
||||
)
|
||||
self.run_test(score_mod, dtype, block_mask=block_mask, device=device)
|
||||
|
||||
@unittest.skipIf(not has_triton_tma_device(), "Skip when TMA is not available")
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_tma_decoding(self, device, dtype: torch.dtype):
|
||||
n_heads, head_dim, seq_len = 4, 16, 128
|
||||
|
||||
score_mod = _generate_alibi_bias(n_heads)
|
||||
kernel_options = {"USE_TMA": True}
|
||||
self.run_test(
|
||||
score_mod=score_mod,
|
||||
dtype=dtype,
|
||||
Q_B=1,
|
||||
Q_H=n_heads,
|
||||
Q_S=1,
|
||||
Q_D=head_dim,
|
||||
KV_B=1,
|
||||
KV_H=n_heads,
|
||||
KV_S=seq_len,
|
||||
V_D=head_dim,
|
||||
device=device,
|
||||
kernel_options=kernel_options,
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
@common_utils.parametrize("k_s", test_input_strides)
|
||||
|
@ -131,9 +131,27 @@
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_N) + off_n
|
||||
|
||||
desc_k = None
|
||||
desc_v = None
|
||||
{%- if USE_TMA %}
|
||||
desc_k = tl.make_tensor_descriptor(
|
||||
base=K,
|
||||
shape=[KV_LEN, QK_HEAD_DIM],
|
||||
strides=[stride_kn, 1],
|
||||
block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
|
||||
desc_v = tl.make_tensor_descriptor(
|
||||
base=V,
|
||||
shape=[KV_LEN, V_HEAD_DIM],
|
||||
strides=[stride_vn, 1],
|
||||
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
{%- endif %}
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K, V, None, None, Q_LEN, KV_LEN,
|
||||
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
@ -168,7 +186,7 @@
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K, V, None, None, Q_LEN, KV_LEN,
|
||||
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
|
Reference in New Issue
Block a user