Optimize decode/prompt prepare code

This commit is contained in:
Alexander Matveev
2025-02-04 21:12:07 +00:00
parent 39c4a4cdb5
commit c2867d5bc1
3 changed files with 358 additions and 265 deletions

View File

@ -57,6 +57,14 @@ class BlockTable:
src, :num_blocks] src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
def commit(self, num_reqs: int) -> None: def commit(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True) non_blocking=True)

View File

@ -436,3 +436,72 @@ class InputBatch:
@property @property
def no_prompt_logprob(self) -> bool: def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0 return len(self.prompt_logprob_reqs) == 0
def swap_positions(b: InputBatch, id_1, id_2):
assert id_1 != id_2
req_id_1 = b.req_ids[id_1]
req_id_2 = b.req_ids[id_2]
assert req_id_1 is not None
assert req_id_2 is not None
assert id_1 == b.req_id_to_index[req_id_1]
assert id_2 == b.req_id_to_index[req_id_2]
b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1]
b.req_id_to_index[id_1], b.req_id_to_index[id_2] = b.req_id_to_index[
id_2], b.req_id_to_index[id_1]
ids = [id_1, id_2]
rev_ids = [id_2, id_1]
b.num_tokens[ids] = b.num_tokens[rev_ids]
b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids]
b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids]
b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids]
b.block_table.swap_row(id_1, id_2)
b.temperature_cpu[ids] = b.temperature_cpu[rev_ids]
b.top_p_cpu[ids] = b.top_p_cpu[rev_ids]
b.top_k_cpu[ids] = b.top_k_cpu[rev_ids]
b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids]
b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids]
b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids]
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
id_1]
b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[
id_2], b.stop_token_ids[id_1]
b.generators[id_1], b.generators[id_2] = b.generators[id_2], b.generators[
id_1]
def ensure_decodes_first(b: InputBatch):
num_reqs = b.num_reqs
while True:
# Find the first prompt index
first_prompt_index = None
for i in range(num_reqs):
if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]:
first_prompt_index = i
break
if first_prompt_index is None:
break
# Find the last decode index
last_decode_index = None
for i in reversed(range(num_reqs)):
if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]:
last_decode_index = i
break
if last_decode_index is None:
break
# Sanity
assert first_prompt_index != last_decode_index
# Check if done
if first_prompt_index > last_decode_index:
break
# Swap
swap_positions(b, first_prompt_index, last_decode_index)

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from unittest.mock import patch from unittest.mock import patch
import numpy as np
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
@ -20,7 +21,8 @@ from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
ensure_decodes_first)
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
if TYPE_CHECKING: if TYPE_CHECKING:
@ -39,22 +41,14 @@ _MAX_NUM_SAMPLES = 128
@dataclass @dataclass
class PromptInputData: class PromptData:
input_tokens: torch.Tensor
req_ids: List input_positions: torch.Tensor
prompt_lens: List attn_metadata: PallasMetadata
input_tokens: List
input_positions: List
attn_metadata: List
def zipped(self):
return zip(self.req_ids, self.prompt_lens, self.input_tokens,
self.input_positions, self.attn_metadata)
@dataclass @dataclass
class DecodeInputData: class DecodeData:
req_ids: List
input_tokens: Optional[torch.Tensor] = None input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional[PallasMetadata] = None attn_metadata: Optional[PallasMetadata] = None
@ -85,249 +79,247 @@ class TPUModelRunner(ModelRunnerBase):
# KV caches for forward pass # KV caches for forward pass
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
# Used to initialize positions for the individual prefills # Cache torch/numpy tensors
self.prefill_input_positions = torch.tensor(range(self.max_model_len), self.input_ids_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu", device="cpu",
dtype=torch.int32).reshape( pin_memory=self.pin_memory)
1, -1) self.input_ids_np = self.input_ids_cpu.numpy()
def _prepare_prompt_inputs( self.input_positions_cpu = torch.empty(self.max_model_len,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.input_positions_np = self.input_positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.prompt_context_lens_cpu = torch.zeros((1),
dtype=torch.int32,
device="cpu")
self.prompt_effective_query_lens = torch.zeros((1),
dtype=torch.int32,
device="cpu")
self.decode_context_lens_cpu = torch.zeros(self.max_model_len,
dtype=torch.int32,
device="cpu")
self.decode_context_lens_np = self.decode_context_lens_cpu.numpy()
self.arange_np = np.arange(self.max_model_len, dtype=np.int32)
self.req_ids = []
self.prompt_token_ids = []
self.sampled_token_ids = []
def _get_prompts_and_decodes(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> PromptInputData: ):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
assert num_reqs > 0 assert num_reqs > 0
req_ids = [] # Traverse decodes first
prompt_lens = [] decode_req_ids = []
input_tokens_list = [] for i in range(num_reqs):
input_positions_list = [] req_id = self.input_batch.req_ids[i]
attn_metadata_list = []
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
req_index = self.input_batch.req_id_to_index[req_id]
req_state = self.requests[req_id]
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id] req_id]
num_computed_tokens = req_state.num_computed_tokens
num_prompt_tokens = len(req_state.prompt_token_ids)
# Detect whether this is a prompt (can be full or chunked) if num_computed_tokens < num_prompt_tokens:
if num_computed_tokens >= num_prompt_tokens: # This is prompt
# This is a decode => Skip break
continue
# This is a prompt # This is decode
req_ids.append(req_id) assert num_scheduled_tokens == 1
decode_req_ids.append(req_id)
# Traverse prompts
prompt_req_ids = []
prompt_scheduled_tokens = []
for i in range(len(decode_req_ids), num_reqs):
req_id = self.input_batch.req_ids[i]
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
# Must be prompt
assert num_computed_tokens < num_prompt_tokens
prompt_scheduled_tokens.append(num_scheduled_tokens)
prompt_req_ids.append(req_id)
return prompt_req_ids, decode_req_ids, prompt_scheduled_tokens
def _prepare_prompt(self, req_index: int,
num_scheduled_tokens: int) -> PromptData:
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[
req_index]
num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index]
# Must be prompt
assert num_computed_tokens < num_prompt_tokens
# Prompt len # Prompt len
prompt_len = num_scheduled_tokens prompt_len = num_scheduled_tokens
prompt_lens.append(prompt_len) padded_prompt_len = _get_padded_prompt_len(prompt_len)
padded_prompt_len = _get_padded_prefill_len(prompt_len)
assert padded_prompt_len <= self.max_model_len assert padded_prompt_len <= self.max_model_len
# Seq len # Seq len
seq_len = num_computed_tokens + prompt_len seq_len = num_computed_tokens + prompt_len
padded_seq_len = num_computed_tokens + padded_prompt_len
# Input tokens # Input tokens
input_tokens = torch.zeros((1, padded_prompt_len), input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
dtype=torch.int32, req_index, num_computed_tokens:padded_seq_len]
device="cpu") input_tokens_cpu[prompt_len:] = 0
input_tokens[:, :prompt_len] = torch.from_numpy(
self.input_batch.token_ids_cpu[req_index,
num_computed_tokens:seq_len])
# input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
# req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
# input_tokens[:, prompt_len:] = 0
input_tokens_list.append(input_tokens.to(self.device))
# Input positions # Input positions
input_positions = torch.zeros((1, padded_prompt_len), input_positions_np = self.input_positions_np[:padded_prompt_len]
dtype=torch.int32, np.add(num_computed_tokens,
device="cpu") self.arange_np[:padded_prompt_len],
input_positions[:, : out=input_positions_np)
prompt_len] = self.prefill_input_positions[:, input_positions_np[prompt_len:] = 0
num_computed_tokens:
seq_len]
# input_positions[:, prompt_len:] = 0
input_positions_list.append(input_positions.to(self.device))
# Slot mapping # Slot mapping
block_table_cpu_tensor = \ block_table_np = \
self.input_batch.block_table.get_cpu_tensor() self.input_batch.block_table.get_numpy_array()
block_numbers = block_table_cpu_tensor[req_index, block_numbers_np = block_table_np[req_index, input_positions_np //
input_positions // self.block_size]
self.block_size].reshape( block_offsets_np = input_positions_np % self.block_size
1, -1)
block_offsets = input_positions % self.block_size slot_mapping_np = self.slot_mapping_np[:padded_prompt_len]
slot_mapping = block_numbers * self.block_size + block_offsets np.add(block_numbers_np * self.block_size,
slot_mapping[:, prompt_len:] = _PAD_SLOT_ID block_offsets_np,
slot_mapping = slot_mapping.long() out=slot_mapping_np)
slot_mapping_np[:, prompt_len:] = _PAD_SLOT_ID
# Block table # Block table
block_table = None block_table_cpu = None
if num_computed_tokens > 0: if num_computed_tokens > 0:
block_table = block_table_cpu_tensor[req_index].unsqueeze(0) block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_table = block_table.to(self.device) block_table_cpu = block_table_cpu[req_index]
# Context len # Context len
context_len = 0 self.prompt_context_lens_cpu[0] = 0
if num_computed_tokens > 0: if num_computed_tokens > 0:
context_len = seq_len self.prompt_context_lens_cpu[0] = seq_len
context_lens = torch.tensor([context_len],
dtype=torch.int32,
device="cpu")
# Effective query len # Effective query len
effective_query_lens = torch.tensor([prompt_len], self.prompt_effective_query_lens[0] = prompt_len
dtype=torch.int32,
device="cpu") # Get final tensors
input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
input_positions = self.input_positions_cpu[:padded_prompt_len].reshape(
1, -1).to(self.device)
slot_mapping = self.slot_mapping_cpu[:padded_prompt_len].reshape(
1, -1).to(self.device)
block_table = block_table_cpu.reshape(1, -1).to(
self.device) if block_table_cpu is not None else None
context_lens = self.prompt_context_lens_cpu.reshape(1,
-1).to(self.device)
effective_query_lens = self.prompt_effective_query_lens.reshape(
1, -1).to(self.device)
# Attn metadata # Attn metadata
attn_metadata_list.append( attn_metadata = PallasMetadata(
PallasMetadata(
num_prefills=1, num_prefills=1,
num_prefill_tokens=0, # NOTE: This is not used. num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping.to(self.device), slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
block_tables=block_table, block_tables=block_table,
context_lens=context_lens.to(self.device), context_lens=context_lens,
effective_query_lens=effective_query_lens.to(self.device), effective_query_lens=effective_query_lens,
))
# TODO: Remove this
# if num_computed_tokens > 0:
# print("-------------------")
# print("input_tokens.shape = {}".format(input_tokens.shape))
# print("input_positions.shape = {}".format(
# input_positions.shape))
# print("slot_mapping.shape = {}".format(slot_mapping.shape))
# print("block_table.shape = {}".format(block_table.shape))
# print("context_lens.shape = {} data = {}".format(
# context_lens.shape, context_lens))
# print("effective_query_lens.shape = {} data = {}".format(
# effective_query_lens.shape, effective_query_lens))
return PromptInputData(
req_ids=req_ids,
prompt_lens=prompt_lens,
input_tokens=input_tokens_list,
input_positions=input_positions_list,
attn_metadata=attn_metadata_list,
) )
def _prepare_decode_inputs( return PromptData(input_tokens, input_positions, attn_metadata)
def _prepare_decode(
self, self,
scheduler_output: "SchedulerOutput", decode_req_ids: List[str],
) -> DecodeInputData: ) -> DecodeData:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens # Batch size
assert total_num_scheduled_tokens > 0 batch_size = len(decode_req_ids)
num_reqs = self.input_batch.num_reqs padded_batch_size = _get_padded_batch_size(batch_size)
assert num_reqs > 0 assert padded_batch_size <= self.max_model_len
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor() # Input positions
input_positions_np = self.input_positions_np[:padded_batch_size]
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
0,
out=input_positions_np)
input_positions_np[batch_size:] = 0
input_positions_cpu = torch.from_numpy(input_positions_np)
req_ids = [] # Input tokens
req_indices = [] input_tokens_cpu = self.input_ids_cpu[:padded_batch_size]
input_tokens = [] torch.index_select(self.input_batch.token_ids_cpu_tensor,
input_positions = [] 1,
slot_mapping = [] input_positions_cpu,
context_lens = [] out=input_tokens_cpu)
for req_id in self.input_batch.req_ids[:num_reqs]: input_tokens_cpu[:batch_size] = 0
assert req_id is not None
req_index = self.input_batch.req_id_to_index[req_id]
req_state = self.requests[req_id]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
num_computed_tokens = req_state.num_computed_tokens
num_prompt_tokens = len(req_state.prompt_token_ids)
# Detect whether this is a decode
if num_computed_tokens < num_prompt_tokens:
# This is a prompt => Skip
continue
# This is a decode
req_ids.append(req_id)
req_indices.append(req_index)
# Seq len
seq_len = num_computed_tokens + num_scheduled_tokens
# Sanity check decode
assert num_scheduled_tokens == 1
assert seq_len == req_state.num_tokens
# Input token
input_tokens.append([
self.input_batch.token_ids_cpu[req_index, num_computed_tokens]
])
# Position
input_positions.append([num_computed_tokens])
# Slot mapping # Slot mapping
block_number = block_table_cpu_tensor[req_index, block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
num_computed_tokens // block_numbers_cpu = torch.index_select(
self.block_size] block_table_cpu, 1, input_positions_cpu // self.block_size)
block_offset = num_computed_tokens % self.block_size block_numbers_np = block_numbers_cpu.numpy()
slot_id = block_number * self.block_size + block_offset
slot_mapping.append([slot_id])
# Context len block_offsets_np = input_positions_np % self.block_size
context_lens.append(seq_len)
# Compute padding slot_mapping_np = self.slot_mapping_np[:padded_batch_size]
batch_size = len(input_tokens) np.add(block_numbers_np * self.block_size,
padded_batch_size = _get_padded_batch_size(batch_size) block_offsets_np,
num_padding = padded_batch_size - batch_size out=slot_mapping_np)
slot_mapping_np[:, batch_size:] = _PAD_SLOT_ID
# Add padding block_table_cpu = block_table_cpu[:len(decode_req_ids)]
input_tokens.extend([[0]] * num_padding)
input_positions.extend([[0]] * num_padding)
slot_mapping.extend([[_PAD_SLOT_ID]] * num_padding)
context_lens.extend([0] * num_padding)
req_indices.extend([0] * num_padding)
# Create tensors # Context lens
input_tokens_tensor = torch.tensor(input_tokens, context_lens_np = self.decode_context_lens_np[:padded_batch_size]
dtype=torch.int32, np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
device="cpu") 1,
input_positions_tensor = torch.tensor(input_positions, out=context_lens_np)
dtype=torch.int32, context_lens_np[batch_size:] = 0
device="cpu")
slot_mapping_tensor = torch.tensor(slot_mapping, # Get final tensors
dtype=torch.int64, input_tokens = input_tokens_cpu.to(self.device)
device="cpu") input_positions = input_positions_cpu.to(self.device)
context_lens_tensor = torch.tensor(context_lens, slot_mapping = self.slot_mapping_cpu[:padded_batch_size].to(
dtype=torch.int32, self.device)
device="cpu") block_table = block_table_cpu.to(self.device)
block_tables_tensor = block_table_cpu_tensor[req_indices] context_lens = self.decode_context_lens_cpu[:padded_batch_size].to(
self.device)
# Attn metadata # Attn metadata
attn_metadata = PallasMetadata( attn_metadata = PallasMetadata(
num_prefills=0, num_prefills=0,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=padded_batch_size, num_decode_tokens=padded_batch_size,
slot_mapping=slot_mapping_tensor.to(self.device), slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
block_tables=block_tables_tensor.to(self.device), block_tables=block_table,
context_lens=context_lens_tensor.to(self.device), context_lens=context_lens,
effective_query_lens=None, effective_query_lens=None,
) )
return DecodeInputData( return DecodeData(input_tokens=input_tokens,
req_ids=req_ids, input_positions=input_positions,
input_tokens=input_tokens_tensor.to(self.device),
input_positions=input_positions_tensor.to(self.device),
attn_metadata=attn_metadata) attn_metadata=attn_metadata)
@torch.no_grad() @torch.no_grad()
@ -338,17 +330,67 @@ class TPUModelRunner(ModelRunnerBase):
# Update cached state # Update cached state
self.update_states(scheduler_output) self.update_states(scheduler_output)
# Prepare inputs # If necessary, swap decodes/prompts to have all decodes on the start
prompt_data = self._prepare_prompt_inputs(scheduler_output) ensure_decodes_first(self.input_batch)
decode_data = self._prepare_decode_inputs(scheduler_output)
# Prepare prompts/decodes info
prompt_req_ids, decode_req_ids, prompt_scheduled_tokens = self._get_prompts_and_decodes(
scheduler_output)
# Init # Init
num_reqs = self.input_batch.num_reqs decode_token_ids = None
assert num_reqs > 0 decode_data = None
sampled_token_ids_list = [0] * num_reqs self.req_ids.clear()
self.prompt_token_ids.clear()
self.sampled_token_ids.clear()
# Run each prompt
is_first = True
for i, req_id in enumerate(prompt_req_ids):
req_index = len(decode_req_ids) + i
req_state = self.requests[req_id]
num_scheduled_tokens = prompt_scheduled_tokens[i]
seq_len = req_state.num_computed_tokens + num_scheduled_tokens
prompt_len = num_scheduled_tokens
# Prepare first prompt
if is_first:
prompt_data = self._prepare_prompt(req_index,
prompt_scheduled_tokens[i])
is_first = False
# Run forward pass
with set_forward_context(prompt_data.attn_metadata,
self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(prompt_data.input_tokens,
prompt_data.input_positions,
prompt_data.attn_metadata,
self.kv_caches)
# In parallel to TPU execution, prepare the next iteration
if i < len(prompt_req_ids) - 1:
prompt_data = self._prepare_prompt(
req_index + 1, prompt_scheduled_tokens[i + 1])
elif i == len(prompt_req_ids) - 1 and len(decode_req_ids) > 0:
decode_data = self._prepare_decode(decode_req_ids)
# Update cached state
if seq_len >= len(req_state.prompt_token_ids):
# Transfer sampled tokens from TPU to CPU
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
self.prompt_token_ids.append(token_id)
# Update cached state
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# Run decodes (a single batch) # Run decodes (a single batch)
if len(decode_data.req_ids) > 0: if len(decode_req_ids) > 0:
if decode_data is None:
decode_data = self._prepare_decode(decode_req_ids)
# Forward # Forward
with set_forward_context(decode_data.attn_metadata, with set_forward_context(decode_data.attn_metadata,
self.vllm_config): self.vllm_config):
@ -359,59 +401,33 @@ class TPUModelRunner(ModelRunnerBase):
self.kv_caches) self.kv_caches)
# Transfer sampled tokens from TPU to CPU # Transfer sampled tokens from TPU to CPU
selected_token_ids_list = selected_token_ids.cpu().tolist() decode_token_ids = selected_token_ids.cpu().tolist()
# Update cached state # Update cached state
for i, req_id in enumerate(decode_data.req_ids): for i, req_id in enumerate(decode_req_ids):
req_index = self.input_batch.req_id_to_index[req_id] req_index = i
req_state = self.requests[req_id] req_state = self.requests[req_id]
seq_len = req_state.num_computed_tokens + 1
seq_len = (req_state.num_computed_tokens + token_id = decode_token_ids[i]
scheduler_output.num_scheduled_tokens[req_id])
token_id = selected_token_ids_list[i]
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1 self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id) req_state.output_token_ids.append(token_id)
sampled_token_ids_list[req_index] = token_id # Create final req_id => token lists.
# This must match the actual batch index positions
# Run each prompt self.req_ids.extend(decode_req_ids)
for (req_id, prompt_len, input_tokens, input_positions, self.req_ids.extend(prompt_req_ids)
attn_metadata) in prompt_data.zipped(): if decode_token_ids is not None:
assert req_id is not None self.sampled_token_ids.extend(decode_token_ids)
req_state = self.requests[req_id] self.sampled_token_ids.extend(self.prompt_token_ids)
req_index = self.input_batch.req_id_to_index[req_id]
# Forward
with set_forward_context(attn_metadata, self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(input_tokens, input_positions,
attn_metadata, self.kv_caches)
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len >= len(req_state.prompt_token_ids):
# Transfer sampled tokens from TPU to CPU
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
sampled_token_ids_list[req_index] = token_id
# Update cached state
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# Get req_ids
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
# Create output
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=req_ids, req_ids=self.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids_list, sampled_token_ids=self.sampled_token_ids,
logprob_token_ids_cpu=None, logprob_token_ids_cpu=None,
logprobs_cpu=None, logprobs_cpu=None,
) )
@ -710,7 +726,7 @@ class ModelWrapperV1(nn.Module):
return argmax_token_ids return argmax_token_ids
def _get_padded_prefill_len(x: int) -> int: def _get_padded_prompt_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest # length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance. # multiple of 16. This is also good for performance.