Compare commits

...

110 Commits

Author SHA1 Message Date
866eef50ca minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-24 15:29:27 +00:00
ad2cf805ad Merge branch 'main' into woosuk/model-runner-v2 2025-09-24 08:19:25 -07:00
704def253c Merge branch 'main' into woosuk/model-runner-v2 2025-09-23 21:08:15 +00:00
42f99150c1 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-23 09:23:21 -07:00
17c2c106b1 Merge branch 'main' into woosuk/model-runner-v2 2025-09-23 09:22:58 -07:00
72f0a71939 assert
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-21 19:37:18 -07:00
fe5472dc03 Merge branch 'main' into woosuk/model-runner-v2 2025-09-21 18:56:48 -07:00
bc73f674bb compute_logits
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-21 11:26:33 -07:00
631b5b47c1 Merge branch 'main' into woosuk/model-runner-v2 2025-09-21 11:25:18 -07:00
42ffdd9179 wip
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-20 22:15:07 +00:00
8aee6e97e6 64-bit for gumbel seed
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-20 11:43:01 +00:00
913b8e9569 Merge branch 'main' into woosuk/model-runner-v2 2025-09-20 11:18:35 +00:00
158a46888e random uuid
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-20 11:17:45 +00:00
98ef239486 minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 23:55:46 +00:00
a66aa37f40 minor:
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 23:47:20 +00:00
6f038fc4fb Merge branch 'main' into woosuk/model-runner-v2 2025-09-19 20:30:04 +00:00
010e39ec7d minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 19:07:46 +00:00
396bbe67d3 Merge branch 'main' into woosuk/model-runner-v2 2025-09-19 18:53:18 +00:00
c7f3e84b34 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-19 09:49:40 -07:00
a8e7071924 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-19 08:33:47 -07:00
4be2c66e37 fix
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 09:35:38 +00:00
d30c0d50a6 refactor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 07:17:53 +00:00
9c75d896a8 minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 07:11:37 +00:00
37478c18cf async output
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 07:10:42 +00:00
33672774f5 Merge branch 'main' into woosuk/model-runner-v2 2025-09-19 06:52:46 +00:00
0d3de9e082 fix
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 06:50:56 +00:00
b405d78c07 DP sampler
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 06:46:46 +00:00
8af87986aa fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 18:37:30 -07:00
af65838d1f dummy run
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 18:29:18 -07:00
52ca2f517a sample
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 17:39:43 -07:00
8deedfa42b -inf
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 17:24:00 -07:00
b9c74487d2 logprobs
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 17:23:02 -07:00
31619ff412 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:38:56 -07:00
d2be62378b fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:33:18 -07:00
86dade710d fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:32:00 -07:00
efda08481b minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:31:01 -07:00
82da219ff9 Implement topk_logprobs
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 16:29:38 -07:00
323a05b3c5 update
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 15:51:36 -07:00
a98eff0762 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 15:21:30 -07:00
67d8c0c21b fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 15:15:31 -07:00
2bb2cb13f4 revert
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:54:19 -07:00
e171e5bb67 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:53:32 -07:00
8407fa02ed fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:52:23 -07:00
82e591f7eb remove
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:35:25 -07:00
330058f9b8 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:30:29 -07:00
aabfaa08cf fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:14:03 -07:00
bc6463ac97 hash
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 13:49:52 -07:00
a4962833f9 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 13:20:37 -07:00
3f50030cc8 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 13:11:46 -07:00
cbdb47dc01 working
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 13:10:35 -07:00
92f337faeb minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 12:44:21 -07:00
9050087250 update
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 12:37:29 -07:00
c1d83f2bae merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 12:13:56 -07:00
91510260b2 task
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-16 01:06:10 -07:00
c320a33c59 skip warmup
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-16 07:21:25 +00:00
83d11373a4 wip
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-16 07:21:25 +00:00
dfc84b11a9 wip
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-16 07:21:25 +00:00
9f2becd3e6 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-16 00:16:42 -07:00
e107680d8a wip
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-15 21:19:18 +00:00
f1981db101 minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-15 19:53:58 +00:00
69b17891a3 chunked prefilling
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-15 19:41:17 +00:00
67852c1036 minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-15 19:23:54 +00:00
8b3c13c485 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-15 11:17:54 -07:00
9a6fcca030 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-14 15:56:42 -07:00
633f9f006d Merge branch 'main' into woosuk/input-prep 2025-09-14 08:03:28 -07:00
eb3742c72a fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:19:40 -07:00
e47bb9970b fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:19:07 -07:00
5c133fc860 reorder
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:17:40 -07:00
caf963f2e9 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:13:08 -07:00
9314a83b56 Merge branch 'main' into woosuk/input-prep 2025-09-14 00:44:56 +00:00
7a50a54390 Merge branch 'main' into woosuk/input-prep 2025-09-13 21:33:54 +00:00
787e59629c wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-08 16:42:26 -07:00
5f95309a6d rename
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-07 12:01:45 -07:00
286eeb91e8 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-07 11:16:37 -07:00
6283995a6c minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 21:18:16 -07:00
0c56069c7e merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 16:35:45 -07:00
8e6cb9aa4a minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 12:23:02 -07:00
ead95fe5dc merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 10:56:27 -07:00
23eae07ea5 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-04 20:19:22 -07:00
b16e2d9602 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 02:10:48 -07:00
4c2a337e67 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 01:45:29 -07:00
cc340e26af top_p top_k
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 01:30:08 -07:00
01bf16ede4 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 01:16:26 -07:00
af7b6c5dd4 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 23:50:20 -07:00
62d23b3006 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 21:00:16 -07:00
ba1a58f51b MAX_SPEC_LEN
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 20:43:25 -07:00
22771e5d83 work
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 20:41:38 -07:00
c11d1e6781 optimize spec
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 16:40:54 -07:00
e696f78e05 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 13:29:58 -07:00
efcb786d52 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 10:44:36 -07:00
9ee9d0e274 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 15:02:07 -07:00
405578121c minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 13:19:10 -07:00
19c0dfc469 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 13:08:07 -07:00
e451045a66 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 12:55:13 -07:00
efba25e21a minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 12:39:15 -07:00
b21393cd98 Merge branch 'main' into woosuk/input-prep 2025-08-28 09:58:08 -07:00
d6d719fb24 Merge branch 'main' into woosuk/input-prep 2025-08-28 09:57:49 -07:00
e570b0a4de merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-27 21:45:11 -07:00
a851aaa0fc simplify
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 09:23:05 -07:00
b1d52734f7 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 08:55:12 -07:00
65f93694be merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 08:54:32 -07:00
7b4b72e551 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-24 18:49:23 -07:00
da9cd26c78 Merge branch 'main' into woosuk/input-prep 2025-08-24 18:36:33 -07:00
a1e3745150 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-24 18:36:18 -07:00
48bca9a109 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-23 11:30:29 -07:00
64c8cced18 rename
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-22 01:48:35 -07:00
79e5eb3643 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-22 01:37:43 -07:00
c472982746 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-21 21:40:44 -07:00
699bd7928e Merge branch 'main' into woosuk/input-prep 2025-08-17 19:28:38 -07:00
33a3a26ca5 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-17 14:38:24 -07:00
14 changed files with 1874 additions and 22 deletions

