[TPU] Implement prefix caching for TPUs (#10307)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2024-11-20 13:54:15 -08:00
committed by GitHub
parent c68f7ede6a
commit 2f77b6cfec
4 changed files with 181 additions and 104 deletions

View File

@ -16,8 +16,8 @@ ray[default]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241028+cpu
torchvision==0.20.0.dev20241028+cpu
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl
torch==2.6.0.dev20241114+cpu
torchvision==0.20.0.dev20241114+cpu
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241114-cp310-cp310-linux_x86_64.whl
jaxlib==0.4.32.dev20240829
jax==0.4.32.dev20240829

View File

@ -65,6 +65,7 @@ class PallasMetadata(AttentionMetadata):
# or all decoding.
block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] = None
effective_query_lens: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
@ -72,8 +73,6 @@ class PallasMetadata(AttentionMetadata):
return None
assert self.num_decode_tokens == 0
assert self.block_tables is None
assert self.context_lens is None
return self
@property
@ -186,29 +185,50 @@ class PallasAttentionBackendImpl(AttentionImpl):
query = query * self.scale
if attn_metadata.num_prefills > 0:
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
if attn_metadata.block_tables is None:
# Prefill without paged KV cache.
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block = 16
num_queries_per_compute_block = 16
assert seq_len % num_queries_per_compute_block == 0
output = torch.ops.xla.multi_queries_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.effective_query_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
)
else:
# Decoding run.
assert kv_cache[0].numel() > 0

View File

@ -1,3 +1,4 @@
import enum
import time
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
@ -11,7 +12,6 @@ import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
@ -39,6 +39,15 @@ _ENABLE_TOP_P = False
_MAX_NUM_SAMPLES = 128
class ExecutionMode(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
PREFIX_PREFILL = enum.auto()
def is_prefill(self) -> bool:
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
token_ids: torch.Tensor
@ -140,16 +149,21 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
xm.wait_device_ops()
self.model = ModelWrapper(model, self.vllm_config)
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
def _dummy_run(
self,
batch_size: int,
seq_len: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
is_prompt: bool,
exec_mode: ExecutionMode,
) -> None:
if is_prompt:
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
@ -160,18 +174,38 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None,
context_lens=None,
)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
if exec_mode == ExecutionMode.PREFILL:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None,
context_lens=None,
effective_query_lens=None,
)
else:
context_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device=self.device)
effective_query_lens = torch.ones_like(context_lens)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
else:
assert seq_len == 1
token_ids = torch.zeros((batch_size, seq_len),
@ -204,7 +238,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
)
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
@ -213,7 +247,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if is_prompt:
if exec_mode.is_prefill():
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
@ -229,15 +263,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
self.model(token_ids,
position_ids,
attn_metadata,
input_lens,
t,
p,
num_samples,
kv_caches,
is_prompt=is_prompt)
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)
def warmup_model(
self,
@ -248,13 +275,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
start = time.time()
for batch_size in [1]:
seq_len = 16
while True:
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
if seq_len >= self.model_config.max_model_len:
break
num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
@ -263,12 +290,39 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
end = time.time()
logger.info("Compilation for prefill done in %.2f s.", end - start)
# Prefix prefill
if self.cache_config.enable_prefix_caching:
logger.info("Compiling the model with different input shapes for "
"prefix prefill...")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens >=
self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info("Compilation for prefix prefill done in %.2f s.",
end - start)
# Decode
start = time.time()
seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.DECODE)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
@ -287,9 +341,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
input_tokens: List[int] = []
input_positions: List[int] = []
prompt_lens: List[int] = []
context_lens: List[int] = []
slot_mapping: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
for batch_idx, seq_group_metadata in enumerate(
seq_group_metadata_list):
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
@ -298,19 +354,31 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
seq_data = seq_group_metadata.seq_data[seq_id]
# Could include output tokens when a request is preempted.
prompt_tokens = seq_data.get_token_ids()
seq_len = len(prompt_tokens)
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
num_computed_tokens = num_computed_blocks * self.block_size
if num_computed_tokens > 0:
prompt_tokens = prompt_tokens[num_computed_tokens:]
context_lens.append(seq_len)
else:
context_lens.append(0)
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens)
input_positions.extend(list(range(prompt_len)))
input_positions.extend(range(num_computed_tokens, seq_len))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(prompt_len):
for i in range(num_computed_tokens, seq_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if num_computed_tokens > 0:
self.block_tables[batch_idx, :len(block_table)] = block_table
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
@ -338,14 +406,21 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
prompt_lens = torch.tensor(prompt_lens,
dtype=torch.int32,
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables = torch.tensor(self.block_tables[:num_prefills],
dtype=torch.int32,
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None,
context_lens=None,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=prompt_lens,
)
return input_tokens, input_positions, attn_metadata, prompt_lens
@ -550,6 +625,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
orig_slot_mapping = model_input.attn_metadata.slot_mapping
orig_block_tables = model_input.attn_metadata.block_tables
orig_context_lens = model_input.attn_metadata.context_lens
orig_effective_query_lens = \
model_input.attn_metadata.effective_query_lens
batch_size = model_input.input_lens.shape[0]
start_idx = 0
next_token_ids = []
@ -568,18 +647,24 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
attn_metadata.num_prefills = 1
attn_metadata.slot_mapping = orig_slot_mapping[
None, start_idx:end_idx].to(self.device)
if orig_context_lens[i].item() > 0:
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
self.device)
attn_metadata.block_tables = orig_block_tables[
i].unsqueeze(0).to(self.device)
attn_metadata.effective_query_lens = \
orig_effective_query_lens[i:i + 1].to(self.device)
else:
attn_metadata.context_lens = None
attn_metadata.block_tables = None
attn_metadata.effective_query_lens = None
input_lens = model_input.input_lens[i:i + 1].to(self.device)
t = model_input.t[i:i + 1].to(self.device)
p = model_input.p[i:i + 1].to(self.device)
output_token_ids = self.model(token_ids,
position_ids,
attn_metadata,
input_lens,
t,
p,
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches,
is_prompt=True)
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx
@ -624,15 +709,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
input_lens = model_input.input_lens.to(self.device)
for i in range(num_steps):
slot_mapping = attn_metadata.slot_mapping
output_token_ids = self.model(token_ids,
position_ids,
attn_metadata,
input_lens,
t,
p,
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches,
is_prompt=False)
kv_caches)
self.cached_step_outputs.append(output_token_ids)
if i < num_steps - 1:
@ -667,34 +747,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
return [sampler_output]
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
class ModelWrapper(nn.Module):
def __init__(self, model: nn.Module, vllm_config: VllmConfig):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
compiled_callable = torch.compile(self.forward,
backend="openxla",
fullgraph=True,
dynamic=False)
super().__init__(
compiled_callable,
compilation_level=vllm_config.compilation_config.level)
def __call__(self, *args, is_prompt: bool, **kwargs):
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
# not fully compiled yet, or not using the custom dispatcher,
# let PyTorch handle it
return self.compiled_callable(*args, **kwargs)
# the 3 compiled codes are:
# 0: for profiling
# 1: for prompt
# 2: for decode
# dispatch to the compiled code directly, skip PyTorch
if is_prompt:
with self.dispatch_to_code(1):
return self.forward(*args, **kwargs)
else:
with self.dispatch_to_code(2):
return self.forward(*args, **kwargs)
def forward(
self,

View File

@ -13,7 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)
@ -112,7 +112,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
batch_size=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
kv_caches=kv_caches,
is_prompt=True,
exec_mode=ExecutionMode.PREFILL,
)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()