mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel][Model] logits_soft_cap for Gemma2 with flashinfer (#6051)
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@ -118,12 +118,15 @@ steps:
|
||||
|
||||
- label: Kernels Test %N
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
|
||||
- label: Models Test
|
||||
#mirror_hardwares: [amd]
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- pytest -v -s models -m \"not vlm\"
|
||||
|
||||
- label: Vision Language Models Test
|
||||
@ -234,7 +237,7 @@ steps:
|
||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s -x lora/test_mixtral.py
|
||||
|
248
tests/kernels/test_flashinfer.py
Normal file
248
tests/kernels/test_flashinfer.py
Normal file
@ -0,0 +1,248 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query_lens: List[int],
|
||||
kv_lens: List[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: Optional[int] = None,
|
||||
soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
|
||||
outputs: List[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx:start_idx + query_len]
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_kv_blocks]
|
||||
|
||||
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
k = k[:kv_len]
|
||||
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
v = v[:kv_len]
|
||||
|
||||
if q.shape[1] != k.shape[1]:
|
||||
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
|
||||
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
|
||||
attn = torch.einsum("qhd,khd->hqk", q, k).float()
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = torch.triu(empty_mask,
|
||||
diagonal=kv_len -
|
||||
(query_len + sliding_window) +
|
||||
1).bool().logical_not()
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||
|
||||
outputs.append(out)
|
||||
start_idx += query_len
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
||||
num_heads: Tuple[int,
|
||||
int], head_size: int,
|
||||
dtype: torch.dtype, block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
key_value_cache = torch.randn(NUM_BLOCKS,
|
||||
2,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||
wrapper.begin_forward(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
data_type=dtype)
|
||||
|
||||
output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap)
|
||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int, dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens),
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_value_cache = torch.randn(NUM_BLOCKS,
|
||||
2,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
# Normalize the scale of the key and value caches to mitigate
|
||||
# numerical instability.
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
qo_indptr = [0]
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
qo_indptr.append(qo_indptr[-1] + query_lens[i])
|
||||
|
||||
qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD")
|
||||
wrapper.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
)
|
||||
|
||||
output = wrapper.forward(
|
||||
query,
|
||||
key_value_cache,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap)
|
||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
@ -102,6 +102,8 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
# The data type of the paged kv cache
|
||||
data_type: torch.dtype = None
|
||||
device: torch.device = torch.device("cuda")
|
||||
# Only used by gemma2 model
|
||||
logits_soft_cap: Optional[float] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Refer to
|
||||
@ -271,9 +273,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
else:
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
output = prefill_meta.prefill_wrapper.forward(query,
|
||||
kv_cache,
|
||||
causal=True)
|
||||
output = prefill_meta.prefill_wrapper.forward(
|
||||
query,
|
||||
kv_cache,
|
||||
logits_soft_cap=attn_metadata.logits_soft_cap,
|
||||
causal=True)
|
||||
else:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||
@ -281,5 +285,5 @@ class FlashInferImpl(AttentionImpl):
|
||||
query,
|
||||
kv_cache,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
logits_soft_cap=attn_metadata.logits_soft_cap)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
@ -77,9 +77,9 @@ def get_attn_backend(
|
||||
return IpexAttnBackend
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
logger.info("Using Flashinfer backend.")
|
||||
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
|
||||
" please avoid using Flashinfer as the"
|
||||
"backend when running on llma-2-7b."))
|
||||
logger.warning(("Flashinfer will be stuck on llama-2-7b,"
|
||||
" please avoid using Flashinfer as the "
|
||||
"backend when running on llama-2-7b."))
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend
|
||||
elif backend == _Backend.PALLAS:
|
||||
|
@ -38,7 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
@ -137,12 +136,6 @@ class Gemma2Attention(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
if self.config.attn_logit_softcapping is not None:
|
||||
print_warning_once(
|
||||
"Gemma 2 normally uses attention logit soft-capping; "
|
||||
"soft-capping is currently incompatible with the flash "
|
||||
"attention kernels, so vLLM removes it to enable speed and "
|
||||
"efficiency gains of flash attention.")
|
||||
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
||||
# odd layer, vLLM currently ignores it and uses global attention for
|
||||
# all layers.
|
||||
|
@ -15,7 +15,7 @@ try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
@ -683,6 +683,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
logits_soft_cap = getattr(self.model_config.hf_config,
|
||||
'attn_logit_softcapping', None)
|
||||
if logits_soft_cap is not None and self.attn_backend.get_name(
|
||||
) != "flashinfer":
|
||||
raise ValueError("Please use Flashinfer backend for models with"
|
||||
"logits_soft_cap (i.e., Gemma-2)."
|
||||
" Otherwise, the output might be wrong."
|
||||
" Set Flashinfer backend by "
|
||||
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
if len(paged_kv_indptr) > 0:
|
||||
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
|
||||
@ -700,7 +710,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
|
||||
self.model_config.dtype)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
@ -721,7 +730,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
query_start_loc=query_start_loc,
|
||||
device=self.device,
|
||||
data_type=kv_cache_dtype,
|
||||
use_cuda_graph=use_captured_graph)
|
||||
use_cuda_graph=use_captured_graph,
|
||||
logits_soft_cap=logits_soft_cap)
|
||||
|
||||
else:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
@ -1196,7 +1206,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
if model_input.attn_metadata.use_cuda_graph:
|
||||
batch_size = model_input.input_tokens.shape[0]
|
||||
model_input.attn_metadata.decode_wrapper = self.graph_runners[
|
||||
batch_size].flashinfer_decode_wrapper
|
||||
model_input.
|
||||
virtual_engine][batch_size].flashinfer_decode_wrapper
|
||||
else:
|
||||
model_input.attn_metadata.decode_wrapper = \
|
||||
self.flashinfer_decode_wrapper
|
||||
|
Reference in New Issue
Block a user