mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	[Misc] Add FA2 support to ViT MHA layer (#12355)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
		
							
								
								
									
										126
									
								
								tests/kernels/test_mha_attn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								tests/kernels/test_mha_attn.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,126 @@ | ||||
| """ | ||||
| Test: | ||||
|  | ||||
| * Tests for MultiHeadAttention layer | ||||
| """ | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import pytest | ||||
| import torch | ||||
|  | ||||
| from vllm.attention.layer import MultiHeadAttention | ||||
| from vllm.attention.selector import _Backend, _cached_get_attn_backend | ||||
| from vllm.platforms import current_platform | ||||
| from vllm.platforms.cpu import CpuPlatform | ||||
| from vllm.platforms.cuda import CudaPlatform | ||||
| from vllm.platforms.rocm import RocmPlatform | ||||
|  | ||||
|  | ||||
| @pytest.fixture(autouse=True) | ||||
| def clear_cache(): | ||||
|     """Clear lru cache to ensure each test case runs without caching. | ||||
|     """ | ||||
|     _cached_get_attn_backend.cache_clear() | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) | ||||
| def test_mha_attn_platform(device: str): | ||||
|     """ | ||||
|     Test that the attention selector between different platform and device. | ||||
|     """ | ||||
|     torch.set_default_dtype(torch.float16) | ||||
|  | ||||
|     if device == "cpu": | ||||
|         with patch("vllm.attention.selector.current_platform", CpuPlatform()): | ||||
|             attn = MultiHeadAttention(16, 64, scale=1) | ||||
|             assert attn.attn_backend == _Backend.TORCH_SDPA | ||||
|     elif device == "hip": | ||||
|         with patch("vllm.attention.selector.current_platform", RocmPlatform()): | ||||
|             attn = MultiHeadAttention(16, 64, scale=1) | ||||
|             assert attn.attn_backend == _Backend.TORCH_SDPA | ||||
|     else: | ||||
|         with patch("vllm.attention.selector.current_platform", CudaPlatform()): | ||||
|             attn = MultiHeadAttention(16, 64, scale=1) | ||||
|             assert attn.attn_backend == _Backend.FLASH_ATTN | ||||
|  | ||||
|         with patch("vllm.attention.selector.current_platform", CudaPlatform()): | ||||
|             attn = MultiHeadAttention(16, 72, scale=1) | ||||
|             assert attn.attn_backend == _Backend.XFORMERS | ||||
|  | ||||
|  | ||||
| def ref_attention( | ||||
|     query: torch.Tensor, | ||||
|     key: torch.Tensor, | ||||
|     value: torch.Tensor, | ||||
|     scale: float, | ||||
| ) -> torch.Tensor: | ||||
|     """ | ||||
|     Native implementation of scaled dot product attention without mask: | ||||
|     - query, key, value: [batch_size, seq_len, num_heads, head_size] | ||||
|     - attn_mask: [batch_size, seq_len, seq_len] | ||||
|     """ | ||||
|     query, key, value = (x.transpose(1, 2) for x in (query, key, value)) | ||||
|     attn_weights = scale * torch.matmul(query, key.transpose(2, 3)) | ||||
|     attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) | ||||
|     out = torch.matmul(attn_weights, value).transpose(1, 2) | ||||
|     return out | ||||
|  | ||||
|  | ||||
| BATCH_SIZES = [1, 16] | ||||
| SEQ_LENS = [1] | ||||
| NUM_HEADS = [1, 16] | ||||
| NUM_KV_HEADS = [1] | ||||
| HEAD_SIZES = [64, 80] | ||||
| # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} | ||||
| DTYPES = [ | ||||
|     torch.half, torch.bfloat16, torch.float | ||||
| ] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] | ||||
| CUDA_DEVICES = ["cuda"] | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("batch_size", BATCH_SIZES) | ||||
| @pytest.mark.parametrize("seq_len", SEQ_LENS) | ||||
| @pytest.mark.parametrize("num_heads", NUM_HEADS) | ||||
| @pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS) | ||||
| @pytest.mark.parametrize("head_size", HEAD_SIZES) | ||||
| @pytest.mark.parametrize("dtype", DTYPES) | ||||
| @pytest.mark.parametrize("device", CUDA_DEVICES) | ||||
| def test_mha_attn_forward( | ||||
|     batch_size: int, | ||||
|     seq_len: int, | ||||
|     num_heads: int, | ||||
|     num_kv_heads: int, | ||||
|     head_size: int, | ||||
|     dtype: torch.dtype, | ||||
|     device: str, | ||||
| ): | ||||
|     current_platform.seed_everything(0) | ||||
|     torch.set_default_device(device) | ||||
|     torch.set_default_dtype(dtype) | ||||
|  | ||||
|     q = torch.randn(batch_size, seq_len, num_heads * head_size) | ||||
|     k = torch.randn(batch_size, seq_len, num_kv_heads * head_size) | ||||
|     v = torch.randn(batch_size, seq_len, num_kv_heads * head_size) | ||||
|     scale = 1.0 / head_size**0.5 | ||||
|     attn = MultiHeadAttention(num_heads, | ||||
|                               head_size, | ||||
|                               scale=scale, | ||||
|                               num_kv_heads=num_kv_heads) | ||||
|     output = attn(q, k, v) | ||||
|  | ||||
|     assert num_heads % num_kv_heads == 0 | ||||
|     num_queries_per_kv = num_heads // num_kv_heads | ||||
|     q = q.reshape(batch_size, seq_len, num_heads, head_size) | ||||
|     k = k.reshape(batch_size, seq_len, num_kv_heads, head_size) | ||||
|     v = v.reshape(batch_size, seq_len, num_kv_heads, head_size) | ||||
|     if num_queries_per_kv > 1: | ||||
|         k = torch.repeat_interleave(k, num_queries_per_kv, dim=2) | ||||
|         v = torch.repeat_interleave(v, num_queries_per_kv, dim=2) | ||||
|  | ||||
|     ref_output = ref_attention( | ||||
|         q, | ||||
|         k, | ||||
|         v, | ||||
|         scale=scale, | ||||
|     ).reshape(batch_size, seq_len, num_heads * head_size) | ||||
|     torch.testing.assert_close(output, ref_output) | ||||
| @ -210,6 +210,9 @@ class MultiHeadAttention(nn.Module): | ||||
|         self.scale = scale | ||||
|         self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | ||||
|  | ||||
|         assert self.num_heads % self.num_kv_heads == 0 | ||||
|         self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||||
|  | ||||
|         dtype = torch.get_default_dtype() | ||||
|         attn_backend = get_attn_backend(head_size, | ||||
|                                         dtype, | ||||
| @ -217,11 +220,12 @@ class MultiHeadAttention(nn.Module): | ||||
|                                         block_size=16, | ||||
|                                         is_attention_free=False) | ||||
|         backend = backend_name_to_enum(attn_backend.get_name()) | ||||
|         if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: | ||||
|             backend = _Backend.XFORMERS | ||||
|  | ||||
|         self.attn_backend = backend if backend in { | ||||
|             _Backend.TORCH_SDPA, _Backend.XFORMERS | ||||
|             _Backend.TORCH_SDPA, | ||||
|             _Backend.XFORMERS, | ||||
|             _Backend.FLASH_ATTN, | ||||
|             _Backend.FLASH_ATTN_VLLM_V1, | ||||
|         } else _Backend.TORCH_SDPA | ||||
|  | ||||
|     def forward( | ||||
| @ -231,7 +235,6 @@ class MultiHeadAttention(nn.Module): | ||||
|         value: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         """Input shape: batch_size x seq_len x hidden_size""" | ||||
|         # TODO(Isotr0py): Use existing backend implementations and support FA2 | ||||
|         bsz, q_len, _ = query.size() | ||||
|         kv_len = key.size(1) | ||||
|  | ||||
| @ -239,7 +242,19 @@ class MultiHeadAttention(nn.Module): | ||||
|         key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) | ||||
|         value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) | ||||
|  | ||||
|         if self.attn_backend == _Backend.XFORMERS: | ||||
|         if (num_repeat := self.num_queries_per_kv) > 1: | ||||
|             # Handle MQA and GQA | ||||
|             key = torch.repeat_interleave(key, num_repeat, dim=2) | ||||
|             value = torch.repeat_interleave(value, num_repeat, dim=2) | ||||
|  | ||||
|         if self.attn_backend in { | ||||
|                 _Backend.FLASH_ATTN, | ||||
|                 _Backend.FLASH_ATTN_VLLM_V1, | ||||
|         }: | ||||
|             from vllm.vllm_flash_attn import flash_attn_func | ||||
|  | ||||
|             out = flash_attn_func(query, key, value, softmax_scale=self.scale) | ||||
|         elif self.attn_backend == _Backend.XFORMERS: | ||||
|             from xformers import ops as xops | ||||
|  | ||||
|             out = xops.memory_efficient_attention_forward(query, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user