mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	[ROCm] Faster Custom Paged Attention kernels (#12348)
This commit is contained in:
		| @ -77,7 +77,6 @@ echo "Commands:$commands" | ||||
| #ignore certain kernels tests | ||||
| if [[ $commands == *" kernels "* ]]; then | ||||
|   commands="${commands} \ | ||||
|   --ignore=kernels/test_attention.py \ | ||||
|   --ignore=kernels/test_attention_selector.py \ | ||||
|   --ignore=kernels/test_blocksparse_attention.py \ | ||||
|   --ignore=kernels/test_causal_conv1d.py \ | ||||
|  | ||||
| @ -11,8 +11,9 @@ from vllm.platforms import current_platform | ||||
| from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, | ||||
|                         create_kv_caches_with_random) | ||||
|  | ||||
| NUM_BLOCKS = 1024 | ||||
| NUM_BLOCKS = 128 * 1024 | ||||
| PARTITION_SIZE = 512 | ||||
| PARTITION_SIZE_ROCM = 256 | ||||
|  | ||||
|  | ||||
| @torch.inference_mode() | ||||
| @ -80,6 +81,12 @@ def main( | ||||
|     # Prepare for the paged attention kernel. | ||||
|     output = torch.empty_like(query) | ||||
|     if version == "v2": | ||||
|         if current_platform.is_rocm(): | ||||
|             global PARTITION_SIZE | ||||
|             if not args.custom_paged_attn: | ||||
|                 PARTITION_SIZE = 1024 | ||||
|             else: | ||||
|                 PARTITION_SIZE = PARTITION_SIZE_ROCM | ||||
|         num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) | ||||
|         tmp_output = torch.empty( | ||||
|             size=(num_seqs, num_query_heads, num_partitions, head_size), | ||||
| @ -123,6 +130,7 @@ def main( | ||||
|                     v_scale, | ||||
|                 ) | ||||
|             elif version == "v2": | ||||
|                 if not args.custom_paged_attn: | ||||
|                     ops.paged_attention_v2( | ||||
|                         output, | ||||
|                         exp_sums, | ||||
| @ -142,6 +150,26 @@ def main( | ||||
|                         k_scale, | ||||
|                         v_scale, | ||||
|                     ) | ||||
|                 else: | ||||
|                     ops.paged_attention_rocm( | ||||
|                         output, | ||||
|                         exp_sums, | ||||
|                         max_logits, | ||||
|                         tmp_output, | ||||
|                         query, | ||||
|                         key_cache, | ||||
|                         value_cache, | ||||
|                         num_kv_heads, | ||||
|                         scale, | ||||
|                         block_tables, | ||||
|                         seq_lens, | ||||
|                         block_size, | ||||
|                         max_seq_len, | ||||
|                         alibi_slopes, | ||||
|                         kv_cache_dtype, | ||||
|                         k_scale, | ||||
|                         v_scale, | ||||
|                     ) | ||||
|             else: | ||||
|                 raise ValueError(f"Invalid version: {version}") | ||||
|         torch.cuda.synchronize() | ||||
| @ -195,6 +223,9 @@ if __name__ == '__main__': | ||||
|         help="Data type for kv cache storage. If 'auto', will use model " | ||||
|         "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " | ||||
|         "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") | ||||
|     parser.add_argument("--custom-paged-attn", | ||||
|                         action="store_true", | ||||
|                         help="Use custom paged attention") | ||||
|     args = parser.parse_args() | ||||
|     print(args) | ||||
|  | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -25,6 +25,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 | ||||
| # Reduce NUM_BLOCKS when it happens. | ||||
| NUM_BLOCKS = 4321  # Arbitrary values for testing | ||||
| PARTITION_SIZE = 512 | ||||
| PARTITION_SIZE_ROCM = 256 | ||||
| # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} | ||||
| DTYPES = [ | ||||
|     torch.half, torch.bfloat16, torch.float | ||||
| @ -146,6 +147,8 @@ def test_paged_attention( | ||||
|             or (version == "rocm" and head_size not in (64, 128))): | ||||
|         pytest.skip() | ||||
|  | ||||
|     global PARTITION_SIZE | ||||
|  | ||||
|     current_platform.seed_everything(seed) | ||||
|     torch.set_default_device(device) | ||||
|     scale = float(1.0 / (head_size**0.5)) | ||||
| @ -214,6 +217,9 @@ def test_paged_attention( | ||||
|                       and block_size == BLOCK_SIZES[0])) | ||||
|  | ||||
|     elif version in ("v2", "rocm"): | ||||
|         if current_platform.is_rocm() and version == "rocm": | ||||
|             PARTITION_SIZE = PARTITION_SIZE_ROCM | ||||
|  | ||||
|         num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) | ||||
|         assert PARTITION_SIZE % block_size == 0 | ||||
|         num_seqs, num_heads, head_size = output.shape | ||||
|  | ||||
| @ -22,7 +22,7 @@ if TYPE_CHECKING: | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
| _PARTITION_SIZE_ROCM = 512 | ||||
| _PARTITION_SIZE_ROCM = 256 | ||||
| _GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName | ||||
| _ON_NAVI = "gfx1" in _GPU_ARCH | ||||
| _ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"]) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user