[Bugfix] Fix fp8 tests for triton_unified_attention for Triton 3.3 (#18013)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Thomas Parnell
2025-05-15 07:26:34 +02:00
committed by GitHub
parent 26d0419309
commit e6b8e65d2d
2 changed files with 7 additions and 0 deletions

View File

@ -99,6 +99,9 @@ def test_triton_unified_attn(
) -> None:
torch.set_default_device("cuda")
if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32:
pytest.skip("block size must be at least 32 for fp8")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]

View File

@ -268,6 +268,10 @@ def unified_attention(
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
block_size = v.shape[1]
assert q.element_size() >= 2 or block_size >= 32, \
"Block size must be at least 32 for fp8"
use_alibi_slopes = alibi_slopes is not None
block_size = v.shape[1]