[Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode (#4628)

Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>, bong-furiosa <bongwon.jang@furiosa.ai>
This commit is contained in:
Lily Liu
2024-06-28 15:28:49 -07:00
committed by GitHub
parent 6a62cb82cc
commit 7041de4384
7 changed files with 313 additions and 117 deletions

View File

@ -211,3 +211,6 @@ steps:
- pytest -v -s distributed/test_custom_all_reduce.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py

View File

@ -19,4 +19,4 @@ sentence-transformers # required for embedding
aiohttp
# quantization
bitsandbytes==0.42.0
bitsandbytes==0.42.0

View File

@ -2,7 +2,6 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import weakref
import pytest
@ -13,7 +12,6 @@ MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
def test_vllm_gc_ed():
@ -39,10 +37,6 @@ def test_models(
max_tokens: int,
enforce_eager: bool,
) -> None:
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
pytest.skip("Skipping non-eager test for FlashInferBackend.")
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

View File

@ -21,7 +21,6 @@ MODELS = [
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
@pytest.mark.skipif(torch.cuda.device_count() < 2,
@ -39,16 +38,12 @@ def test_models(
) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
enforce_eager = backend_by_env_var == "FLASHINFER"
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@ -1,10 +1,16 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import flashinfer
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
# requests only.
max_prefill_seq_len: int
use_cuda_graph: bool = False
use_cuda_graph: bool = True
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
# Metadata for the prefill stage
seq_start_loc: Optional[torch.Tensor] = None
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
def __post_init__(self):
# Refer to
@ -109,13 +113,35 @@ class FlashInferMetadata(AttentionMetadata):
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if self.num_prefills == 0:
assert self.num_decode_tokens > 0
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
def begin_forward(self):
if self.num_prefill_tokens > 0:
if self.paged_kv_indices is None:
return
assert self.prefill_wrapper is not None
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,
@ -133,8 +159,9 @@ class FlashInferMetadata(AttentionMetadata):
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be
# We need to skip the prefill/decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('prefill_wrapper')
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
@ -168,6 +195,7 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -217,10 +245,14 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype,
)
query = query.contiguous(
) # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func(
q=query,
k=key,
@ -235,13 +267,14 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
)
else:
raise NotImplementedError(
"Prefix caching is not supported with flashinfer yet.")
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(query,
kv_cache,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,

View File

@ -77,8 +77,9 @@ def get_attn_backend(
return IpexAttnBackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set.")
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
" please avoid using Flashinfer as the"
"backend when running on llma-2-7b."))
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.PALLAS:

View File

@ -10,6 +10,17 @@ import numpy as np
import torch
import torch.nn as nn
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
@ -198,11 +209,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Lazy initialization
self.model: nn.Module # Set after load_model
# Set if the backend is flashinfer.
self.flashinfer_workspace_buffer: torch.Tensor
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.flashinfer_decode_workspace_buffer = None
self.flashinfer_decode_wrapper = None
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
@ -450,15 +464,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if curr_sliding_window_blocks is not None:
block_table = block_table[
-curr_sliding_window_blocks:]
if self.attn_backend.get_name() == "flashinfer":
paged_kv_indices.extend(block_table)
paged_kv_indptr.append(paged_kv_indptr[-1] +
len(block_table))
last_page_len = seq_data.get_len(
) % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
else:
# Only happens when memory profiling runs.
block_table = []
@ -505,7 +510,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
if _is_block_tables_empty(seq_group_metadata.block_tables):
is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables)
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
@ -544,6 +551,27 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Prepare input tensors for flashinfer
if self.attn_backend.get_name() == "flashinfer":
seq_len = seq_data.get_len()
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
paged_kv_indices.extend(block_table[:block_table_bound])
paged_kv_indptr.append(paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
batch_size = len(input_tokens)
max_query_len = max(query_lens)
max_prefill_seq_len = max(prefill_seq_lens, default=0)
@ -566,6 +594,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
if self.attn_backend.get_name() == "flashinfer":
last_paged_kv_indptr = paged_kv_indptr[-1]
paged_kv_indptr.append(last_paged_kv_indptr)
paged_kv_last_page_len.append(0)
batch_size = graph_batch_size
num_decode_tokens = batch_size
@ -589,9 +623,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
@ -600,6 +644,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
@ -612,30 +660,30 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
device=self.device)
if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
self.flashinfer_workspace_buffer = torch.empty(
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
dtype=torch.int,
device=self.device)
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
dtype=torch.int,
device=self.device)
paged_kv_last_page_len_tensor = torch.tensor(
paged_kv_last_page_len, dtype=torch.int, device=self.device)
if len(paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
device='cpu',
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
device='cpu',
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
paged_kv_last_page_len, device='cpu', dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
use_cuda_graph=False,
max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables,
workspace_buffer=self.flashinfer_workspace_buffer,
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
@ -644,25 +692,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
num_kv_heads=self.model_config.get_num_kv_heads(
self.parallel_config),
head_dim=self.model_config.get_head_size(),
page_size=16,
page_size=self.block_size,
seq_start_loc=seq_start_loc,
data_type=kv_cache_dtype)
query_start_loc=query_start_loc,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph)
else:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
@ -854,27 +891,97 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
if self.attn_backend.get_name() == "flashinfer":
# For flashinfer, different batch sizes will share the
# same workspace buffer.
decode_workspace_buffer = \
torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
indices_buffer = torch.empty(max_batch_size *
self.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.device)
indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.device)
last_page_len_buffer = torch.empty(max_batch_size,
dtype=torch.int32,
device=self.device)
with graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata.
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
if self.attn_backend.get_name() == "flashinfer":
indptr_buffer = indptr_buffer[:batch_size + 1]
last_page_len_buffer = last_page_len_buffer[:batch_size]
num_qo_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
num_kv_heads = self.model_config.get_num_kv_heads(
self.parallel_config)
if num_qo_heads // num_kv_heads >= 4:
use_tensor_cores = True
else:
use_tensor_cores = False
decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, indptr_buffer, indices_buffer,
last_page_len_buffer, "NHD", use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(
0, batch_size + 1, dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(
0, batch_size, dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full(
(batch_size, ), self.block_size, dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=
paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
else:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
if self.lora_config:
lora_mapping = LoRAMapping(
@ -883,8 +990,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
)
self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model)
hidden_states = graph_runner.capture(
graph_runner = CUDAGraphRunner(self.model,
self.attn_backend.get_name())
if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = indptr_buffer
graph_runner.flashinfer_indices_buffer = indices_buffer
graph_runner.flashinfer_last_page_len_buffer = \
last_page_len_buffer
graph_runner.flashinfer_decode_workspace_buffer = \
decode_workspace_buffer
graph_runner.flashinfer_decode_wrapper = \
decode_wrapper
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
hidden_states[:batch_size]
@ -918,11 +1037,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForGPUWithSamplingMetadata:
return (
model_input = \
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
)
return model_input
def prepare_model_input(
self,
@ -970,6 +1090,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.attn_backend.get_name() == "flashinfer":
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None
if self.flashinfer_decode_workspace_buffer is None:
self.flashinfer_decode_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_decode_wrapper = \
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_decode_workspace_buffer, "NHD")
self.flashinfer_prefill_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_prefill_wrapper = \
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_prefill_workspace_buffer, "NHD")
model_input.attn_metadata.prefill_wrapper = \
self.flashinfer_prefill_wrapper
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
model_input.attn_metadata.decode_wrapper = self.graph_runners[
batch_size].flashinfer_decode_wrapper
else:
model_input.attn_metadata.decode_wrapper = \
self.flashinfer_decode_wrapper
model_input.attn_metadata.begin_forward()
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
@ -1020,13 +1170,22 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
class CUDAGraphRunner:
def __init__(self, model: nn.Module):
def __init__(self, model: nn.Module, backend_name: str):
self.model = model
self.backend_name = backend_name
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
self.flashinfer_decode_wrapper: Optional[
CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
@property
def graph(self):
assert self._graph is not None
@ -1079,14 +1238,23 @@ class CUDAGraphRunner:
torch.cuda.synchronize()
# Save the input and output buffers.
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if self.backend_name == "flashinfer":
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
}
else:
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor":
attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return hidden_states
@ -1106,10 +1274,12 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if self.backend_name != "flashinfer":
self.input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor,
non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph.
self.graph.replay()