[ROCm] Faster Custom Paged Attention kernels (#12348)

This commit is contained in:
TJian
2025-03-04 01:24:45 +08:00
committed by GitHub
parent 98175b2816
commit 848a6438ae
6 changed files with 1145 additions and 447 deletions

View File

@ -77,7 +77,6 @@ echo "Commands:$commands"
#ignore certain kernels tests
if [[ $commands == *" kernels "* ]]; then
commands="${commands} \
--ignore=kernels/test_attention.py \
--ignore=kernels/test_attention_selector.py \
--ignore=kernels/test_blocksparse_attention.py \
--ignore=kernels/test_causal_conv1d.py \

View File

@ -11,8 +11,9 @@ from vllm.platforms import current_platform
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)
NUM_BLOCKS = 1024
NUM_BLOCKS = 128 * 1024
PARTITION_SIZE = 512
PARTITION_SIZE_ROCM = 256
@torch.inference_mode()
@ -80,6 +81,12 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
if current_platform.is_rocm():
global PARTITION_SIZE
if not args.custom_paged_attn:
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
@ -123,25 +130,46 @@ def main(
v_scale,
)
elif version == "v2":
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
if not args.custom_paged_attn:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize()
@ -195,6 +223,9 @@ if __name__ == '__main__':
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
parser.add_argument("--custom-paged-attn",
action="store_true",
help="Use custom paged attention")
args = parser.parse_args()
print(args)

File diff suppressed because it is too large Load Diff

View File

@ -11,4 +11,4 @@ peft
pytest-asyncio
tensorizer>=2.9.0
runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
runai-model-streamer-s3==0.11.0

View File

@ -25,6 +25,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
PARTITION_SIZE_ROCM = 256
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [
torch.half, torch.bfloat16, torch.float
@ -146,6 +147,8 @@ def test_paged_attention(
or (version == "rocm" and head_size not in (64, 128))):
pytest.skip()
global PARTITION_SIZE
current_platform.seed_everything(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
@ -214,6 +217,9 @@ def test_paged_attention(
and block_size == BLOCK_SIZES[0]))
elif version in ("v2", "rocm"):
if current_platform.is_rocm() and version == "rocm":
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
@ -432,4 +438,4 @@ def test_multi_query_kv_attention(
)
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)

View File

@ -22,7 +22,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_PARTITION_SIZE_ROCM = 512
_PARTITION_SIZE_ROCM = 256
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])
@ -885,4 +885,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)