View File

@ -24,6 +24,8 @@ logger = init_logger(__name__)
def kernel_warmup(worker: "Worker"):
return
# Deep GEMM warmup
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
and is_deep_gemm_supported()

View File

@ -864,6 +864,7 @@ class Scheduler(SchedulerInterface):
model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids
num_sampled_tokens = model_runner_output.num_sampled_tokens
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
@ -881,7 +882,8 @@ class Scheduler(SchedulerInterface):
# to avoid expensive operations inside the loop.
stopped_running_reqs: set[Request] = set()
stopped_preempted_reqs: set[Request] = set()
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
for req_index, req_id in enumerate(model_runner_output.req_ids):
num_tokens_scheduled = num_scheduled_tokens[req_id]
assert num_tokens_scheduled > 0
request = self.requests.get(req_id)
if request is None:
@ -890,9 +892,13 @@ class Scheduler(SchedulerInterface):
# in pipeline parallelism).
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
generated_token_ids = []
if sampled_token_ids is not None:
assert num_sampled_tokens is not None
n = num_sampled_tokens[req_index]
if n > 0:
generated_token_ids = sampled_token_ids[req_index, :n]
generated_token_ids = generated_token_ids.tolist()
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))

View File

@ -3,6 +3,7 @@
import copy
from dataclasses import dataclass, fields
from functools import cached_property
from math import prod
from typing import Optional
@ -342,3 +343,10 @@ class KVCacheConfig:
see `_get_kv_cache_config_uniform_page_size` for more details.
"""
kv_cache_groups: list[KVCacheGroupSpec]
@cached_property
def block_sizes(self) -> list[int]:
return [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in self.kv_cache_groups
]

View File

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Optional
import numpy as np
import torch
if TYPE_CHECKING:
@ -15,11 +16,11 @@ if TYPE_CHECKING:
class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: list[list[int]]
logprob_token_ids: np.ndarray
# [num_reqs, max_num_logprobs + 1]
logprobs: list[list[float]]
logprobs: np.ndarray
# [num_reqs]
sampled_token_ranks: list[int]
sampled_token_ranks: np.ndarray
def slice(self, start: int, end: int):
return LogprobsLists(
@ -40,9 +41,9 @@ class LogprobsTensors(NamedTuple):
def tolists(self):
return LogprobsLists(
self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
self.logprob_token_ids.cpu().numpy(),
self.logprobs.cpu().numpy(),
self.selected_token_ranks.cpu().numpy(),
)
@staticmethod
@ -89,20 +90,18 @@ class KVConnectorOutput:
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead.
@dataclass
class ModelRunnerOutput:
# [num_reqs]
req_ids: list[str]
# req_id -> index
req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]]
sampled_token_ids: Optional[np.ndarray]
num_sampled_tokens: Optional[np.ndarray]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
@ -148,8 +147,8 @@ class DraftTokenIds:
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
sampled_token_ids=None,
num_sampled_tokens=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],

View File

View File

@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.v1.outputs import (AsyncModelRunnerOutput, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
copy_stream: torch.cuda.Stream,
):
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.copy_stream = copy_stream
self.copy_event = torch.cuda.Event()
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.copy_stream):
self.copy_stream.wait_stream(default_stream)
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
"cpu", non_blocking=True)
x = sampler_output.logprobs_tensors
if x is not None:
self.logprobs_tensors = LogprobsTensors(
logprob_token_ids=x.logprob_token_ids.to(
"cpu", non_blocking=True),
logprobs=x.logprobs.to("cpu", non_blocking=True),
selected_token_ranks=x.selected_token_ranks.to(
"cpu", non_blocking=True),
)
else:
self.logprobs_tensors = None
self.copy_event.record()
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
self.model_runner_output.sampled_token_ids = (
self.sampled_token_ids.numpy())
if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = (
self.logprobs_tensors.tolists())
return self.model_runner_output

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, SlidingWindowSpec)
from vllm.v1.worker.utils import bind_kv_cache
def get_kv_cache_spec(
vllm_config: VllmConfig,
kv_cache_dtype: torch.dtype,
) -> dict[str, KVCacheSpec]:
block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
assert attn_module.attn_type == AttentionType.DECODER
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
use_mla=use_mla,
)
return kv_cache_spec
def init_attn_backend(
kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
device: torch.device,
):
attn_backends: dict[str, AttentionBackend] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names
any_layer_name = next(iter(layer_names))
attn_backend = attn_layers[any_layer_name].get_attn_backend()
for layer_name in layer_names:
attn_backends[layer_name] = attn_backend
attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec,
layer_names,
vllm_config,
device,
)
attn_metadata_builders.append(attn_metadata_builder)
return attn_backends, attn_metadata_builders
def _allocate_kv_cache(
kv_cache_config: KVCacheConfig,
device: torch.device,
):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()
), "Some layers are not correctly initialized"
return kv_cache_raw_tensors
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend],
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes)
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
return kv_caches
def init_kv_cache(
runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
):
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors,
attn_backends)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)

View File

@ -0,0 +1,312 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import triton
import triton.language as tl
from vllm.utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
PAD_SLOT_ID = -1
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.pin_memory = pin_memory
self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros(
self.max_num_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
# Block tables used for model's forward pass.
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(block_table) for block_table in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(
self.input_block_tables)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device)
self.block_sizes_tensor = torch.tensor(self.block_sizes,
dtype=torch.int32,
device=self.device)
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups,
self.max_num_reqs + 1,
dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x],
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
self,
# [num_reqs]
req_indices: list[int],
# [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks: list[list[int]],
# [num_kv_cache_groups, num_new_blocks]
new_block_ids: list[list[int]],
# [num_reqs]
overwrite: list[bool],
) -> None:
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound to the number of new blocks in a single step.
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
# guarantee that the buffer is not freed before the copy is completed.
self.new_block_ids_cpu = torch.empty(
self.num_kv_cache_groups,
max(len(x) for x in new_block_ids),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
new_block_ids_np = self.new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device,
non_blocking=True)
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(),
self.cu_num_new_blocks.gpu.stride(0),
new_block_ids_gpu,
new_block_ids_gpu.stride(0),
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.block_table_ptrs,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
)
def gather_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
)
return tuple(block_table[:num_reqs]
for block_table in self.input_block_tables)
def compute_slot_mappings(
self,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
num_reqs = query_start_loc.shape[0] - 1
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
positions,
self.input_block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)
return self.slot_mappings[:, :num_tokens]
@triton.jit
def _append_block_ids_kernel(
# Inputs
req_indices, # [num_reqs]
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks_stride,
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
new_block_ids_stride,
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
):
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = (cu_num_new_blocks_ptr +
group_id * cu_num_new_blocks_stride)
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
if do_overwrite:
dst_start_idx = 0
else:
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx)
dst_end_idx = dst_start_idx + num_new_blocks
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
row_ptr = block_table_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = (new_block_ids_ptr +
group_id * new_block_ids_stride)
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset,
mask=offset < num_new_blocks)
tl.store(row_ptr + dst_start_idx + offset,
block_ids,
mask=offset < num_new_blocks)
@triton.jit
def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
req_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset,
PAD_ID,
mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(block_table_ptr +
req_idx * block_table_stride + block_indices)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)

View File

@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.v1.outputs import SamplerOutput
def evenly_split(
n: int,
tp_size: int,
tp_rank: int,
) -> tuple[int, int]:
q = n // tp_size
r = n % tp_size
start = q * tp_rank + min(tp_rank, r)
end = start + q + (1 if tp_rank < r else 0)
return start, end
def pad_and_all_gather(
x: torch.Tensor,
padded_size: int,
) -> torch.Tensor:
n = x.shape[0]
if n != padded_size:
padded_x = torch.empty(
(padded_size, *x.shape[1:]),
dtype=x.dtype,
device=x.device,
)
padded_x[:n] = x
else:
padded_x = x
x = tensor_model_parallel_all_gather(padded_x)
return x
def all_gather_sampler_output(
sampler_output: SamplerOutput,
num_reqs: int,
tp_size: int,
) -> SamplerOutput:
n = (num_reqs + tp_size - 1) // tp_size
sampler_output.sampled_token_ids = pad_and_all_gather(
sampler_output.sampled_token_ids, n)[:num_reqs]
# TODO(woosuk): 3 small all-gathers, could be merged into one.
logprobs_tensors = sampler_output.logprobs_tensors
if logprobs_tensors is not None:
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
logprobs_tensors.logprobs = pad_and_all_gather(
logprobs_tensors.logprobs, n)[:num_reqs]
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
return sampler_output

View File

@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
import numba
import numba.types as types
import numpy as np
import torch
import triton
import triton.language as tl
from vllm.utils import random_uuid
from vllm.v1.utils import CpuGpuBuffer
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
self.query_start_loc = self._make_buffer(max_num_reqs + 1,
dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
# [num_reqs]
is_chunked_prefilling: np.ndarray
# [max_num_batched_tokens]
input_ids: torch.Tensor
# [max_num_batched_tokens]
positions: torch.Tensor
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# [num_reqs]
logits_indices: torch.Tensor
@classmethod
def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.tensor(idx_mapping_np, device=device)
num_scheduled_tokens = np.full(num_reqs,
num_tokens // num_reqs,
dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_)
input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device)
positions = torch.zeros(num_tokens, dtype=torch.int64, device=device)
attn_metadata = defaultdict(lambda: None)
logits_indices = torch.arange(num_reqs,
dtype=torch.int32,
device=device)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
is_chunked_prefilling=is_chunked_prefilling,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int64[:], # positions
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
)
],
nopython=True,
cache=True,
)
def _prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
def prepare_inputs(
idx_mapping: np.ndarray,
prompt_token_ids: np.ndarray,
num_computed_tokens: np.ndarray,
num_scheduled_tokens: np.ndarray,
input_ids: CpuGpuBuffer,
positions: CpuGpuBuffer,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_tokens: int,
) -> tuple[np.ndarray, np.ndarray]:
_prepare_inputs(
idx_mapping,
prompt_token_ids,
num_computed_tokens,
num_scheduled_tokens,
input_ids.np,
positions.np,
query_start_loc.np,
seq_lens.np,
)
input_ids.copy_to_gpu(num_tokens)
positions.copy_to_gpu(num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
seq_lens.copy_to_gpu()
num_reqs = num_scheduled_tokens.shape[0]
max_query_len = int(num_scheduled_tokens.max())
max_seq_len = int(seq_lens.np[:num_reqs].max())
return max_query_len, max_seq_len
@triton.jit
def _combine_last_token_ids_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_token_ids_ptr,
query_start_loc_ptr,
seq_lens_ptr,
num_tokens_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
num_tokens = tl.load(num_tokens_ptr + req_state_idx)
if seq_len < num_tokens:
# Chunked prefilling.
return
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
if last_token_id == -1:
return
end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id)
def combine_last_token_ids(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_token_ids: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
num_tokens: torch.Tensor,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
_combine_last_token_ids_kernel[(num_reqs, )](
input_ids,
idx_mapping,
last_token_ids,
query_start_loc,
seq_lens,
num_tokens,
)
return input_ids

View File

@ -0,0 +1,476 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from copy import deepcopy
from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, is_pin_memory_available)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
init_attn_backend, init_kv_cache)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output,
evenly_split)
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
combine_last_token_ids,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
logger = init_logger(__name__)
class GPUModelRunner:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self.kv_cache_dtype = self.dtype
if self.cache_config.cache_dtype != "auto":
# Quantized KV cache.
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.use_async_scheduling = self.scheduler_config.async_scheduling
assert self.use_async_scheduling
self.output_copy_stream = torch.cuda.Stream()
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
)
self.sampler = Sampler()
def get_supported_tasks(self) -> tuple[str]:
return ("generate", )
def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(self.vllm_config.load_config)
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config,
)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load)
def get_model(self) -> nn.Module:
return self.model
def get_kv_cache_spec(self):
return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
block_sizes = kv_cache_config.block_sizes
self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
)
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
self.kv_cache_config,
self.vllm_config,
self.device,
)
self.kv_caches: list[torch.Tensor] = []
init_kv_cache(
self.kv_caches,
self.compilation_config.static_forward_context,
self.kv_cache_config,
self.attn_backends,
self.device,
)
def _dummy_run(
self,
num_tokens: int,
*args,
input_batch: Optional[InputBatch] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if input_batch is None:
input_batch = InputBatch.make_dummy(
num_reqs=min(num_tokens, self.max_num_reqs),
num_tokens=num_tokens,
device=self.device,
)
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> None:
num_reqs = hidden_states.shape[0]
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata)
def profile_run(self) -> None:
input_batch = InputBatch.make_dummy(
num_reqs=self.max_num_reqs,
num_tokens=self.max_num_tokens,
device=self.device,
)
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens,
input_batch=input_batch,
)
self._dummy_sampler_run(sample_hidden_states)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
def update_states(self, scheduler_output: SchedulerOutput) -> None:
# for req_id in scheduler_output.preempted_req_ids:
# self.req_states.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id)
# TODO(woosuk): Change SchedulerOutput.
req_indices: list[int] = []
cu_num_new_blocks = tuple(
[0] for _ in range(self.block_tables.num_kv_cache_groups))
new_block_ids = tuple(
[] for _ in range(self.block_tables.num_kv_cache_groups))
overwrite: list[bool] = []
# Add new requests.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
self.req_states.add_request(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
)
req_index = self.req_states.req_id_to_index[req_id]
req_indices.append(req_index)
for i, block_ids in enumerate(new_req_data.block_ids):
x = cu_num_new_blocks[i][-1]
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
req_index = self.req_states.req_id_to_index[req_id]
req_new_block_ids = cached_reqs.new_block_ids[i]
if req_new_block_ids is not None:
req_indices.append(req_index)
for group_id, block_ids in enumerate(req_new_block_ids):
x = cu_num_new_blocks[group_id][-1]
cu_num_new_blocks[group_id].append(x + len(block_ids))
new_block_ids[group_id].extend(block_ids)
overwrite.append(False)
if req_indices:
self.block_tables.append_block_ids(
req_indices=req_indices,
cu_num_new_blocks=cu_num_new_blocks,
new_block_ids=new_block_ids,
overwrite=overwrite,
)
def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens)
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
dtype=np.int32)
# TODO(woosuk): Support CUDA graphs.
num_tokens_after_padding = num_tokens
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
idx_mapping = self.input_buffers.idx_mapping
idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = idx_mapping.np[:num_reqs]
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
max_query_len, max_seq_len = prepare_inputs(
idx_mapping_np,
self.req_states.prompt_token_ids,
self.req_states.num_computed_tokens,
num_scheduled_tokens,
self.input_buffers.input_ids,
self.input_buffers.positions,
self.input_buffers.query_start_loc,
self.input_buffers.seq_lens,
num_tokens,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
seq_lens_cpu = self.input_buffers.seq_lens.cpu[:num_reqs]
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
# Some input token ids are directly read from the last sampled tokens.
combine_last_token_ids(
self.input_buffers.input_ids.gpu,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
seq_lens_gpu,
self.req_states.num_tokens.copy_to_gpu(),
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens])
num_computed_tokens_cpu = torch.from_numpy(
self.req_states.num_computed_tokens[idx_mapping_np])
# Whether the request is chunked-prefilling or not.
is_chunked_prefilling = (
seq_lens_np < self.req_states.num_tokens.np[idx_mapping_np])
# Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1
num_logits_indices = logits_indices.size(0)
# Layer name -> attention metadata.
attn_metadata: dict[str, Any] = {}
kv_cache_groups = self.kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
logits_indices_padded=None,
num_logits_indices=num_logits_indices,
causal=True,
encoder_seq_lens=None,
)
attn_metadata_builder = self.attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
is_chunked_prefilling=is_chunked_prefilling,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
def sample(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
) -> SamplerOutput:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
pos = input_batch.positions[input_batch.logits_indices]
idx_mapping_np = input_batch.idx_mapping_np
num_reqs = logits.shape[0]
# When the batch size is large enough, use DP sampler.
tp_group = get_tp_group()
tp_size = tp_group.world_size
n = (num_reqs + tp_size - 1) // tp_size
use_dp_sampler = tp_size > 1 and n > 32 # TODO(woosuk): Tune.
if use_dp_sampler:
# NOTE(woosuk): Make sure that no rank gets zero requests.
tp_rank = tp_group.rank
start, end = evenly_split(num_reqs, tp_size, tp_rank)
logits = logits[start:end]
pos = pos[start:end]
idx_mapping_np = idx_mapping_np[start:end]
sampling_metadata = self.req_states.make_sampling_metadata(
idx_mapping_np, pos)
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
needs_prompt_logprobs = np.any(
self.req_states.needs_prompt_logprobs[idx_mapping_np])
assert not needs_prompt_logprobs
if use_dp_sampler:
# All-gather the outputs.
sampler_output = all_gather_sampler_output(
sampler_output,
num_reqs,
tp_size,
)
return sampler_output
def postprocess(
self,
sampler_output: SamplerOutput,
input_batch: InputBatch,
) -> AsyncOutput:
# Store the last sampled token ids.
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
sampler_output.sampled_token_ids)
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
is_chunked_prefilling = input_batch.is_chunked_prefilling
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
idx_mapping_np = input_batch.idx_mapping_np
self.req_states.num_tokens.np[idx_mapping_np] += num_sampled_tokens
# Increment the number of computed tokens.
self.req_states.num_computed_tokens[idx_mapping_np] += (
input_batch.num_scheduled_tokens)
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
sampled_token_ids=None,
num_sampled_tokens=num_sampled_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
return AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
copy_stream=self.output_copy_stream,
)
def execute_model(
self,
scheduler_output: SchedulerOutput,
) -> AsyncOutput:
self.update_states(scheduler_output)
if scheduler_output.total_num_scheduled_tokens == 0:
return EMPTY_MODEL_RUNNER_OUTPUT
input_batch = self.prepare_inputs(scheduler_output)
num_tokens = input_batch.num_tokens_after_padding
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sampler_output = self.sample(hidden_states, input_batch)
return self.postprocess(sampler_output, input_batch)

View File

@ -0,0 +1,323 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import LogprobsMode
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.states import SamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(
self,
logprobs_mode: LogprobsMode = "processed_logprobs",
):
super().__init__()
assert logprobs_mode == "processed_logprobs"
self.logprobs_mode = logprobs_mode
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# Divide logits by temperature, in FP32.
logits = apply_temperature(logits, sampling_metadata.temperature)
# Apply top_k and/or top_p.
logits = apply_top_k_top_p(
logits,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# Sample the next token (int64).
sampled = gumbel_sample(
probs,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
)
logprobs_tensors = None
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
logprobs_tensors = compute_logprobs(
logits,
num_logprobs,
sampled,
)
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
@triton.jit
def _apply_temp_kernel(
logits, # bf16[batch_size, vocab_size]
logits_stride,
output, # fp32[batch_size, vocab_size]
output_stride,
temperature,
vocab_size,
BLOCK_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
temp = tl.load(temperature + batch_idx)
if temp < EPSILON:
# Greedy sampling. Don't apply temperature.
# NOTE(woosuk): In this case, we assume that its logprobs are not used.
temp = 1.0
offset = tl.arange(0, BLOCK_SIZE)
block = block_idx * BLOCK_SIZE + offset
# Load the logits.
x = tl.load(logits + batch_idx * logits_stride + block,
mask=block < vocab_size)
x = x.to(tl.float32)
x = x / temp
tl.store(output + batch_idx * output_stride + block,
x,
mask=block < vocab_size)
def apply_temperature(
logits: torch.Tensor,
temperature: torch.Tensor,
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
output = torch.empty_like(logits, dtype=torch.float32)
BLOCK_SIZE = 8192
_apply_temp_kernel[(batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))](
logits,
logits.stride(0),
output,
output.stride(0),
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
EPSILON=_SAMPLING_EPS,
)
return output
@triton.jit
def _apply_gumbel_kernel(
probs_ptr,
probs_stride,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
):
req_idx = tl.program_id(0)
temp = tl.load(temp_ptr + req_idx)
if temp < EPSILON:
# Greedy sampling. Don't apply gumbel noise.
return
seed = tl.load(seeds_ptr + req_idx).to(tl.uint64)
pos = tl.load(pos_ptr + req_idx).to(tl.uint64)
gumbel_seed = seed ^ (pos * 0x9E3779B97F4A7C15)
block_id = tl.program_id(1)
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
q = tl.rand(gumbel_seed, r_offset)
# NOTE(woosuk): This logic makes sure q is not 0.
RMAX = 0.9999999403953552
RMAX_LOG = -5.960464477539063e-08
q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q))
q = -1.0 * q
p = tl.load(probs_ptr + req_idx * probs_stride + r_offset,
mask=r_offset < vocab_size)
p = p / q
tl.store(probs_ptr + req_idx * probs_stride + r_offset,
p,
mask=r_offset < vocab_size)
def gumbel_sample(
# fp32[num_reqs, vocab_size]
probs: torch.Tensor,
# fp32[num_reqs]
temperature: torch.Tensor,
# int64[num_reqs]
seeds: torch.Tensor,
# int64[num_reqs]
pos: torch.Tensor,
) -> torch.Tensor:
num_reqs = probs.shape[0]
vocab_size = probs.shape[1]
# Update the probs in-place.
BLOCK_SIZE = 8192
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
probs,
probs.stride(0),
seeds,
pos,
temperature,
vocab_size,
BLOCK_SIZE,
EPSILON=_SAMPLING_EPS,
)
# Sample the next token.
return probs.argmax(dim=-1).view(-1)
@triton.jit
def _topk_log_softmax_kernel(
output_ptr,
logits_ptr,
logits_stride,
topk_ids_ptr,
topk,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block,
mask=block < vocab_size,
other=float("-inf"))
max_val = tl.max(tl.maximum(l, max_val))
se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
e = tl.exp(l - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask)
l = tl.load(row_ptr + topk_ids, mask=k_mask)
o = l - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
def compute_topk_logprobs(
logits: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
topk = topk_ids.shape[1]
output = torch.empty(
batch_size,
topk,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size, )](
output,
logits,
logits.stride(0),
topk_ids,
topk,
vocab_size,
BLOCK_SIZE=1024,
PADDED_TOPK=triton.next_power_of_2(topk),
)
return output
@triton.jit
def _ranks_kernel(
output_ptr,
logits_ptr,
logits_stride,
token_ids_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
token_id = tl.load(token_ids_ptr + req_idx)
x = tl.load(row_ptr + token_id)
n = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block,
mask=block < vocab_size,
other=float("-inf"))
n += tl.sum((l > x).to(tl.int32))
tl.store(output_ptr + req_idx, n)
def compute_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
if num_logprobs == 0:
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
else:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_topk_logprobs(
logits,
logprob_token_ids,
)
token_ranks = torch.empty(
batch_size,
dtype=torch.int64,
device=logits.device,
)
_ranks_kernel[(batch_size, )](
token_ranks,
logits,
logits.stride(0),
sampled_token_ids,
vocab_size,
BLOCK_SIZE=8192,
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=token_ranks,
)

View File

@ -0,0 +1,229 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.utils import CpuGpuBuffer
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
top_p: torch.Tensor | None
top_k: torch.Tensor | None
seeds: torch.Tensor
pos: torch.Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
@classmethod
def make_dummy(
cls,
num_reqs: int,
device: torch.device,
) -> "SamplingMetadata":
assert num_reqs > 0
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
top_p = torch.ones(num_reqs, dtype=torch.float32, device=device)
top_p[0] = 0.99
top_k = torch.ones(num_reqs, dtype=torch.int32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
return cls(
temperature=temperature,
top_p=top_p,
top_k=top_k,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.device = device
self.pin_memory = pin_memory
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
# NOTE(woosuk): Strictly speaking, it contains prompt + some output
# because of preemption.
self.prompt_token_ids = np.zeros(
(self.max_num_reqs, self.max_model_len),
dtype=np.int32,
)
self.num_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
self.max_num_reqs,
1,
dtype=torch.int64,
device=device,
)
# Sampling parameters.
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
) -> None:
assert len(self.free_indices) > 0
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
# NOTE(woosuk): Strictly speaking, "prompt_len" here may include
# output tokens, if the request is resumed from preemption.
prompt_len = len(prompt_token_ids)
self.prompt_token_ids[req_idx, :prompt_len] = prompt_token_ids
self.num_tokens.np[req_idx] = prompt_len
self.num_computed_tokens[req_idx] = num_computed_tokens
# TODO(woosuk): Optimize.
self.last_sampled_tokens[req_idx].fill_(-1)
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature.np[idx_mapping]
temperature = self.temperature.copy_np_to_gpu(temperature)
top_p = self.top_p.np[idx_mapping]
no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
top_k = self.top_k.np[idx_mapping]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
seeds = self.seeds.np[idx_mapping]
seeds = self.seeds.copy_np_to_gpu(seeds)
num_logprobs = self.num_logprobs[idx_mapping]
max_num_logprobs = int(np.max(num_logprobs))
if max_num_logprobs == -1:
max_num_logprobs = None
return SamplingMetadata(
temperature=temperature,
top_p=top_p,
top_k=top_k,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
class Param:
def __init__(
self,
size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.buffer = CpuGpuBuffer(
size,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
self.np = np.zeros_like(self.buffer.np)
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
n = x.shape[0]
self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n)

View File

@ -31,7 +31,8 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
@ -341,7 +342,9 @@ class Worker(WorkerBase):
self.model_runner._dummy_run(size,
skip_eplb=True,
remove_lora=False)
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
if self.model_runner.lora_config is not None:
self.model_runner.maybe_remove_all_loras(
self.model_runner.lora_config)
# Warmup and tune the kernels used during model execution before
# cuda graph capture.
@ -436,6 +439,9 @@ class Worker(WorkerBase):
self,
scheduler_output: "SchedulerOutput",
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
if len(get_pp_group().ranks) == 1:
return self.model_runner.execute_model(scheduler_output)
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -454,8 +460,6 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
@ -692,8 +696,9 @@ class Worker(WorkerBase):
tensorizer_config=tensorizer_config, )
def shutdown(self) -> None:
if runner := getattr(self, "model_runner", None):
runner.ensure_kv_transfer_shutdown()
# if runner := getattr(self, "model_runner", None):
# runner.ensure_kv_transfer_shutdown()
pass
def init_worker_distributed_environment(