mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] Implement vLLM V1 [1/N] (#9289)
This commit is contained in:
@ -17,6 +17,7 @@ logger = init_logger(__name__)
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
@ -110,6 +111,10 @@ def get_attn_backend(
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
return FlashAttentionBackend
|
||||
if backend == _Backend.FLASH_ATTN_VLLM_V1:
|
||||
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend as FlashAttentionBackendV1)
|
||||
return FlashAttentionBackendV1
|
||||
if backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||
@ -215,6 +220,9 @@ def which_attn_to_use(
|
||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||
return _Backend.ROCM_FLASH
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
return _Backend.FLASH_ATTN_VLLM_V1
|
||||
|
||||
# FlashAttn in NVIDIA GPUs.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
if not current_platform.has_device_capability(80):
|
||||
|
@ -8,7 +8,7 @@ from typing import Iterator, List, Optional, Union
|
||||
import cloudpickle
|
||||
import zmq
|
||||
|
||||
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
|
||||
from vllm import AsyncEngineArgs, SamplingParams
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
# yapf conflicts with isort for this block
|
||||
@ -21,12 +21,17 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
RPCStartupRequest, RPCStartupResponse,
|
||||
RPCUProfileRequest)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
if VLLM_USE_V1:
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
else:
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
|
||||
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig]
|
||||
|
||||
@ -136,14 +141,16 @@ class MQLLMEngine:
|
||||
|
||||
executor_class = LLMEngine._get_executor_cls(engine_config)
|
||||
|
||||
return cls(
|
||||
ipc_path=ipc_path,
|
||||
use_async_sockets=engine_config.model_config.use_async_output_proc,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context)
|
||||
use_async_sockets = (engine_config.model_config.use_async_output_proc
|
||||
and not VLLM_USE_V1)
|
||||
|
||||
return cls(ipc_path=ipc_path,
|
||||
use_async_sockets=use_async_sockets,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context)
|
||||
|
||||
def start(self):
|
||||
try:
|
||||
|
@ -6,10 +6,10 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import envs
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence, get_beam_search_score)
|
||||
from vllm.engine.arg_utils import EngineArgs, TaskOption
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
@ -31,6 +31,11 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
|
||||
else:
|
||||
from vllm.engine.llm_engine import LLMEngine # type: ignore
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
|
@ -68,6 +68,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||
VLLM_CUSTOM_OPS: List[str] = []
|
||||
VLLM_DISABLED_KERNELS: List[str] = []
|
||||
VLLM_USE_V1: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -450,6 +451,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_DISABLED_KERNELS":
|
||||
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
|
||||
"VLLM_DISABLED_KERNELS"].split(","),
|
||||
|
||||
# If set, use the V1 code path.
|
||||
"VLLM_USE_V1":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -48,14 +48,15 @@ class LogitsProcessor(nn.Module):
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_metadata: Optional[SamplingMetadata] = None,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
hidden_states = _prune_hidden_states(hidden_states,
|
||||
sampling_metadata)
|
||||
if sampling_metadata is not None:
|
||||
hidden_states = _prune_hidden_states(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
@ -69,7 +70,8 @@ class LogitsProcessor(nn.Module):
|
||||
logits *= self.scale
|
||||
|
||||
# Apply logits processors (if any).
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
if sampling_metadata is not None:
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
|
||||
return logits
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
|
||||
Sequence, SequenceGroup)
|
||||
|
||||
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
|
||||
detokenize_incrementally)
|
||||
from .tokenizer import AnyTokenizer
|
||||
from .tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
@ -161,167 +163,3 @@ class Detokenizer:
|
||||
seq.output_text += new_decoded_token_text
|
||||
|
||||
return len(new_decoded_token_text)
|
||||
|
||||
|
||||
def _replace_none_with_empty(tokens: List[Optional[str]]):
|
||||
for i, token in enumerate(tokens):
|
||||
if token is None:
|
||||
tokens[i] = ""
|
||||
|
||||
|
||||
def _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer: AnyTokenizer,
|
||||
output_tokens: List[str],
|
||||
skip_special_tokens: bool,
|
||||
spaces_between_special_tokens: bool,
|
||||
) -> str:
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||
# the output_tokens. In Python, running a for loop over a list can be slow
|
||||
# even when the loop body is very simple.
|
||||
sub_texts: List[str] = []
|
||||
current_sub_text: List[str] = []
|
||||
all_special_tokens = set(tokenizer.all_special_tokens)
|
||||
for token in output_tokens:
|
||||
if skip_special_tokens and token in all_special_tokens:
|
||||
continue
|
||||
if token in tokenizer.get_added_vocab():
|
||||
if current_sub_text:
|
||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
current_sub_text = []
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
current_sub_text.append(token)
|
||||
if current_sub_text:
|
||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
if spaces_between_special_tokens:
|
||||
return " ".join(sub_texts)
|
||||
else:
|
||||
return "".join(sub_texts)
|
||||
|
||||
|
||||
# 5 is an arbitrary value that should work for all
|
||||
# tokenizers (bigger = more conservative).
|
||||
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
|
||||
def convert_prompt_ids_to_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_ids: List[int],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> Tuple[List[str], int, int]:
|
||||
"""Converts the prompt ids to tokens and returns the tokens and offsets
|
||||
for incremental detokenization.
|
||||
|
||||
Note that not all tokens are converted to strings. Only the tokens that
|
||||
are necessary for incremental detokenization are converted to strings.
|
||||
"""
|
||||
# We do not need to convert the whole prompt to tokens.
|
||||
# Offset a little more in case we have special tokens.
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
read_offset = len(new_tokens)
|
||||
prefix_offset = max(
|
||||
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||||
# This is required to guard against out-of-vocab prompt token ids
|
||||
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
|
||||
return new_tokens, prefix_offset, read_offset
|
||||
|
||||
|
||||
# Based on
|
||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||
# under Apache 2.0 license
|
||||
def detokenize_incrementally(
|
||||
tokenizer: AnyTokenizer,
|
||||
all_input_ids: List[int],
|
||||
prev_tokens: Optional[List[str]],
|
||||
prefix_offset: int,
|
||||
read_offset: int,
|
||||
skip_special_tokens: bool = False,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
) -> Tuple[List[str], str, int, int]:
|
||||
"""Detokenizes the input ids incrementally and returns the new tokens
|
||||
and the new text.
|
||||
|
||||
If `prev_tokens` is None, this function will convert the input ids to
|
||||
tokens and return the tokens and the new text. Otherwise, it will return the
|
||||
new tokens and the new text.
|
||||
|
||||
This function will also return the new prefix offset and the new read
|
||||
offset to be used in the next iteration.
|
||||
|
||||
The offsets are necessary to defeat cleanup algorithms in the decode which
|
||||
decide to add a space or not depending on the surrounding ids.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer to use.
|
||||
all_input_ids: The input ids. The last id is the new token id.
|
||||
prev_tokens: The previous tokens. If None, this function will convert
|
||||
the input ids to tokens and return the tokens and the new text.
|
||||
prefix_offset: The prefix offset.
|
||||
read_offset: The read offset.
|
||||
skip_special_tokens: Whether to skip special tokens.
|
||||
spaces_between_special_tokens: Whether to add spaces between special
|
||||
tokens.
|
||||
"""
|
||||
new_token_id = all_input_ids[-1]
|
||||
# This is the first iteration for this sequence
|
||||
is_first_iter = prev_tokens is None
|
||||
if is_first_iter:
|
||||
(prev_tokens, prefix_offset,
|
||||
read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer,
|
||||
all_input_ids[:-1],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
assert prev_tokens is not None
|
||||
|
||||
# If the new token id is out of bounds, return an empty string.
|
||||
if 0 <= new_token_id < len(tokenizer):
|
||||
# Put new_token_id in a list so skip_special_tokens is respected
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||
if isinstance(new_tokens, str):
|
||||
new_tokens = [new_tokens]
|
||||
else:
|
||||
new_tokens = [""]
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
|
||||
# If this is the first iteration, return all tokens.
|
||||
if is_first_iter:
|
||||
new_tokens = output_tokens
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||
# the decode which decide to add a space or not depending on the
|
||||
# surrounding ids.
|
||||
if tokenizer.is_fast or not tokenizer.get_added_vocab():
|
||||
prefix_text = tokenizer.convert_tokens_to_string(
|
||||
output_tokens[prefix_offset:read_offset])
|
||||
new_text = tokenizer.convert_tokens_to_string(
|
||||
output_tokens[prefix_offset:])
|
||||
else:
|
||||
prefix_text = _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer,
|
||||
output_tokens[prefix_offset:read_offset],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
new_text = _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer,
|
||||
output_tokens[prefix_offset:],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
if len(new_text) <= len(prefix_text) or new_text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
# If it's in the middle, it's probably a real invalid id generated
|
||||
# by the model
|
||||
return new_tokens, "", prefix_offset, read_offset
|
||||
|
||||
new_text = new_text[len(prefix_text):]
|
||||
return new_tokens, new_text, read_offset, len(output_tokens)
|
||||
|
167
vllm/transformers_utils/detokenizer_utils.py
Normal file
167
vllm/transformers_utils/detokenizer_utils.py
Normal file
@ -0,0 +1,167 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from .tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
def _replace_none_with_empty(tokens: List[Optional[str]]):
|
||||
for i, token in enumerate(tokens):
|
||||
if token is None:
|
||||
tokens[i] = ""
|
||||
|
||||
|
||||
def _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer: AnyTokenizer,
|
||||
output_tokens: List[str],
|
||||
skip_special_tokens: bool,
|
||||
spaces_between_special_tokens: bool,
|
||||
) -> str:
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||
# the output_tokens. In Python, running a for loop over a list can be slow
|
||||
# even when the loop body is very simple.
|
||||
sub_texts: List[str] = []
|
||||
current_sub_text: List[str] = []
|
||||
all_special_tokens = set(tokenizer.all_special_tokens)
|
||||
for token in output_tokens:
|
||||
if skip_special_tokens and token in all_special_tokens:
|
||||
continue
|
||||
if token in tokenizer.get_added_vocab():
|
||||
if current_sub_text:
|
||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
current_sub_text = []
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
current_sub_text.append(token)
|
||||
if current_sub_text:
|
||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
if spaces_between_special_tokens:
|
||||
return " ".join(sub_texts)
|
||||
else:
|
||||
return "".join(sub_texts)
|
||||
|
||||
|
||||
# 5 is an arbitrary value that should work for all
|
||||
# tokenizers (bigger = more conservative).
|
||||
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
|
||||
def convert_prompt_ids_to_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_ids: List[int],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> Tuple[List[str], int, int]:
|
||||
"""Converts the prompt ids to tokens and returns the tokens and offsets
|
||||
for incremental detokenization.
|
||||
|
||||
Note that not all tokens are converted to strings. Only the tokens that
|
||||
are necessary for incremental detokenization are converted to strings.
|
||||
"""
|
||||
# We do not need to convert the whole prompt to tokens.
|
||||
# Offset a little more in case we have special tokens.
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
read_offset = len(new_tokens)
|
||||
prefix_offset = max(
|
||||
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||||
# This is required to guard against out-of-vocab prompt token ids
|
||||
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
|
||||
return new_tokens, prefix_offset, read_offset
|
||||
|
||||
|
||||
# Based on
|
||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||
# under Apache 2.0 license
|
||||
def detokenize_incrementally(
|
||||
tokenizer: AnyTokenizer,
|
||||
all_input_ids: List[int],
|
||||
prev_tokens: Optional[List[str]],
|
||||
prefix_offset: int,
|
||||
read_offset: int,
|
||||
skip_special_tokens: bool = False,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
) -> Tuple[List[str], str, int, int]:
|
||||
"""Detokenizes the input ids incrementally and returns the new tokens
|
||||
and the new text.
|
||||
|
||||
If `prev_tokens` is None, this function will convert the input ids to
|
||||
tokens and return the tokens and the new text. Otherwise, it will return the
|
||||
new tokens and the new text.
|
||||
|
||||
This function will also return the new prefix offset and the new read
|
||||
offset to be used in the next iteration.
|
||||
|
||||
The offsets are necessary to defeat cleanup algorithms in the decode which
|
||||
decide to add a space or not depending on the surrounding ids.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer to use.
|
||||
all_input_ids: The input ids. The last id is the new token id.
|
||||
prev_tokens: The previous tokens. If None, this function will convert
|
||||
the input ids to tokens and return the tokens and the new text.
|
||||
prefix_offset: The prefix offset.
|
||||
read_offset: The read offset.
|
||||
skip_special_tokens: Whether to skip special tokens.
|
||||
spaces_between_special_tokens: Whether to add spaces between special
|
||||
tokens.
|
||||
"""
|
||||
new_token_id = all_input_ids[-1]
|
||||
# This is the first iteration for this sequence
|
||||
is_first_iter = prev_tokens is None
|
||||
if is_first_iter:
|
||||
(prev_tokens, prefix_offset,
|
||||
read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer,
|
||||
all_input_ids[:-1],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
assert prev_tokens is not None
|
||||
|
||||
# If the new token id is out of bounds, return an empty string.
|
||||
if 0 <= new_token_id < len(tokenizer):
|
||||
# Put new_token_id in a list so skip_special_tokens is respected
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||
if isinstance(new_tokens, str):
|
||||
new_tokens = [new_tokens]
|
||||
else:
|
||||
new_tokens = [""]
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
|
||||
# If this is the first iteration, return all tokens.
|
||||
if is_first_iter:
|
||||
new_tokens = output_tokens
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||
# the decode which decide to add a space or not depending on the
|
||||
# surrounding ids.
|
||||
if tokenizer.is_fast or not tokenizer.get_added_vocab():
|
||||
prefix_text = tokenizer.convert_tokens_to_string(
|
||||
output_tokens[prefix_offset:read_offset])
|
||||
new_text = tokenizer.convert_tokens_to_string(
|
||||
output_tokens[prefix_offset:])
|
||||
else:
|
||||
prefix_text = _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer,
|
||||
output_tokens[prefix_offset:read_offset],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
new_text = _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer,
|
||||
output_tokens[prefix_offset:],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
if len(new_text) <= len(prefix_text) or new_text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
# If it's in the middle, it's probably a real invalid id generated
|
||||
# by the model
|
||||
return new_tokens, "", prefix_offset, read_offset
|
||||
|
||||
new_text = new_text[len(prefix_text):]
|
||||
return new_tokens, new_text, read_offset, len(output_tokens)
|
0
vllm/v1/attention/__init__.py
Normal file
0
vllm/v1/attention/__init__.py
Normal file
0
vllm/v1/attention/backends/__init__.py
Normal file
0
vllm/v1/attention/backends/__init__.py
Normal file
241
vllm/v1/attention/backends/flash_attn.py
Normal file
241
vllm/v1/attention/backends/flash_attn.py
Normal file
@ -0,0 +1,241 @@
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "flash-attn-vllm-v1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_start_loc: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = ((sliding_window, sliding_window)
|
||||
if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if sliding_window is not None:
|
||||
# NOTE(woosuk): flash-attn's sliding window does not work with
|
||||
# paged KV cache.
|
||||
raise ValueError(
|
||||
"Sliding window is not supported in FlashAttention.")
|
||||
|
||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
output = torch.ops.vllm.unified_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
kv_cache,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::unified_flash_attention",
|
||||
mutates_args=["kv_cache"])
|
||||
def unified_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
current_metadata = get_forward_context()
|
||||
if current_metadata is None:
|
||||
# Profiling run.
|
||||
return torch.empty_like(query)
|
||||
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
attn_metadata.slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_query_len,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
max_seqlen_k=attn_metadata.max_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
window_size=window_size,
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
@unified_flash_attention.register_fake
|
||||
def _(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
0
vllm/v1/core/__init__.py
Normal file
0
vllm/v1/core/__init__.py
Normal file
108
vllm/v1/core/kv_cache_manager.py
Normal file
108
vllm/v1/core/kv_cache_manager.py
Normal file
@ -0,0 +1,108 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = True,
|
||||
num_preallocate_tokens: int = 64,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_caching = enable_caching
|
||||
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
|
||||
# blocks for each request. For example, when a request reaches the end
|
||||
# of its block table, we preallocate N blocks in advance. This way, we
|
||||
# reduce the overhead of updating free_block_ids and ref_cnts for each
|
||||
# request every step (at the cost of some memory waste).
|
||||
# NOTE(woosuk): This is different from the "lookahead" slots since this
|
||||
# does not guarantee that the request always has N empty blocks. After
|
||||
# the request gets N empty blocks, it starts to use the blocks without
|
||||
# further allocation. When it uses up all the N empty blocks, it gets
|
||||
# N new empty blocks.
|
||||
self.num_preallocate_tokens = num_preallocate_tokens
|
||||
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
|
||||
|
||||
self.free_block_ids = list(range(num_gpu_blocks))
|
||||
self.req_to_block_ids: Dict[str, List[int]] = {}
|
||||
self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32)
|
||||
|
||||
def get_computed_blocks(self, request: Request) -> List[int]:
|
||||
if not self.enable_caching:
|
||||
# No prefix caching.
|
||||
return []
|
||||
# TODO(woosuk): Implement hash-based caching.
|
||||
return []
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
request: Request,
|
||||
num_tokens: int,
|
||||
) -> Optional[List[int]]:
|
||||
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
|
||||
self.block_size)
|
||||
req_block_ids = self.req_to_block_ids[request.request_id]
|
||||
if num_required_blocks <= len(req_block_ids):
|
||||
# No new block is needed.
|
||||
return []
|
||||
|
||||
num_new_blocks = num_required_blocks - len(req_block_ids)
|
||||
num_free_blocks = len(self.free_block_ids)
|
||||
if num_new_blocks > num_free_blocks:
|
||||
# Cannot allocate new blocks.
|
||||
return None
|
||||
|
||||
# Allocate new blocks.
|
||||
num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks,
|
||||
num_free_blocks)
|
||||
new_block_ids = self._get_new_blocks(num_new_blocks)
|
||||
req_block_ids.extend(new_block_ids)
|
||||
self.ref_cnts[new_block_ids] += 1
|
||||
return new_block_ids
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
request: Request,
|
||||
num_tokens: int,
|
||||
computed_block_ids: List[int],
|
||||
) -> Optional[List[int]]:
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_free_blocks = len(self.free_block_ids)
|
||||
if num_required_blocks > num_free_blocks:
|
||||
# Cannot allocate new blocks.
|
||||
return None
|
||||
|
||||
num_new_blocks = min(num_required_blocks + self.num_preallocate_blocks,
|
||||
num_free_blocks)
|
||||
new_block_ids = self._get_new_blocks(num_new_blocks)
|
||||
block_ids = computed_block_ids + new_block_ids
|
||||
self.req_to_block_ids[request.request_id] = block_ids
|
||||
self.ref_cnts[block_ids] += 1
|
||||
return new_block_ids
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
block_ids = self.req_to_block_ids.pop(request.request_id)
|
||||
self.ref_cnts[block_ids] -= 1
|
||||
for block_id in block_ids:
|
||||
ref_cnt = self.ref_cnts[block_id]
|
||||
if ref_cnt == 0:
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def _get_new_blocks(self, num_blocks: int) -> List[int]:
|
||||
assert num_blocks <= len(self.free_block_ids)
|
||||
new_block_ids = self.free_block_ids[-num_blocks:]
|
||||
self.free_block_ids = self.free_block_ids[:-num_blocks]
|
||||
return new_block_ids
|
412
vllm/v1/core/scheduler.py
Normal file
412
vllm/v1/core/scheduler.py
Normal file
@ -0,0 +1,412 @@
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
# TODO: Support LoRA.
|
||||
assert lora_config is None, "V1 does not support LoRA yet."
|
||||
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
# Create the block space manager.
|
||||
self.kv_cache_manager = KVCacheManager(
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
sliding_window=self.cache_config.sliding_window,
|
||||
enable_caching=True)
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
# Scheduling constraints.
|
||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_scheduled_tokens = \
|
||||
self.scheduler_config.max_num_batched_tokens
|
||||
self.max_model_len = self.scheduler_config.max_model_len
|
||||
|
||||
# req_id -> Request
|
||||
self.requests: Dict[str, Request] = {}
|
||||
# Priority queues for requests.
|
||||
self.waiting: Deque[Request] = deque()
|
||||
self.running: List[Request] = []
|
||||
|
||||
# The request IDs that are finished in between the previous and the
|
||||
# current steps. This is used to notify the workers about the finished
|
||||
# requests so that they can free the cached states for those requests.
|
||||
# This is flushed at the end of each scheduling step.
|
||||
self.finished_req_ids: Set[str] = set()
|
||||
|
||||
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
# Request id -> RunningRequestData
|
||||
self.running_reqs_data: Dict[str, RunningRequestData] = {}
|
||||
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
scheduled_new_reqs: List[Request] = []
|
||||
scheduled_resumed_reqs: List[Request] = []
|
||||
scheduled_running_reqs: List[Request] = []
|
||||
preempted_reqs: List[Request] = []
|
||||
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
# Each request just has the num_computed_tokens and num_tokens,
|
||||
# which is equal to len(prompt_token_ids) + len(output_token_ids).
|
||||
# At each step, the scheduler tries to assign tokens to the requests
|
||||
# so that each request's num_computed_tokens can catch up its
|
||||
# num_tokens. This is general enough to cover chunked prefills,
|
||||
# prefix caching, and the "jump forward" optimization in the future.
|
||||
|
||||
req_to_new_block_ids: Dict[str, List[int]] = {}
|
||||
num_scheduled_tokens: Dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
while req_index < len(self.running):
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
request = self.running[req_index]
|
||||
num_new_tokens = request.num_tokens - request.num_computed_tokens
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
while True:
|
||||
new_block_ids = self.kv_cache_manager.append_slots(
|
||||
request, num_new_tokens)
|
||||
if new_block_ids is None:
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
preempted_req = self.running.pop()
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
|
||||
self.waiting.appendleft(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
break
|
||||
else:
|
||||
# The request can be scheduled.
|
||||
scheduled_running_reqs.append(request)
|
||||
|
||||
req_to_new_block_ids[request.request_id] = new_block_ids
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
break
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
while self.waiting:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
request = self.waiting[0]
|
||||
# Get already-cached tokens.
|
||||
computed_block_ids = self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
||||
# sharing, `num_computed_tokens` is always a multiple of
|
||||
# `block_size`.
|
||||
num_computed_tokens = len(computed_block_ids) * self.block_size
|
||||
# Number of tokens to be scheduled.
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed requests,
|
||||
# which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
new_block_ids = self.kv_cache_manager.allocate_slots(
|
||||
request, num_new_tokens, computed_block_ids)
|
||||
if new_block_ids is None:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
|
||||
self.waiting.popleft()
|
||||
self.running.append(request)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Invalid request status: {request.status}")
|
||||
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
computed_block_ids + new_block_ids)
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
||||
len(scheduled_running_reqs) == len(self.running))
|
||||
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(req,
|
||||
req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
resumed_reqs_data = [
|
||||
ResumedRequestData.from_request(
|
||||
req, req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens) for req in scheduled_resumed_reqs
|
||||
]
|
||||
running_reqs_data = [
|
||||
self._make_running_request_data(
|
||||
req, req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens) for req in scheduled_running_reqs
|
||||
]
|
||||
preempted_req_ids = {req.request_id for req in preempted_reqs}
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_resumed_reqs=resumed_reqs_data,
|
||||
scheduled_running_reqs=running_reqs_data,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
preempted_req_ids=preempted_req_ids,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
)
|
||||
|
||||
self.finished_req_ids = set()
|
||||
return scheduler_output
|
||||
|
||||
def _make_running_request_data(
|
||||
self,
|
||||
request: Request,
|
||||
new_block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
) -> "RunningRequestData":
|
||||
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
if request.request_id in self.running_reqs_data:
|
||||
req_data = self.running_reqs_data[request.request_id]
|
||||
req_data.new_block_ids = new_block_ids
|
||||
req_data.num_computed_tokens = num_computed_tokens
|
||||
else:
|
||||
req_data = RunningRequestData.from_request(request, new_block_ids,
|
||||
num_computed_tokens)
|
||||
self.running_reqs_data[request.request_id] = req_data
|
||||
return req_data
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> List[Tuple[Request, int]]:
|
||||
# NOTE(woosuk): This method doesn't consider speculative decoding.
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
new_running: List[Request] = []
|
||||
# (request, num_sampled_tokens)
|
||||
sampled: List[Tuple[Request, int]] = []
|
||||
for request in self.running:
|
||||
req_id = request.request_id
|
||||
request.num_computed_tokens += num_scheduled_tokens[req_id]
|
||||
# When the request's num_computed_tokens catches up its num_tokens,
|
||||
# the request generates output tokens. Otherwise, we ignore the
|
||||
# sampler output for the request.
|
||||
assert request.num_computed_tokens <= request.num_tokens
|
||||
if request.num_computed_tokens == request.num_tokens:
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
# NOTE(woosuk): Currently, we assume that each request
|
||||
# generates at most one token at each step.
|
||||
token_id = sampled_token_ids[req_index]
|
||||
request.output_token_ids.append(token_id)
|
||||
sampled.append((request, 1))
|
||||
# TODO: Update the KV cache manager for prefix caching.
|
||||
|
||||
# Check if the request is finished.
|
||||
stopped = self._check_stop(request)
|
||||
if stopped:
|
||||
continue
|
||||
|
||||
new_running.append(request)
|
||||
self.running = new_running
|
||||
return sampled
|
||||
|
||||
def _check_stop(self, request: Request) -> bool:
|
||||
if (request.num_tokens >= self.max_model_len
|
||||
or request.num_output_tokens >= request.max_tokens):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
self._free_request(request)
|
||||
return True
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
last_token_id = request.output_token_ids[-1]
|
||||
if (not sampling_params.ignore_eos
|
||||
and last_token_id == request.eos_token_id):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
self._free_request(request)
|
||||
return True
|
||||
|
||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
request.stop_reason = last_token_id
|
||||
self._free_request(request)
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
self.waiting.append(request)
|
||||
self.requests[request.request_id] = request
|
||||
|
||||
def finish_requests(
|
||||
self,
|
||||
request_ids: Union[str, Iterable[str]],
|
||||
finished_status: RequestStatus,
|
||||
) -> None:
|
||||
"""Handles the finish signal from outside the scheduler.
|
||||
|
||||
For example, the API server can abort a request when the client
|
||||
disconnects.
|
||||
"""
|
||||
assert RequestStatus.is_finished(finished_status)
|
||||
if isinstance(request_ids, str):
|
||||
request_ids = (request_ids, )
|
||||
request_ids = set(request_ids)
|
||||
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# Invalid request ID.
|
||||
continue
|
||||
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
self.running.remove(request)
|
||||
else:
|
||||
self.waiting.remove(request)
|
||||
request.status = finished_status
|
||||
self._free_request(request)
|
||||
|
||||
def _free_request(self, request: Request) -> None:
|
||||
assert request.is_finished()
|
||||
self.kv_cache_manager.free(request)
|
||||
self.running_reqs_data.pop(request.request_id, None)
|
||||
del self.requests[request.request_id]
|
||||
self.finished_req_ids.add(request.request_id)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return len(self.waiting) + len(self.running)
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return self.get_num_unfinished_requests() > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewRequestData:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: List[int]
|
||||
prompt: Optional[str]
|
||||
multi_modal_data: Optional[MultiModalDataDict]
|
||||
sampling_params: SamplingParams
|
||||
block_ids: List[int]
|
||||
num_computed_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
) -> "NewRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.inputs["prompt_token_ids"],
|
||||
prompt=request.inputs.get("prompt"),
|
||||
multi_modal_data=request.inputs.get("multi_modal_data"),
|
||||
sampling_params=request.sampling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumedRequestData:
|
||||
|
||||
req_id: str
|
||||
block_ids: List[int]
|
||||
num_computed_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
) -> "ResumedRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunningRequestData:
|
||||
|
||||
req_id: str
|
||||
new_block_ids: List[int]
|
||||
num_computed_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
new_block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
) -> "RunningRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutput:
|
||||
|
||||
scheduled_new_reqs: List[NewRequestData]
|
||||
scheduled_resumed_reqs: List[ResumedRequestData]
|
||||
scheduled_running_reqs: List[RunningRequestData]
|
||||
|
||||
num_scheduled_tokens: Dict[str, int]
|
||||
total_num_scheduled_tokens: int
|
||||
|
||||
preempted_req_ids: Set[str]
|
||||
finished_req_ids: Set[str]
|
0
vllm/v1/engine/__init__.py
Normal file
0
vllm/v1/engine/__init__.py
Normal file
523
vllm/v1/engine/llm_engine.py
Normal file
523
vllm/v1/engine/llm_engine.py
Normal file
@ -0,0 +1,523 @@
|
||||
import time
|
||||
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
|
||||
Union)
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderLLMInputs, InputRegistry, PromptType)
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
decoding_config: Optional[DecodingConfig],
|
||||
observability_config: Optional[ObservabilityConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
executor_class: Type[GPUExecutor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
) -> None:
|
||||
# Override the configs for V1.
|
||||
# FIXME
|
||||
if usage_context == UsageContext.LLM_CLASS:
|
||||
scheduler_config.max_num_seqs = 1024
|
||||
scheduler_config.max_num_batched_tokens = 8192
|
||||
elif usage_context == UsageContext.OPENAI_API_SERVER:
|
||||
scheduler_config.max_num_seqs = 1024
|
||||
scheduler_config.max_num_batched_tokens = 2048
|
||||
|
||||
logger.info(
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
"model=%r, speculative_config=%r, tokenizer=%r, "
|
||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||
"override_neuron_config=%s, "
|
||||
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
|
||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||
"pipeline_parallel_size=%d, "
|
||||
"disable_custom_all_reduce=%s, quantization=%s, "
|
||||
"enforce_eager=%s, kv_cache_dtype=%s, "
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
"decoding_config=%r, observability_config=%r, "
|
||||
"seed=%d, served_model_name=%s, "
|
||||
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
|
||||
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
|
||||
VLLM_VERSION,
|
||||
model_config.model,
|
||||
speculative_config,
|
||||
model_config.tokenizer,
|
||||
model_config.skip_tokenizer_init,
|
||||
model_config.tokenizer_mode,
|
||||
model_config.revision,
|
||||
model_config.override_neuron_config,
|
||||
model_config.rope_scaling,
|
||||
model_config.rope_theta,
|
||||
model_config.tokenizer_revision,
|
||||
model_config.trust_remote_code,
|
||||
model_config.dtype,
|
||||
model_config.max_model_len,
|
||||
load_config.download_dir,
|
||||
load_config.load_format,
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.disable_custom_all_reduce,
|
||||
model_config.quantization,
|
||||
model_config.enforce_eager,
|
||||
cache_config.cache_dtype,
|
||||
model_config.quantization_param_path,
|
||||
device_config.device,
|
||||
decoding_config,
|
||||
observability_config,
|
||||
model_config.seed,
|
||||
model_config.served_model_name,
|
||||
scheduler_config.num_scheduler_steps,
|
||||
cache_config.enable_prefix_caching,
|
||||
model_config.use_async_output_proc,
|
||||
model_config.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.speculative_config = speculative_config
|
||||
self.load_config = load_config
|
||||
self.decoding_config = decoding_config or DecodingConfig()
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.observability_config = observability_config or ObservabilityConfig(
|
||||
)
|
||||
self.log_stats = log_stats
|
||||
|
||||
assert not self.model_config.skip_tokenizer_init
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
if self.tokenizer:
|
||||
# Ping the tokenizer to ensure liveness if it runs in a
|
||||
# different process.
|
||||
self.tokenizer.ping()
|
||||
self.detokenizer = Detokenizer(self.model_config.tokenizer)
|
||||
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
self.input_preprocessor = InputPreprocessor(model_config,
|
||||
self.tokenizer)
|
||||
self.input_registry = input_registry
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
model_config)
|
||||
|
||||
# Request id -> Request
|
||||
self.requests: Dict[str, Request] = {}
|
||||
# NOTE(woosuk): Now that the detokenizer works asynchronously, we need
|
||||
# to keep track of how many steps each request has been lagged behind
|
||||
# in terms of detokenization.
|
||||
# Request id -> how many detokenizer steps the request should wait for.
|
||||
self.num_lagged_steps: Dict[str, int] = {}
|
||||
# OPTIMIZATION: Cache the request output and update it incrementally.
|
||||
# This is used to avoid creating a new RequestOutput object every step.
|
||||
# Request id -> RequestOutput
|
||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||
|
||||
self.model_executor = executor_class(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
speculative_config=speculative_config,
|
||||
load_config=load_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
observability_config=self.observability_config,
|
||||
)
|
||||
assert self.model_config.task != "embedding"
|
||||
self._initialize_kv_caches()
|
||||
|
||||
# Create the scheduler.
|
||||
# NOTE: the cache_config here have been updated with the numbers of
|
||||
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
|
||||
)
|
||||
|
||||
if self.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_gpu_blocks,
|
||||
num_gpu_blocks_override)
|
||||
num_gpu_blocks = num_gpu_blocks_override
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = 0
|
||||
self.model_executor.initialize_cache(num_gpu_blocks)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_config = engine_args.create_engine_config()
|
||||
executor_class = cls._get_executor_cls(engine_config)
|
||||
# Create the LLM engine.
|
||||
engine = cls(
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
return engine
|
||||
|
||||
def _init_tokenizer(self) -> BaseTokenizerGroup:
|
||||
return init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
parallel_config=self.parallel_config,
|
||||
enable_lora=bool(self.lora_config))
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str,
|
||||
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs],
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> None:
|
||||
assert prompt_adapter_request is None
|
||||
assert trace_headers is None
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
# TODO(woosuk): Support embedding mode.
|
||||
assert isinstance(params, SamplingParams)
|
||||
sampling_params = params.clone()
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
|
||||
# TODO(woosuk): Check max_logprobs
|
||||
# TODO(woosuk): Support encoder-decoder models.
|
||||
req = Request(request_id, processed_inputs, params, eos_token_id,
|
||||
arrival_time)
|
||||
self.requests[request_id] = req
|
||||
self.num_lagged_steps[request_id] = 0
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
raise NotImplementedError("TP not implemented yet.")
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
assert priority == 0, "vLLM V1 does not support priority at the moment."
|
||||
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
self.scheduler.finish_requests(request_id,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Gets the number of unfinished requests."""
|
||||
return len(self.requests)
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return len(self.requests) > 0
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
# NOTE(woosuk): This method may return an empty list when the
|
||||
# detokenizer is still processing the outputs. This should not be
|
||||
# considered as the end of the generation process.
|
||||
# FIXME(woosuk): Currently, the step method is inefficient because it
|
||||
# creates RequestOutput objects for all running requests, while they
|
||||
# may not be needed unless the output is streamed to the client.
|
||||
if self.scheduler.has_unfinished_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
sampled = self.scheduler.update_from_output(
|
||||
scheduler_output, output)
|
||||
self.send_to_detokenizer(sampled)
|
||||
req_outputs = self.recv_from_detokenizer()
|
||||
return req_outputs
|
||||
|
||||
def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
|
||||
inputs = DetokenizerInputs(
|
||||
req_ids=[],
|
||||
prompt_token_ids=[],
|
||||
new_token_ids=[],
|
||||
skip_special_tokens=[],
|
||||
spaces_between_special_tokens=[],
|
||||
free_req_ids=[], # TODO(woosuk): Implement freeing.
|
||||
)
|
||||
for req, num_tokens in sampled:
|
||||
inputs.req_ids.append(req.request_id)
|
||||
if len(req.output_token_ids) == num_tokens:
|
||||
# The request is first detokenized.
|
||||
inputs.prompt_token_ids.append(req.prompt_token_ids)
|
||||
else:
|
||||
# The prompt token ids are already cached in the detokenizer.
|
||||
inputs.prompt_token_ids.append([])
|
||||
inputs.new_token_ids.append(req.output_token_ids[-num_tokens:])
|
||||
inputs.skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens)
|
||||
inputs.spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens)
|
||||
|
||||
# Update the number of lagged steps.
|
||||
self.num_lagged_steps[req.request_id] += 1
|
||||
self.detokenizer.send(inputs)
|
||||
|
||||
def recv_from_detokenizer(self) -> List[RequestOutput]:
|
||||
detokenizer_output = self.detokenizer.recv()
|
||||
if detokenizer_output is None:
|
||||
return []
|
||||
|
||||
req_outputs: List[RequestOutput] = []
|
||||
num_reqs = len(detokenizer_output.req_ids)
|
||||
for i in range(num_reqs):
|
||||
req_id = detokenizer_output.req_ids[i]
|
||||
req = self.requests[req_id]
|
||||
req.output_text += detokenizer_output.detokenized_texts[i]
|
||||
|
||||
self.num_lagged_steps[req_id] -= 1
|
||||
finished = (self.num_lagged_steps[req_id] == 0
|
||||
and req.is_finished())
|
||||
req_output = self._make_request_output(
|
||||
req, detokenizer_output.num_output_token_ids[i],
|
||||
detokenizer_output.detokenized_texts[i], finished)
|
||||
req_outputs.append(req_output)
|
||||
|
||||
if finished:
|
||||
del self.requests[req_id]
|
||||
del self.num_lagged_steps[req_id]
|
||||
del self.request_outputs[req_id]
|
||||
return req_outputs
|
||||
|
||||
def terminate_detokenizer(self) -> None:
|
||||
self.detokenizer.terminate()
|
||||
|
||||
def _make_request_output(
|
||||
self,
|
||||
request: Request,
|
||||
num_output_tokens: int,
|
||||
new_output_text: str,
|
||||
finished: bool,
|
||||
) -> RequestOutput:
|
||||
req_output = self.request_outputs.get(request.request_id)
|
||||
if req_output is None:
|
||||
# TODO: Support `n` > 1.
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None, # TODO
|
||||
finish_reason=None,
|
||||
stop_reason=None,
|
||||
lora_request=None,
|
||||
)
|
||||
req_output = RequestOutput(
|
||||
request_id=request.request_id,
|
||||
prompt=request.prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt_logprobs=None, # TODO
|
||||
outputs=[completion_output],
|
||||
finished=False,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
self.request_outputs[request.request_id] = req_output
|
||||
|
||||
completion_output = req_output.outputs[0]
|
||||
if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE:
|
||||
completion_output.text += new_output_text
|
||||
completion_output.token_ids = (
|
||||
request.output_token_ids[:num_output_tokens])
|
||||
elif request.sampling_params.output_kind == RequestOutputKind.DELTA:
|
||||
completion_output.text = new_output_text
|
||||
num_prev_tokens = len(completion_output.token_ids)
|
||||
completion_output.token_ids = request.output_token_ids[
|
||||
num_prev_tokens:num_output_tokens]
|
||||
elif (request.sampling_params.output_kind ==
|
||||
RequestOutputKind.FINAL_ONLY):
|
||||
if finished:
|
||||
completion_output.text = request.output_text
|
||||
completion_output.token_ids = request.output_token_ids
|
||||
else:
|
||||
completion_output.text = ""
|
||||
completion_output.token_ids = []
|
||||
|
||||
if finished:
|
||||
completion_output.finish_reason = request.get_finished_reason()
|
||||
completion_output.stop_reason = request.stop_reason
|
||||
req_output.finished = finished
|
||||
return req_output
|
||||
|
||||
def check_health(self) -> None:
|
||||
if self.tokenizer:
|
||||
self.tokenizer.check_health()
|
||||
self.model_executor.check_health()
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderLLMInputs]):
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
if self.model_config.is_multimodal_model:
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
raise ValueError(
|
||||
f"The prompt (total length {len(prompt_ids)}) is too long "
|
||||
f"to fit into the model (context length {max_prompt_len}). "
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens plus multimodal tokens. For image "
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well.")
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
return outputs
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
"""Gets the model configuration."""
|
||||
return self.model_config
|
||||
|
||||
def get_parallel_config(self) -> ParallelConfig:
|
||||
"""Gets the parallel configuration."""
|
||||
return self.parallel_config
|
||||
|
||||
def get_decoding_config(self) -> DecodingConfig:
|
||||
"""Gets the decoding configuration."""
|
||||
return self.decoding_config
|
||||
|
||||
def get_scheduler_config(self) -> SchedulerConfig:
|
||||
"""Gets the scheduler configuration."""
|
||||
return self.scheduler_config
|
||||
|
||||
def get_lora_config(self) -> LoRAConfig:
|
||||
"""Gets the LoRA configuration."""
|
||||
return self.lora_config
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(cls, engine_config: EngineConfig):
|
||||
return GPUExecutor
|
||||
|
||||
def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
def do_log_stats(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def is_encoder_decoder_model(self) -> bool:
|
||||
return False
|
||||
|
||||
def start_profile(self) -> None:
|
||||
pass
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
pass
|
||||
|
||||
def get_tokenizer_group(self, *args, **kwargs):
|
||||
return self.tokenizer
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
config = try_get_generation_config(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {}
|
||||
|
||||
return config.to_diff_dict()
|
0
vllm/v1/executor/__init__.py
Normal file
0
vllm/v1/executor/__init__.py
Normal file
100
vllm/v1/executor/gpu_executor.py
Normal file
100
vllm/v1/executor/gpu_executor.py
Normal file
@ -0,0 +1,100 @@
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUExecutor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
observability_config: Optional[ObservabilityConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.speculative_config = speculative_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.observability_config = observability_config
|
||||
|
||||
self.worker = self._create_worker()
|
||||
self.worker.initialize()
|
||||
self.worker.load_model()
|
||||
|
||||
def _create_worker(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None) -> Worker:
|
||||
"""Return worker init args for a given rank."""
|
||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||
|
||||
if distributed_init_method is None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return Worker(
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
speculative_config=self.speculative_config,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
observability_config=self.observability_config,
|
||||
)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
return self.worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info("# GPU blocks: %d", num_gpu_blocks)
|
||||
self.worker.initialize_cache(num_gpu_blocks)
|
||||
self.worker.compile_or_warm_up_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> ModelRunnerOutput:
|
||||
output = self.worker.execute_model(scheduler_output)
|
||||
return output
|
||||
|
||||
def check_health(self) -> None:
|
||||
# GPUExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
37
vllm/v1/outputs.py
Normal file
37
vllm/v1/outputs.py
Normal file
@ -0,0 +1,37 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
sampled_token_ids: torch.Tensor
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: Optional[torch.Tensor]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: Optional[torch.Tensor]
|
||||
|
||||
# TODO: Support prompt logprobs.
|
||||
prompt_logprob_token_ids: Optional[torch.Tensor]
|
||||
prompt_logprobs: Optional[torch.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRunnerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
req_ids: List[str]
|
||||
# req_id -> index
|
||||
req_id_to_index: Dict[str, int]
|
||||
|
||||
# [num_reqs]
|
||||
sampled_token_ids_cpu: torch.Tensor
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids_cpu: Optional[torch.Tensor]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs_cpu: Optional[torch.Tensor]
|
92
vllm/v1/request.py
Normal file
92
vllm/v1/request.py
Normal file
@ -0,0 +1,92 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import RequestMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import DecoderOnlyInputs
|
||||
|
||||
|
||||
class Request:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: "DecoderOnlyInputs",
|
||||
sampling_params: SamplingParams,
|
||||
eos_token_id: Optional[int],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.inputs = inputs
|
||||
self.sampling_params = sampling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
self.metrics = RequestMetrics(arrival_time=arrival_time,
|
||||
last_token_time=arrival_time,
|
||||
first_scheduled_time=None,
|
||||
first_token_time=None,
|
||||
time_in_queue=None)
|
||||
self.lora_request = lora_request
|
||||
|
||||
self.status = RequestStatus.WAITING
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
|
||||
self.prompt = inputs.get("prompt")
|
||||
self.prompt_token_ids = inputs["prompt_token_ids"]
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
self.output_token_ids: List[int] = []
|
||||
self.output_text = ""
|
||||
self.num_computed_tokens = 0
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
|
||||
@property
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self.output_token_ids)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return RequestStatus.is_finished(self.status)
|
||||
|
||||
def get_finished_reason(self) -> Union[str, None]:
|
||||
return RequestStatus.get_finished_reason(self.status)
|
||||
|
||||
|
||||
class RequestStatus(enum.IntEnum):
|
||||
"""Status of a sequence."""
|
||||
WAITING = 0
|
||||
RUNNING = 1
|
||||
PREEMPTED = 2
|
||||
# Note: anything after PREEMPTED (2) will be considered
|
||||
# as a finished status.
|
||||
FINISHED_STOPPED = 3
|
||||
FINISHED_LENGTH_CAPPED = 4
|
||||
FINISHED_ABORTED = 5
|
||||
FINISHED_IGNORED = 6
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "RequestStatus") -> bool:
|
||||
return status > RequestStatus.PREEMPTED
|
||||
|
||||
@staticmethod
|
||||
def get_finished_reason(status: "RequestStatus") -> Union[str, None]:
|
||||
return _FINISHED_REASON_MAP.get(status)
|
||||
|
||||
|
||||
# Mapping of finished statuses to their finish reasons.
|
||||
# NOTE: The ignored sequences are the sequences whose prompt lengths
|
||||
# are longer than the model's length cap. Therefore, the stop
|
||||
# reason should also be "length" as in OpenAI API.
|
||||
_FINISHED_REASON_MAP = {
|
||||
RequestStatus.FINISHED_STOPPED: "stop",
|
||||
RequestStatus.FINISHED_LENGTH_CAPPED: "length",
|
||||
RequestStatus.FINISHED_ABORTED: "abort",
|
||||
RequestStatus.FINISHED_IGNORED: "length",
|
||||
}
|
0
vllm/v1/sample/__init__.py
Normal file
0
vllm/v1/sample/__init__.py
Normal file
22
vllm/v1/sample/metadata.py
Normal file
22
vllm/v1/sample/metadata.py
Normal file
@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: torch.Tensor
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
top_p: torch.Tensor
|
||||
top_k: torch.Tensor
|
||||
no_top_p: bool
|
||||
no_top_k: bool
|
||||
|
||||
generators: List[Optional[torch.Generator]]
|
||||
no_generator: bool
|
||||
|
||||
max_num_logprobs: int
|
161
vllm/v1/sample/sampler.py
Normal file
161
vllm/v1/sample/sampler.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
logits = self.apply_top_k_top_p(logits, sampling_metadata)
|
||||
|
||||
probs = self.get_probs(logits)
|
||||
sampled = self.sample(probs, sampling_metadata)
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
if sampling_metadata.max_num_logprobs > 0:
|
||||
logprobs = self.get_logprobs(logits)
|
||||
# FIXME: Mask the sampled token_id, get topk logprobs,
|
||||
# and concatenate the topk with the sampled token_id.
|
||||
topk_logprobs, topk_indices = torch.topk(
|
||||
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
|
||||
# Use int32 to reduce the tensor size.
|
||||
topk_indices = topk_indices.to(torch.int32)
|
||||
else:
|
||||
topk_logprobs = None
|
||||
topk_indices = None
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=sampled,
|
||||
logprob_token_ids=topk_indices,
|
||||
logprobs=topk_logprobs,
|
||||
prompt_logprob_token_ids=None,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Use float32 to apply temperature scaling.
|
||||
logits = logits.to(torch.float32)
|
||||
# Avoid division by zero.
|
||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(temp.unsqueeze(dim=1))
|
||||
return logits
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
return _apply_top_k_top_p(
|
||||
logits,
|
||||
sampling_metadata.no_top_k,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.no_top_p,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
|
||||
return probs.argmax(dim=-1).view(-1)
|
||||
|
||||
def random_sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
generators: List[Optional[torch.Generator]],
|
||||
no_generator: bool,
|
||||
) -> torch.Tensor:
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
q.exponential_()
|
||||
if not no_generator:
|
||||
assert len(generators) == probs.shape[0]
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in enumerate(generators):
|
||||
if generator is not None:
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert not (sampling_metadata.all_greedy
|
||||
and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_greedy:
|
||||
return self.greedy_sample(probs)
|
||||
if sampling_metadata.all_random:
|
||||
return self.random_sample(probs, sampling_metadata.generators,
|
||||
sampling_metadata.no_generator)
|
||||
|
||||
greedy_sampled = self.greedy_sample(probs)
|
||||
random_sampled = self.random_sample(probs,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.no_generator)
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
)
|
||||
return sampled
|
||||
|
||||
|
||||
# TODO(woosuk): Optimize this with a custom kernel.
|
||||
def _apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if no_top_k and no_top_p:
|
||||
return logits
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if not no_top_k:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if not no_top_p:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
0
vllm/v1/tokenizer/__init__.py
Normal file
0
vllm/v1/tokenizer/__init__.py
Normal file
215
vllm/v1/tokenizer/detokenizer.py
Normal file
215
vllm/v1/tokenizer/detokenizer.py
Normal file
@ -0,0 +1,215 @@
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
from msgspec import msgpack
|
||||
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
|
||||
class DetokenizerInputs(msgspec.Struct):
|
||||
|
||||
# [num_reqs]
|
||||
req_ids: List[str]
|
||||
# A request's prompt token ids is sent to the detokenizer only when
|
||||
# the request is first detokenized. Otherwise, an empty list is sent.
|
||||
prompt_token_ids: List[List[int]]
|
||||
new_token_ids: List[List[int]]
|
||||
skip_special_tokens: List[bool]
|
||||
spaces_between_special_tokens: List[bool]
|
||||
|
||||
# [num_free_reqs]
|
||||
free_req_ids: List[str]
|
||||
|
||||
|
||||
class DetokenizerOutputs(msgspec.Struct):
|
||||
|
||||
# [num_reqs]
|
||||
req_ids: List[str]
|
||||
detokenized_texts: List[str]
|
||||
# NOTE(woosuk): The number of the output token ids of each request
|
||||
# at the time of detokenization. The detokenizer returns this to the engine
|
||||
# because the request state (including the output token ids) is
|
||||
# asynchronously updated in the engine, while RequestOutput requires the
|
||||
# output token ids to be consistent with the detokenized text.
|
||||
num_output_token_ids: List[int]
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
|
||||
def __init__(self, tokenizer_name: str):
|
||||
# FIXME(woosuk): Currently, the detokenizer is just a hacky prototype.
|
||||
# For example, it does not terminate properly. We need to improve this.
|
||||
self.push_port = get_open_port()
|
||||
self.pull_port = get_open_port()
|
||||
self.detokenizer = DetokenizerProc(tokenizer_name, self.push_port,
|
||||
self.pull_port)
|
||||
self.detokenizer.start()
|
||||
|
||||
self.zmq_context = zmq.Context()
|
||||
self.push_socket = self.zmq_context.socket(zmq.PUSH)
|
||||
self.push_socket.connect(f"tcp://localhost:{self.push_port}")
|
||||
self.pull_socket = self.zmq_context.socket(zmq.PULL)
|
||||
self.pull_socket.connect(f"tcp://localhost:{self.pull_port}")
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.pull_socket, zmq.POLLIN)
|
||||
self.msgpack_encoder = msgpack.Encoder()
|
||||
self.msgpack_decoder = msgpack.Decoder(DetokenizerOutputs)
|
||||
|
||||
def send(self, inputs: DetokenizerInputs) -> None:
|
||||
self.push_socket.send(self.msgpack_encoder.encode(inputs),
|
||||
flags=zmq.NOBLOCK)
|
||||
|
||||
def recv(self) -> Optional[DetokenizerOutputs]:
|
||||
socks = dict(self.poller.poll(timeout=0))
|
||||
if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN:
|
||||
msg = self.pull_socket.recv()
|
||||
return self.msgpack_decoder.decode(msg)
|
||||
return None
|
||||
|
||||
def terminate(self) -> None:
|
||||
self.push_socket.send(b"", flags=zmq.NOBLOCK)
|
||||
self.detokenizer.join()
|
||||
|
||||
|
||||
class DetokenizerProc(multiprocessing.Process):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_name: str,
|
||||
pull_port: int,
|
||||
push_port: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer_name = tokenizer_name
|
||||
# NOTE: The pull_port of the detokenizer should be the same as the
|
||||
# push_port of the engine. Vice versa.
|
||||
self.pull_port = pull_port
|
||||
self.push_port = push_port
|
||||
|
||||
def run(self):
|
||||
# Initialize these objects after the process is forked since they are
|
||||
# not picklable.
|
||||
self.msgpack_encoder = msgpack.Encoder()
|
||||
self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs)
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_name)
|
||||
# req_id -> RequestState
|
||||
self.request_states: Dict[str, RequestState] = {}
|
||||
|
||||
self.zmq_context = zmq.Context()
|
||||
self.pull_socket = self.zmq_context.socket(zmq.PULL)
|
||||
self.pull_socket.bind(f"tcp://*:{self.pull_port}")
|
||||
self.push_socket = self.zmq_context.socket(zmq.PUSH)
|
||||
self.push_socket.bind(f"tcp://*:{self.push_port}")
|
||||
|
||||
while True:
|
||||
message = self.pull_socket.recv()
|
||||
if message == b"":
|
||||
# Terminate signal.
|
||||
break
|
||||
inputs = self.msgpack_decoder.decode(message)
|
||||
|
||||
for req_id in inputs.free_req_ids:
|
||||
self.free(req_id)
|
||||
|
||||
detokenized_texts: List[str] = []
|
||||
num_output_token_ids: List[int] = []
|
||||
num_reqs = len(inputs.req_ids)
|
||||
for i in range(num_reqs):
|
||||
req_id = inputs.req_ids[i]
|
||||
if req_id not in self.request_states:
|
||||
self.add_request(
|
||||
request_id=req_id,
|
||||
prompt_token_ids=inputs.prompt_token_ids[i],
|
||||
skip_special_tokens=inputs.skip_special_tokens[i],
|
||||
spaces_between_special_tokens=inputs.
|
||||
spaces_between_special_tokens[i],
|
||||
)
|
||||
new_str = self.detokenize(req_id, inputs.new_token_ids[i])
|
||||
detokenized_texts.append(new_str)
|
||||
req_state = self.request_states[req_id]
|
||||
num_output_token_ids.append(
|
||||
len(req_state.token_ids) - req_state.num_prompt_tokens)
|
||||
|
||||
detokenized = DetokenizerOutputs(
|
||||
req_ids=inputs.req_ids,
|
||||
detokenized_texts=detokenized_texts,
|
||||
num_output_token_ids=num_output_token_ids,
|
||||
)
|
||||
self.push_socket.send(self.msgpack_encoder.encode(detokenized),
|
||||
flags=zmq.NOBLOCK)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: List[int],
|
||||
skip_special_tokens: bool,
|
||||
spaces_between_special_tokens: bool,
|
||||
) -> None:
|
||||
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_ids=prompt_token_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
self.request_states[request_id] = RequestState(
|
||||
req_id=request_id,
|
||||
token_ids=prompt_token_ids,
|
||||
tokens=tokens,
|
||||
num_prompt_tokens=len(prompt_token_ids),
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
del self.request_states[request_id]
|
||||
|
||||
def detokenize(self, request_id: str, new_token_ids: List[int]) -> str:
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
# new_token_ids is more than 1. We need to optimize this.
|
||||
req_state = self.request_states[request_id]
|
||||
decoded_text = ""
|
||||
for new_token_id in new_token_ids:
|
||||
req_state.token_ids.append(new_token_id)
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=req_state.token_ids,
|
||||
prev_tokens=req_state.tokens,
|
||||
prefix_offset=req_state.prefix_offset,
|
||||
read_offset=req_state.read_offset,
|
||||
skip_special_tokens=req_state.skip_special_tokens,
|
||||
spaces_between_special_tokens=req_state.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
req_state.tokens.extend(new_tokens)
|
||||
req_state.prefix_offset = prefix_offset
|
||||
req_state.read_offset = read_offset
|
||||
req_state.output_text += new_decoded_token_text
|
||||
decoded_text += new_decoded_token_text
|
||||
return decoded_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestState:
|
||||
|
||||
req_id: str
|
||||
|
||||
token_ids: List[int]
|
||||
tokens: List[str]
|
||||
num_prompt_tokens: int
|
||||
|
||||
prefix_offset: int
|
||||
read_offset: int
|
||||
|
||||
skip_special_tokens: bool
|
||||
spaces_between_special_tokens: bool
|
||||
|
||||
output_text: str = ""
|
0
vllm/v1/worker/__init__.py
Normal file
0
vllm/v1/worker/__init__.py
Normal file
690
vllm/v1/worker/gpu_model_runner.py
Normal file
690
vllm/v1/worker/gpu_model_runner.py
Normal file
@ -0,0 +1,690 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.observability_config = observability_config
|
||||
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_attention_layers(
|
||||
parallel_config)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
|
||||
# Lazy initialization
|
||||
# self.model: nn.Module # Set after load_model
|
||||
self.kv_caches: List[torch.Tensor] = []
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.scheduler_config.max_num_seqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the pre-empted requests.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
|
||||
# Remove the requests from the persistent batch.
|
||||
stopped_req_ids = set().union(
|
||||
scheduler_output.preempted_req_ids,
|
||||
scheduler_output.finished_req_ids,
|
||||
)
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in stopped_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Update the states of the running requests.
|
||||
for req_data in scheduler_output.scheduled_running_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Update the num_computed_tokens.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
|
||||
# Update the block table.
|
||||
num_new_blocks = len(req_data.new_block_ids)
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
end_index = start_index + num_new_blocks
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table_cpu[
|
||||
req_index, start_index:end_index] = req_data.new_block_ids
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req_data.req_id
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=req_data.prompt_token_ids,
|
||||
prompt=req_data.prompt,
|
||||
multi_modal_data=req_data.multi_modal_data,
|
||||
sampling_params=req_data.sampling_params,
|
||||
generator=None, # TODO
|
||||
block_ids=req_data.block_ids,
|
||||
num_computed_tokens=req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
)
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
for req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.block_ids = req_data.block_ids
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
||||
for req_id in req_ids_to_add:
|
||||
req_state = self.requests[req_id]
|
||||
if removed_req_indices:
|
||||
# Fill the empty index.
|
||||
req_index = removed_req_indices.pop()
|
||||
else:
|
||||
# Append to the end.
|
||||
req_index = None
|
||||
self.input_batch.add_request(req_state, req_index)
|
||||
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table[:num_reqs].copy_(
|
||||
self.input_batch.block_table_cpu_tensor[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens = []
|
||||
max_num_scheduled_tokens = 0
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens.append(num_tokens)
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
|
||||
assert max_num_scheduled_tokens > 0
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
indices = np.arange(num_reqs)
|
||||
req_indices = np.repeat(indices, num_scheduled_tokens)
|
||||
|
||||
# Get batched arange.
|
||||
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
arange_matrix = np.tile(np.arange(max_num_scheduled_tokens),
|
||||
(num_reqs, 1))
|
||||
mask = arange_matrix < num_scheduled_tokens[:, np.newaxis]
|
||||
arange = arange_matrix[mask]
|
||||
|
||||
# Get positions.
|
||||
positions = torch.empty((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
positions_np = positions.numpy()
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
out=positions_np)
|
||||
|
||||
# Get token indices.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||||
# where M is the max_model_len.
|
||||
token_indices = positions_np + req_indices * self.max_model_len
|
||||
token_indices = torch.from_numpy(token_indices)
|
||||
input_ids = torch.empty((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
torch.index_select(torch.from_numpy(
|
||||
self.input_batch.token_ids_cpu).flatten(),
|
||||
0,
|
||||
token_indices,
|
||||
out=input_ids)
|
||||
|
||||
# Calculate the slot mapping.
|
||||
block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[
|
||||
token_indices // self.block_size]
|
||||
block_offsets = token_indices % self.block_size
|
||||
slot_mapping = torch.empty((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
torch.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=slot_mapping)
|
||||
|
||||
# Prepare the attention metadata.
|
||||
query_start_loc = torch.empty((num_reqs + 1, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
query_start_loc_np = query_start_loc.numpy()
|
||||
query_start_loc_np[0] = 0
|
||||
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
|
||||
|
||||
seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
max_seq_len = seq_lens.max()
|
||||
seq_start_loc = torch.empty((num_reqs + 1, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
seq_start_loc_np = seq_start_loc.numpy()
|
||||
seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens, out=seq_start_loc_np[1:])
|
||||
|
||||
input_ids = input_ids.to(self.device, non_blocking=True)
|
||||
positions = positions.to(self.device, non_blocking=True).long()
|
||||
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
|
||||
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
|
||||
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_start_loc=seq_start_loc,
|
||||
block_table=self.input_batch.block_table[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
# request in the batch. While we should not sample any token from this
|
||||
# partial request, we do so for simplicity. We will ignore the sampled
|
||||
# token from the partial request.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return input_ids, positions, attn_metadata, logits_indices
|
||||
|
||||
def _prepare_sampling(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> SamplingMetadata:
|
||||
skip_copy = True
|
||||
if (scheduler_output.finished_req_ids
|
||||
or scheduler_output.preempted_req_ids):
|
||||
skip_copy = False
|
||||
if (scheduler_output.scheduled_new_reqs
|
||||
or scheduler_output.scheduled_resumed_reqs):
|
||||
skip_copy = False
|
||||
# Create the sampling metadata.
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
|
||||
return sampling_metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
self._update_states(scheduler_output)
|
||||
inputs = self._prepare_inputs(scheduler_output)
|
||||
input_ids, positions, attn_metadata, logits_indices = inputs
|
||||
|
||||
with set_forward_context(attn_metadata):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self._prepare_sampling(scheduler_output)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
# NOTE: CPU-GPU synchronization happens here.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
|
||||
sampled_token_ids_list = sampled_token_ids.tolist()
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
assert seq_len <= req_state.num_tokens
|
||||
if seq_len == req_state.num_tokens:
|
||||
# Append the sampled token to the output token ids.
|
||||
token_id = sampled_token_ids_list[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids.append(token_id)
|
||||
else:
|
||||
# Ignore the sampled token from the partial request.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
generator = self.input_batch.generators[i]
|
||||
if generator is not None:
|
||||
offset = generator.get_offset()
|
||||
generator = generator.set_offset(offset - 1)
|
||||
self.input_batch.generators[i] = generator
|
||||
|
||||
if sampler_output.logprob_token_ids is None:
|
||||
logprob_token_ids = None
|
||||
else:
|
||||
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
|
||||
if sampler_output.logprobs is None:
|
||||
logprobs = None
|
||||
else:
|
||||
logprobs = sampler_output.logprobs.cpu()
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids[:num_reqs],
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids_cpu=sampled_token_ids,
|
||||
logprob_token_ids_cpu=logprob_token_ids,
|
||||
logprobs_cpu=logprobs,
|
||||
)
|
||||
return model_runner_output
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
device_config=self.device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
self.model_memory_usage / float(2**30))
|
||||
|
||||
def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
|
||||
input_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
positions = torch.zeros(num_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
kv_caches = [None for _ in range(self.num_attn_layers)]
|
||||
model(input_ids, positions, kv_caches, attn_metadata=None)
|
||||
return
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
self._dummy_run(self.model, self.max_num_tokens)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self) -> None:
|
||||
# TODO: Implement CUDA graph support.
|
||||
return
|
||||
|
||||
def initialize_kv_cache(self, num_blocks: int) -> None:
|
||||
assert len(self.kv_caches) == 0
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
for _ in range(self.num_attn_layers):
|
||||
self.kv_caches.append(
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device))
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: List[int]
|
||||
prompt: Optional[str]
|
||||
multi_modal_data: Optional["MultiModalDataDict"]
|
||||
sampling_params: SamplingParams
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
block_ids: List[int]
|
||||
num_computed_tokens: int
|
||||
output_token_ids: List[int]
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return len(self.prompt_token_ids) + len(self.output_token_ids)
|
||||
|
||||
|
||||
class InputBatch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
|
||||
self.req_id_to_index: Dict[str, int] = {}
|
||||
|
||||
self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
|
||||
dtype=np.int32)
|
||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Attention-related.
|
||||
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
self.block_table_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
||||
self.greedy_reqs: Set[str] = set()
|
||||
self.random_reqs: Set[str] = set()
|
||||
|
||||
self.top_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
||||
self.top_p_reqs: Set[str] = set()
|
||||
|
||||
self.top_k = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: Set[str] = set()
|
||||
|
||||
self.generators: List[Optional[torch.Generator]] = [None
|
||||
] * max_num_reqs
|
||||
|
||||
self.num_logprobs: Dict[str, int] = {}
|
||||
self.prompt_logprob_reqs: Set[str] = set()
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
req_index: Optional[int] = None,
|
||||
) -> None:
|
||||
if req_index is None:
|
||||
req_index = self.num_reqs
|
||||
assert req_index < self.max_num_reqs
|
||||
|
||||
self.req_ids[req_index] = request.req_id
|
||||
self.req_id_to_index[request.req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
start_idx = num_prompt_tokens
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
self.token_ids_cpu[req_index,
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
num_blocks = len(request.block_ids)
|
||||
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
self.greedy_reqs.add(req_index)
|
||||
elif sampling_params.sampling_type == SamplingType.RANDOM:
|
||||
self.random_reqs.add(req_index)
|
||||
elif sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
# TODO(woosuk): Support per-request random seed.
|
||||
raise NotImplementedError("Per-request seed is not supported yet.")
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_index)
|
||||
self.top_k_cpu[req_index] = sampling_params.top_k
|
||||
if sampling_params.top_k > 0:
|
||||
self.top_k_reqs.add(req_index)
|
||||
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is not None and num_logprobs > 0:
|
||||
self.num_logprobs[request.req_id] = num_logprobs
|
||||
if sampling_params.prompt_logprobs:
|
||||
self.prompt_logprob_reqs.add(req_index)
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self.req_ids[req_index] = None
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.generators[req_index] = None
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.prompt_logprob_reqs.discard(req_id)
|
||||
return req_index
|
||||
|
||||
def clear(self) -> None:
|
||||
self.req_ids = [None] * self.max_num_reqs
|
||||
self.req_id_to_index.clear()
|
||||
self.greedy_reqs.clear()
|
||||
self.random_reqs.clear()
|
||||
self.top_p_reqs.clear()
|
||||
self.top_k_reqs.clear()
|
||||
self.generators.clear()
|
||||
self.num_logprobs.clear()
|
||||
self.prompt_logprob_reqs.clear()
|
||||
|
||||
def condense(self, empty_req_indices: List[int]) -> None:
|
||||
if self.num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
# is sorted in descending order.
|
||||
last_req_index = self.num_reqs + len(empty_req_indices) - 1
|
||||
while empty_req_indices:
|
||||
# Find the largest non-empty index.
|
||||
while last_req_index in empty_req_indices:
|
||||
last_req_index -= 1
|
||||
|
||||
# Find the smallest empty index.
|
||||
empty_index = empty_req_indices.pop()
|
||||
if empty_index >= last_req_index:
|
||||
break
|
||||
|
||||
# Swap the states.
|
||||
req_id = self.req_ids[last_req_index]
|
||||
self.req_ids[empty_index] = req_id
|
||||
self.req_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
# TODO(woosuk): Optimize the copy of token_ids_cpu and
|
||||
# block_table_cpu.
|
||||
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table_cpu[empty_index] = self.block_table_cpu[
|
||||
last_req_index]
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.generators[empty_index] = self.generators[last_req_index]
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
skip_copy: bool = False,
|
||||
) -> SamplingMetadata:
|
||||
if not skip_copy:
|
||||
self.temperature[:self.num_reqs].copy_(
|
||||
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
self.top_p[:self.num_reqs].copy_(
|
||||
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
self.top_k[:self.num_reqs].copy_(
|
||||
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature[:self.num_reqs],
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
top_p=self.top_p[:self.num_reqs],
|
||||
top_k=self.top_k[:self.num_reqs],
|
||||
no_top_p=self.no_top_p,
|
||||
no_top_k=self.no_top_k,
|
||||
generators=self.generators[:self.num_reqs],
|
||||
no_generator=self.no_generator,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
@property
|
||||
def all_greedy(self) -> bool:
|
||||
return len(self.random_reqs) == 0
|
||||
|
||||
@property
|
||||
def all_random(self) -> bool:
|
||||
return len(self.greedy_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_p(self) -> bool:
|
||||
return len(self.top_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_generator(self) -> bool:
|
||||
return len(self.generators) == 0
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> int:
|
||||
if self.num_logprobs:
|
||||
return max(self.num_logprobs.values())
|
||||
else:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def no_logprob(self) -> bool:
|
||||
return len(self.num_logprobs) == 0
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return len(self.prompt_logprob_reqs) == 0
|
245
vllm/v1/worker/gpu_worker.py
Normal file
245
vllm/v1/worker/gpu_worker.py
Normal file
@ -0,0 +1,245 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
|
||||
|
||||
class Worker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.speculative_config = speculative_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.observability_config = observability_config
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
self.model_runner = GPUModelRunner(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config,
|
||||
lora_config=lora_config,
|
||||
)
|
||||
|
||||
def initialize(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
torch.cuda.synchronize()
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||
assert peak_memory > 0, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
cache_block_size = _get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory) // cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
# if self.model_runner.lora_manager:
|
||||
# self.model_runner.remove_all_loras()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return num_gpu_blocks, 0
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
"""Allocate GPU and CPU KV cache with the specified number of blocks."""
|
||||
if num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
|
||||
max_seq_len = self.cache_config.block_size * num_gpu_blocks
|
||||
max_model_len = self.model_config.max_model_len
|
||||
if max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
|
||||
self.model_runner.initialize_kv_cache(num_gpu_blocks)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
# TODO(woosuk): Send the output to the engine process.
|
||||
return output
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not current_platform.has_device_capability(80):
|
||||
capability = current_platform.get_device_capability()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the"
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
|
||||
def _get_cache_block_size(
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_attention_layers = model_config.get_num_attention_layers(
|
||||
parallel_config)
|
||||
|
||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_attention_layers * (key_cache_block + value_cache_block)
|
||||
if cache_config.cache_dtype == "auto":
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
Reference in New Issue
Block a user