mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode (#4628)
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>, bong-furiosa <bongwon.jang@furiosa.ai>
This commit is contained in:
@ -211,3 +211,6 @@ steps:
|
||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
@ -19,4 +19,4 @@ sentence-transformers # required for embedding
|
||||
aiohttp
|
||||
|
||||
# quantization
|
||||
bitsandbytes==0.42.0
|
||||
bitsandbytes==0.42.0
|
@ -2,7 +2,6 @@
|
||||
|
||||
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
||||
"""
|
||||
import os
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
@ -13,7 +12,6 @@ MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
|
||||
def test_vllm_gc_ed():
|
||||
@ -39,10 +37,6 @@ def test_models(
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
|
||||
pytest.skip("Skipping non-eager test for FlashInferBackend.")
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
@ -21,7 +21,6 @@ MODELS = [
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
|
||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
@ -39,16 +38,12 @@ def test_models(
|
||||
) -> None:
|
||||
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
|
||||
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
enforce_eager = backend_by_env_var == "FLASHINFER"
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=enforce_eager,
|
||||
distributed_executor_backend=distributed_executor_backend
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
@ -1,10 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import flashinfer
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
from vllm_flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
|
||||
import torch
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
|
||||
use_cuda_graph: bool = False
|
||||
use_cuda_graph: bool = True
|
||||
|
||||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||
|
||||
# Metadata for the prefill stage since we still
|
||||
# use flash attention for prefill.
|
||||
# Metadata for the prefill stage
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# Metadata for the decode stage
|
||||
# Workspace buffer required by the kernel, the buffer should not
|
||||
# be allocated/deacollated by the FalshInfermetadata object.
|
||||
workspace_buffer: Optional[torch.Tensor] = None
|
||||
# An example for paged_kv_indices, paged_kv_indptr:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
page_size: Optional[int] = None
|
||||
# The data type of the paged kv cache
|
||||
data_type: torch.dtype = None
|
||||
device: torch.device = torch.device("cuda")
|
||||
|
||||
def __post_init__(self):
|
||||
# Refer to
|
||||
@ -109,13 +113,35 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
# When using flashinfer, we are also creating the FlashInferMetadata,
|
||||
# which will also call post_init by default, here we want to skip the
|
||||
# post_init if it's the prefill phase.
|
||||
if self.num_prefills == 0:
|
||||
assert self.num_decode_tokens > 0
|
||||
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD")
|
||||
def begin_forward(self):
|
||||
if self.num_prefill_tokens > 0:
|
||||
if self.paged_kv_indices is None:
|
||||
return
|
||||
|
||||
assert self.prefill_wrapper is not None
|
||||
assert self.paged_kv_indices is not None
|
||||
assert self.paged_kv_indptr is not None
|
||||
assert self.paged_kv_last_page_len is not None
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.query_start_loc, self.paged_kv_indptr,
|
||||
self.paged_kv_indices, self.paged_kv_last_page_len,
|
||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||
self.page_size)
|
||||
else:
|
||||
if not self.use_cuda_graph:
|
||||
assert self.paged_kv_indices is not None
|
||||
assert self.paged_kv_indptr is not None
|
||||
assert self.paged_kv_last_page_len is not None
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
|
||||
assert self.decode_wrapper is not None
|
||||
self.decode_wrapper.begin_forward(
|
||||
self.paged_kv_indptr,
|
||||
self.paged_kv_indices,
|
||||
@ -133,8 +159,9 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
) -> Dict[str, Any]:
|
||||
if skip_fields is None:
|
||||
skip_fields = set()
|
||||
# We need to skip the decode_wrapper field since it cannot be
|
||||
# We need to skip the prefill/decode_wrapper field since it cannot be
|
||||
# broadcasted with nccl when TP is enabled.
|
||||
skip_fields.add('prefill_wrapper')
|
||||
skip_fields.add('decode_wrapper')
|
||||
return super().asdict_zerocopy(skip_fields)
|
||||
|
||||
@ -168,6 +195,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -217,10 +245,14 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
query = query.contiguous(
|
||||
) # Flashinfer requires query to be contiguous
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
assert prefill_meta.block_tables is not None
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# We will use flash attention for prefill
|
||||
# when kv_cache is not provided.
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache is None:
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
@ -235,13 +267,14 @@ class FlashInferImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported with flashinfer yet.")
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
output = prefill_meta.prefill_wrapper.forward(query,
|
||||
kv_cache,
|
||||
causal=True)
|
||||
else:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||
query = query.contiguous(
|
||||
) # Flashinfer requires query to be contiguous
|
||||
output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
||||
query,
|
||||
kv_cache,
|
||||
|
@ -77,8 +77,9 @@ def get_attn_backend(
|
||||
return IpexAttnBackend
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
logger.info("Using Flashinfer backend.")
|
||||
logger.warning("Eager mode is required for the Flashinfer backend. "
|
||||
"Please make sure --enforce-eager is set.")
|
||||
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
|
||||
" please avoid using Flashinfer as the"
|
||||
"backend when running on llma-2-7b."))
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend
|
||||
elif backend == _Backend.PALLAS:
|
||||
|
@ -10,6 +10,17 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
@ -198,11 +209,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
# Lazy initialization
|
||||
self.model: nn.Module # Set after load_model
|
||||
# Set if the backend is flashinfer.
|
||||
self.flashinfer_workspace_buffer: torch.Tensor
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
|
||||
self.flashinfer_decode_workspace_buffer = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
self.flashinfer_prefill_workspace_buffer = None
|
||||
self.flashinfer_prefill_wrapper = None
|
||||
|
||||
def load_model(self) -> None:
|
||||
with CudaMemoryProfiler() as m:
|
||||
self.model = get_model(
|
||||
@ -450,15 +464,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
if curr_sliding_window_blocks is not None:
|
||||
block_table = block_table[
|
||||
-curr_sliding_window_blocks:]
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
paged_kv_indices.extend(block_table)
|
||||
paged_kv_indptr.append(paged_kv_indptr[-1] +
|
||||
len(block_table))
|
||||
last_page_len = seq_data.get_len(
|
||||
) % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
paged_kv_last_page_len.append(last_page_len)
|
||||
else:
|
||||
# Only happens when memory profiling runs.
|
||||
block_table = []
|
||||
@ -505,7 +510,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
for k, v in mm_kwargs.items():
|
||||
multi_modal_kwargs_list[k].append(v)
|
||||
|
||||
if _is_block_tables_empty(seq_group_metadata.block_tables):
|
||||
is_profile_run = _is_block_tables_empty(
|
||||
seq_group_metadata.block_tables)
|
||||
if is_profile_run:
|
||||
# During memory profiling, the block tables are not
|
||||
# initialized yet. In this case, we just use a dummy
|
||||
# slot mapping.
|
||||
@ -544,6 +551,27 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
# Prepare input tensors for flashinfer
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
seq_len = seq_data.get_len()
|
||||
# Get the number of valid blocks based on sequence length.
|
||||
# If seq_len = 16, block_size = 16,
|
||||
# block_table_bound is 1 with 1 valid block.
|
||||
# If seq_len = 15, block_size = 16,
|
||||
# block_table_bound is 0 + 1 with 1 valid block.
|
||||
block_table_bound = seq_len // self.block_size + 1 \
|
||||
if seq_len % self.block_size != 0 \
|
||||
else seq_len // self.block_size
|
||||
|
||||
paged_kv_indices.extend(block_table[:block_table_bound])
|
||||
paged_kv_indptr.append(paged_kv_indptr[-1] +
|
||||
block_table_bound)
|
||||
|
||||
last_page_len = seq_len % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
paged_kv_last_page_len.append(last_page_len)
|
||||
|
||||
batch_size = len(input_tokens)
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(prefill_seq_lens, default=0)
|
||||
@ -566,6 +594,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
seq_lens.append(1)
|
||||
block_tables.append([])
|
||||
lora_index_mapping.append(0)
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
last_paged_kv_indptr = paged_kv_indptr[-1]
|
||||
paged_kv_indptr.append(last_paged_kv_indptr)
|
||||
paged_kv_last_page_len.append(0)
|
||||
|
||||
batch_size = graph_batch_size
|
||||
num_decode_tokens = batch_size
|
||||
|
||||
@ -589,9 +623,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
@ -600,6 +644,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
@ -612,30 +660,30 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
device=self.device)
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
if not hasattr(self, "flashinfer_workspace_buffer"):
|
||||
# Allocate 16MB workspace buffer
|
||||
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
|
||||
self.flashinfer_workspace_buffer = torch.empty(
|
||||
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
|
||||
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
paged_kv_last_page_len_tensor = torch.tensor(
|
||||
paged_kv_last_page_len, dtype=torch.int, device=self.device)
|
||||
if len(paged_kv_indptr) > 0:
|
||||
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
|
||||
device='cpu',
|
||||
dtype=torch.int)
|
||||
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
|
||||
device='cpu',
|
||||
dtype=torch.int)
|
||||
paged_kv_last_page_len_tensor = torch.tensor(
|
||||
paged_kv_last_page_len, device='cpu', dtype=torch.int)
|
||||
else:
|
||||
paged_kv_indices_tensor = None
|
||||
paged_kv_indptr_tensor = None
|
||||
paged_kv_last_page_len_tensor = None
|
||||
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
|
||||
self.model_config.dtype)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
use_cuda_graph=False,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
block_tables=block_tables,
|
||||
workspace_buffer=self.flashinfer_workspace_buffer,
|
||||
paged_kv_indptr=paged_kv_indptr_tensor,
|
||||
paged_kv_indices=paged_kv_indices_tensor,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
||||
@ -644,25 +692,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
num_kv_heads=self.model_config.get_num_kv_heads(
|
||||
self.parallel_config),
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
page_size=16,
|
||||
page_size=self.block_size,
|
||||
seq_start_loc=seq_start_loc,
|
||||
data_type=kv_cache_dtype)
|
||||
query_start_loc=query_start_loc,
|
||||
device=self.device,
|
||||
data_type=kv_cache_dtype,
|
||||
use_cuda_graph=use_captured_graph)
|
||||
|
||||
else:
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
@ -854,27 +891,97 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
# For flashinfer, different batch sizes will share the
|
||||
# same workspace buffer.
|
||||
decode_workspace_buffer = \
|
||||
torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
indices_buffer = torch.empty(max_batch_size *
|
||||
self.cache_config.num_gpu_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
indptr_buffer = torch.empty(max_batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
last_page_len_buffer = torch.empty(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
with graph_capture() as graph_capture_context:
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
# Create dummy attn_metadata.
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens[:batch_size],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
indptr_buffer = indptr_buffer[:batch_size + 1]
|
||||
last_page_len_buffer = last_page_len_buffer[:batch_size]
|
||||
|
||||
num_qo_heads = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
if num_qo_heads // num_kv_heads >= 4:
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
decode_wrapper = \
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
||||
decode_workspace_buffer, indptr_buffer, indices_buffer,
|
||||
last_page_len_buffer, "NHD", use_tensor_cores)
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||
self.kv_cache_dtype, self.model_config.dtype)
|
||||
|
||||
paged_kv_indptr_tensor_host = torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32)
|
||||
paged_kv_indices_tensor_host = torch.arange(
|
||||
0, batch_size, dtype=torch.int32)
|
||||
paged_kv_last_page_len_tensor_host = torch.full(
|
||||
(batch_size, ), self.block_size, dtype=torch.int32)
|
||||
query_start_loc_host = torch.arange(0,
|
||||
batch_size + 1,
|
||||
dtype=torch.int32)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
max_prefill_seq_len=0,
|
||||
block_tables=block_tables,
|
||||
paged_kv_indptr=paged_kv_indptr_tensor_host,
|
||||
paged_kv_indices=paged_kv_indices_tensor_host,
|
||||
paged_kv_last_page_len=
|
||||
paged_kv_last_page_len_tensor_host,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
page_size=self.block_size,
|
||||
seq_start_loc=None,
|
||||
query_start_loc=query_start_loc_host,
|
||||
device=self.device,
|
||||
data_type=kv_cache_dtype,
|
||||
use_cuda_graph=True,
|
||||
decode_wrapper=decode_wrapper,
|
||||
prefill_wrapper=None)
|
||||
attn_metadata.begin_forward()
|
||||
else:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens[:batch_size],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
@ -883,8 +990,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
)
|
||||
self.set_active_loras(set(), lora_mapping)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
hidden_states = graph_runner.capture(
|
||||
graph_runner = CUDAGraphRunner(self.model,
|
||||
self.attn_backend.get_name())
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
graph_runner.flashinfer_indptr_buffer = indptr_buffer
|
||||
graph_runner.flashinfer_indices_buffer = indices_buffer
|
||||
graph_runner.flashinfer_last_page_len_buffer = \
|
||||
last_page_len_buffer
|
||||
graph_runner.flashinfer_decode_workspace_buffer = \
|
||||
decode_workspace_buffer
|
||||
graph_runner.flashinfer_decode_wrapper = \
|
||||
decode_wrapper
|
||||
|
||||
graph_runner.capture(
|
||||
input_tokens[:batch_size],
|
||||
input_positions[:batch_size],
|
||||
hidden_states[:batch_size]
|
||||
@ -918,11 +1037,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
self,
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
return (
|
||||
model_input = \
|
||||
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
)
|
||||
return model_input
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
@ -970,6 +1090,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
assert model_input.attn_metadata is not None
|
||||
assert model_input.input_tokens is not None
|
||||
if self.flashinfer_decode_workspace_buffer is None:
|
||||
self.flashinfer_decode_workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.flashinfer_decode_wrapper = \
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_decode_workspace_buffer, "NHD")
|
||||
self.flashinfer_prefill_workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.flashinfer_prefill_wrapper = \
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_prefill_workspace_buffer, "NHD")
|
||||
|
||||
model_input.attn_metadata.prefill_wrapper = \
|
||||
self.flashinfer_prefill_wrapper
|
||||
if model_input.attn_metadata.use_cuda_graph:
|
||||
batch_size = model_input.input_tokens.shape[0]
|
||||
model_input.attn_metadata.decode_wrapper = self.graph_runners[
|
||||
batch_size].flashinfer_decode_wrapper
|
||||
else:
|
||||
model_input.attn_metadata.decode_wrapper = \
|
||||
self.flashinfer_decode_wrapper
|
||||
model_input.attn_metadata.begin_forward()
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
@ -1020,13 +1170,22 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
|
||||
class CUDAGraphRunner:
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
def __init__(self, model: nn.Module, backend_name: str):
|
||||
self.model = model
|
||||
self.backend_name = backend_name
|
||||
|
||||
self.input_buffers: Dict[str, torch.Tensor] = {}
|
||||
self.output_buffers: Dict[str, torch.Tensor] = {}
|
||||
|
||||
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
||||
|
||||
self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
|
||||
self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
|
||||
self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
|
||||
self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
|
||||
self.flashinfer_decode_wrapper: Optional[
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
assert self._graph is not None
|
||||
@ -1079,14 +1238,23 @@ class CUDAGraphRunner:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Save the input and output buffers.
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
if self.backend_name == "flashinfer":
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
}
|
||||
else:
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor":
|
||||
attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
self.output_buffers = {"hidden_states": hidden_states}
|
||||
return hidden_states
|
||||
|
||||
@ -1106,10 +1274,12 @@ class CUDAGraphRunner:
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
||||
non_blocking=True)
|
||||
self.input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
if self.backend_name != "flashinfer":
|
||||
self.input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
|
||||
|
Reference in New Issue
Block a user