[V1] Input Batch Relocation (#10962)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2024-12-09 12:33:41 -05:00
committed by GitHub
parent aea2fc38c3
commit 25b79d9fd3
2 changed files with 283 additions and 270 deletions

View File

@ -0,0 +1,280 @@
# Datastructures defining an input batch
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set
import numpy as np
import torch
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List[MultiModalKwargs]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: List[int]
num_computed_tokens: int
output_token_ids: List[int]
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.device = device
self.pin_memory = pin_memory
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
self.req_id_to_index: Dict[str, int] = {}
self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
# Attention-related.
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: Set[str] = set()
self.random_reqs: Set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: Set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set()
# req_index -> generator
self.generators: Dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set()
def add_request(
self,
request: "CachedRequestState",
req_index: Optional[int] = None,
) -> None:
if req_index is None:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
req_id = request.req_id
self.req_ids[req_index] = req_id
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
self.greedy_reqs.add(req_id)
else:
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id)
self.generators[req_index] = request.generator
num_logprobs = sampling_params.logprobs
if num_logprobs is not None and num_logprobs > 0:
self.num_logprobs[req_id] = num_logprobs
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_id)
def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.req_ids[req_index] = None
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
return req_index
def clear(self) -> None:
self.req_ids = [None] * self.max_num_reqs
self.req_id_to_index.clear()
self.greedy_reqs.clear()
self.random_reqs.clear()
self.top_p_reqs.clear()
self.top_k_reqs.clear()
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
# The batched states are empty.
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = self.num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = empty_req_indices.pop()
if empty_index >= last_req_index:
break
# Swap the states.
req_id = self.req_ids[last_req_index]
self.req_ids[empty_index] = req_id
self.req_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
# TODO(woosuk): Optimize the copy of token_ids_cpu and
# block_table_cpu.
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table_cpu[empty_index] = self.block_table_cpu[
last_req_index]
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
# Decrement last_req_index since it is now empty.
last_req_index -= 1
def make_sampling_metadata(
self,
skip_copy: bool = False,
) -> SamplingMetadata:
if not skip_copy:
self.temperature[:self.num_reqs].copy_(
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_p[:self.num_reqs].copy_(
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_k[:self.num_reqs].copy_(
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
return SamplingMetadata(
temperature=self.temperature[:self.num_reqs],
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=self.top_p[:self.num_reqs],
top_k=self.top_k[:self.num_reqs],
no_top_p=self.no_top_p,
no_top_k=self.no_top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def max_num_logprobs(self) -> int:
return max(self.num_logprobs.values()) if self.num_logprobs else 0
@property
def no_logprob(self) -> bool:
return len(self.num_logprobs) == 0
@property
def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0

View File

@ -1,7 +1,6 @@
import gc
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import numpy as np
import torch
@ -15,16 +14,16 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
@ -609,269 +608,3 @@ class GPUModelRunner:
if batch_size <= size:
return size
return None
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List[MultiModalKwargs]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: List[int]
num_computed_tokens: int
output_token_ids: List[int]
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.device = device
self.pin_memory = pin_memory
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
self.req_id_to_index: Dict[str, int] = {}
self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
# Attention-related.
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: Set[str] = set()
self.random_reqs: Set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: Set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set()
# req_index -> generator
self.generators: Dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set()
def add_request(
self,
request: "CachedRequestState",
req_index: Optional[int] = None,
) -> None:
if req_index is None:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
req_id = request.req_id
self.req_ids[req_index] = req_id
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
self.greedy_reqs.add(req_id)
else:
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id)
self.generators[req_index] = request.generator
num_logprobs = sampling_params.logprobs
if num_logprobs is not None and num_logprobs > 0:
self.num_logprobs[req_id] = num_logprobs
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_id)
def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.req_ids[req_index] = None
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
return req_index
def clear(self) -> None:
self.req_ids = [None] * self.max_num_reqs
self.req_id_to_index.clear()
self.greedy_reqs.clear()
self.random_reqs.clear()
self.top_p_reqs.clear()
self.top_k_reqs.clear()
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
# The batched states are empty.
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = self.num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = empty_req_indices.pop()
if empty_index >= last_req_index:
break
# Swap the states.
req_id = self.req_ids[last_req_index]
self.req_ids[empty_index] = req_id
self.req_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
# TODO(woosuk): Optimize the copy of token_ids_cpu and
# block_table_cpu.
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table_cpu[empty_index] = self.block_table_cpu[
last_req_index]
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
# Decrement last_req_index since it is now empty.
last_req_index -= 1
def make_sampling_metadata(
self,
skip_copy: bool = False,
) -> SamplingMetadata:
if not skip_copy:
self.temperature[:self.num_reqs].copy_(
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_p[:self.num_reqs].copy_(
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_k[:self.num_reqs].copy_(
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
return SamplingMetadata(
temperature=self.temperature[:self.num_reqs],
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=self.top_p[:self.num_reqs],
top_k=self.top_k[:self.num_reqs],
no_top_p=self.no_top_p,
no_top_k=self.no_top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def max_num_logprobs(self) -> int:
return max(self.num_logprobs.values()) if self.num_logprobs else 0
@property
def no_logprob(self) -> bool:
return len(self.num_logprobs) == 0
@property
def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0