mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ROCm] Faster Custom Paged Attention kernels (#12348)
This commit is contained in:
@ -77,7 +77,6 @@ echo "Commands:$commands"
|
||||
#ignore certain kernels tests
|
||||
if [[ $commands == *" kernels "* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/test_attention.py \
|
||||
--ignore=kernels/test_attention_selector.py \
|
||||
--ignore=kernels/test_blocksparse_attention.py \
|
||||
--ignore=kernels/test_causal_conv1d.py \
|
||||
|
@ -11,8 +11,9 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
create_kv_caches_with_random)
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
NUM_BLOCKS = 128 * 1024
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -80,6 +81,12 @@ def main(
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v2":
|
||||
if current_platform.is_rocm():
|
||||
global PARTITION_SIZE
|
||||
if not args.custom_paged_attn:
|
||||
PARTITION_SIZE = 1024
|
||||
else:
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
@ -123,25 +130,46 @@ def main(
|
||||
v_scale,
|
||||
)
|
||||
elif version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
if not args.custom_paged_attn:
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
torch.cuda.synchronize()
|
||||
@ -195,6 +223,9 @@ if __name__ == '__main__':
|
||||
help="Data type for kv cache storage. If 'auto', will use model "
|
||||
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
||||
parser.add_argument("--custom-paged-attn",
|
||||
action="store_true",
|
||||
help="Use custom paged attention")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -11,4 +11,4 @@ peft
|
||||
pytest-asyncio
|
||||
tensorizer>=2.9.0
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
@ -25,6 +25,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
# Reduce NUM_BLOCKS when it happens.
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = [
|
||||
torch.half, torch.bfloat16, torch.float
|
||||
@ -146,6 +147,8 @@ def test_paged_attention(
|
||||
or (version == "rocm" and head_size not in (64, 128))):
|
||||
pytest.skip()
|
||||
|
||||
global PARTITION_SIZE
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
@ -214,6 +217,9 @@ def test_paged_attention(
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
|
||||
elif version in ("v2", "rocm"):
|
||||
if current_platform.is_rocm() and version == "rocm":
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
@ -432,4 +438,4 @@ def test_multi_query_kv_attention(
|
||||
)
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PARTITION_SIZE_ROCM = 512
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
_ON_NAVI = "gfx1" in _GPU_ARCH
|
||||
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
||||
@ -885,4 +885,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
Reference in New Issue
Block a user