mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
110 Commits
v0.11.0rc6
...
woosuk/mod
Author | SHA1 | Date | |
---|---|---|---|
866eef50ca | |||
ad2cf805ad | |||
704def253c | |||
42f99150c1 | |||
17c2c106b1 | |||
72f0a71939 | |||
fe5472dc03 | |||
bc73f674bb | |||
631b5b47c1 | |||
42ffdd9179 | |||
8aee6e97e6 | |||
913b8e9569 | |||
158a46888e | |||
98ef239486 | |||
a66aa37f40 | |||
6f038fc4fb | |||
010e39ec7d | |||
396bbe67d3 | |||
c7f3e84b34 | |||
a8e7071924 | |||
4be2c66e37 | |||
d30c0d50a6 | |||
9c75d896a8 | |||
37478c18cf | |||
33672774f5 | |||
0d3de9e082 | |||
b405d78c07 | |||
8af87986aa | |||
af65838d1f | |||
52ca2f517a | |||
8deedfa42b | |||
b9c74487d2 | |||
31619ff412 | |||
d2be62378b | |||
86dade710d | |||
efda08481b | |||
82da219ff9 | |||
323a05b3c5 | |||
a98eff0762 | |||
67d8c0c21b | |||
2bb2cb13f4 | |||
e171e5bb67 | |||
8407fa02ed | |||
82e591f7eb | |||
330058f9b8 | |||
aabfaa08cf | |||
bc6463ac97 | |||
a4962833f9 | |||
3f50030cc8 | |||
cbdb47dc01 | |||
92f337faeb | |||
9050087250 | |||
c1d83f2bae | |||
91510260b2 | |||
c320a33c59 | |||
83d11373a4 | |||
dfc84b11a9 | |||
9f2becd3e6 | |||
e107680d8a | |||
f1981db101 | |||
69b17891a3 | |||
67852c1036 | |||
8b3c13c485 | |||
9a6fcca030 | |||
633f9f006d | |||
eb3742c72a | |||
e47bb9970b | |||
5c133fc860 | |||
caf963f2e9 | |||
9314a83b56 | |||
7a50a54390 | |||
787e59629c | |||
5f95309a6d | |||
286eeb91e8 | |||
6283995a6c | |||
0c56069c7e | |||
8e6cb9aa4a | |||
ead95fe5dc | |||
23eae07ea5 | |||
b16e2d9602 | |||
4c2a337e67 | |||
cc340e26af | |||
01bf16ede4 | |||
af7b6c5dd4 | |||
62d23b3006 | |||
ba1a58f51b | |||
22771e5d83 | |||
c11d1e6781 | |||
e696f78e05 | |||
efcb786d52 | |||
9ee9d0e274 | |||
405578121c | |||
19c0dfc469 | |||
e451045a66 | |||
efba25e21a | |||
b21393cd98 | |||
d6d719fb24 | |||
e570b0a4de | |||
a851aaa0fc | |||
b1d52734f7 | |||
65f93694be | |||
7b4b72e551 | |||
da9cd26c78 | |||
a1e3745150 | |||
48bca9a109 | |||
64c8cced18 | |||
79e5eb3643 | |||
c472982746 | |||
699bd7928e | |||
33a3a26ca5 |
@ -24,6 +24,8 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def kernel_warmup(worker: "Worker"):
|
||||
return
|
||||
|
||||
# Deep GEMM warmup
|
||||
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
|
||||
and is_deep_gemm_supported()
|
||||
|
@ -864,6 +864,7 @@ class Scheduler(SchedulerInterface):
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> dict[int, EngineCoreOutputs]:
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
num_sampled_tokens = model_runner_output.num_sampled_tokens
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
@ -881,7 +882,8 @@ class Scheduler(SchedulerInterface):
|
||||
# to avoid expensive operations inside the loop.
|
||||
stopped_running_reqs: set[Request] = set()
|
||||
stopped_preempted_reqs: set[Request] = set()
|
||||
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
|
||||
for req_index, req_id in enumerate(model_runner_output.req_ids):
|
||||
num_tokens_scheduled = num_scheduled_tokens[req_id]
|
||||
assert num_tokens_scheduled > 0
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
@ -890,9 +892,13 @@ class Scheduler(SchedulerInterface):
|
||||
# in pipeline parallelism).
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[
|
||||
req_index] if sampled_token_ids else []
|
||||
generated_token_ids = []
|
||||
if sampled_token_ids is not None:
|
||||
assert num_sampled_tokens is not None
|
||||
n = num_sampled_tokens[req_index]
|
||||
if n > 0:
|
||||
generated_token_ids = sampled_token_ids[req_index, :n]
|
||||
generated_token_ids = generated_token_ids.tolist()
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
@ -342,3 +343,10 @@ class KVCacheConfig:
|
||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||
"""
|
||||
kv_cache_groups: list[KVCacheGroupSpec]
|
||||
|
||||
@cached_property
|
||||
def block_sizes(self) -> list[int]:
|
||||
return [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in self.kv_cache_groups
|
||||
]
|
||||
|
@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -15,11 +16,11 @@ if TYPE_CHECKING:
|
||||
class LogprobsLists(NamedTuple):
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: list[list[int]]
|
||||
logprob_token_ids: np.ndarray
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: list[list[float]]
|
||||
logprobs: np.ndarray
|
||||
# [num_reqs]
|
||||
sampled_token_ranks: list[int]
|
||||
sampled_token_ranks: np.ndarray
|
||||
|
||||
def slice(self, start: int, end: int):
|
||||
return LogprobsLists(
|
||||
@ -40,9 +41,9 @@ class LogprobsTensors(NamedTuple):
|
||||
|
||||
def tolists(self):
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids.tolist(),
|
||||
self.logprobs.tolist(),
|
||||
self.selected_token_ranks.tolist(),
|
||||
self.logprob_token_ids.cpu().numpy(),
|
||||
self.logprobs.cpu().numpy(),
|
||||
self.selected_token_ranks.cpu().numpy(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -89,20 +90,18 @@ class KVConnectorOutput:
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
# This is expensive for torch.Tensor so prefer to use list instead.
|
||||
@dataclass
|
||||
class ModelRunnerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
req_ids: list[str]
|
||||
# req_id -> index
|
||||
req_id_to_index: dict[str, int]
|
||||
|
||||
# num_reqs x num_generated_tokens
|
||||
# num_generated_tokens is the number of tokens
|
||||
# generated in the current step. It can be different for
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: list[list[int]]
|
||||
sampled_token_ids: Optional[np.ndarray]
|
||||
num_sampled_tokens: Optional[np.ndarray]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
@ -148,8 +147,8 @@ class DraftTokenIds:
|
||||
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
req_id_to_index={},
|
||||
sampled_token_ids=[],
|
||||
sampled_token_ids=None,
|
||||
num_sampled_tokens=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
|
0
vllm/v1/worker/gpu/__init__.py
Normal file
0
vllm/v1/worker/gpu/__init__.py
Normal file
48
vllm/v1/worker/gpu/async_utils.py
Normal file
48
vllm/v1/worker/gpu/async_utils.py
Normal file
@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import (AsyncModelRunnerOutput, LogprobsTensors,
|
||||
ModelRunnerOutput, SamplerOutput)
|
||||
|
||||
|
||||
class AsyncOutput(AsyncModelRunnerOutput):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
sampler_output: SamplerOutput,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
):
|
||||
self.model_runner_output = model_runner_output
|
||||
self.sampler_output = sampler_output
|
||||
self.copy_stream = copy_stream
|
||||
self.copy_event = torch.cuda.Event()
|
||||
|
||||
default_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
self.copy_stream.wait_stream(default_stream)
|
||||
|
||||
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
|
||||
"cpu", non_blocking=True)
|
||||
x = sampler_output.logprobs_tensors
|
||||
if x is not None:
|
||||
self.logprobs_tensors = LogprobsTensors(
|
||||
logprob_token_ids=x.logprob_token_ids.to(
|
||||
"cpu", non_blocking=True),
|
||||
logprobs=x.logprobs.to("cpu", non_blocking=True),
|
||||
selected_token_ranks=x.selected_token_ranks.to(
|
||||
"cpu", non_blocking=True),
|
||||
)
|
||||
else:
|
||||
self.logprobs_tensors = None
|
||||
self.copy_event.record()
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
self.model_runner_output.sampled_token_ids = (
|
||||
self.sampled_token_ids.numpy())
|
||||
if self.logprobs_tensors is not None:
|
||||
self.model_runner_output.logprobs = (
|
||||
self.logprobs_tensors.tolists())
|
||||
return self.model_runner_output
|
139
vllm/v1/worker/gpu/attn_utils.py
Normal file
139
vllm/v1/worker/gpu/attn_utils.py
Normal file
@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec, SlidingWindowSpec)
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def get_kv_cache_spec(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
) -> dict[str, KVCacheSpec]:
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
assert attn_module.attn_type == AttentionType.DECODER
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
return kv_cache_spec
|
||||
|
||||
|
||||
def init_attn_backend(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
attn_backends: dict[str, AttentionBackend] = {}
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
layer_names = kv_cache_group_spec.layer_names
|
||||
any_layer_name = next(iter(layer_names))
|
||||
|
||||
attn_backend = attn_layers[any_layer_name].get_attn_backend()
|
||||
for layer_name in layer_names:
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
attn_metadata_builder = attn_backend.get_builder_cls()(
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
)
|
||||
attn_metadata_builders.append(attn_metadata_builder)
|
||||
return attn_backends, attn_metadata_builders
|
||||
|
||||
|
||||
def _allocate_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys()
|
||||
), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
|
||||
def _reshape_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes)
|
||||
|
||||
attn_backend = attn_backends[layer_name]
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
kv_cache_shape = tuple(kv_cache_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
|
||||
raw_tensor = raw_tensor.view(dtype)
|
||||
raw_tensor = raw_tensor.view(kv_cache_shape)
|
||||
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
|
||||
return kv_caches
|
||||
|
||||
|
||||
def init_kv_cache(
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
forward_context: dict[str, Any],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
device: torch.device,
|
||||
):
|
||||
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
|
||||
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors,
|
||||
attn_backends)
|
||||
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
312
vllm/v1/worker/gpu/block_table.py
Normal file
312
vllm/v1/worker/gpu/block_table.py
Normal file
@ -0,0 +1,312 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
class BlockTables:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_sizes: list[int],
|
||||
max_num_reqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.block_sizes = block_sizes
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.num_kv_cache_groups = len(self.block_sizes)
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.block_tables: list[torch.Tensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size)
|
||||
block_table = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
max_num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.block_tables.append(block_table)
|
||||
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
|
||||
|
||||
# Block tables used for model's forward pass.
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.input_block_tables: list[torch.Tensor] = [
|
||||
torch.zeros_like(block_table) for block_table in self.block_tables
|
||||
]
|
||||
self.input_block_table_ptrs = self._make_ptr_tensor(
|
||||
self.input_block_tables)
|
||||
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.stride(0) for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.block_sizes_tensor = torch.tensor(self.block_sizes,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
# Misc buffers.
|
||||
self.req_indices = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
|
||||
self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups,
|
||||
self.max_num_reqs + 1,
|
||||
dtype=torch.int32)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*args,
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory,
|
||||
device=self.device)
|
||||
|
||||
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||
ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x],
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
|
||||
|
||||
def append_block_ids(
|
||||
self,
|
||||
# [num_reqs]
|
||||
req_indices: list[int],
|
||||
# [num_kv_cache_groups, num_reqs + 1]
|
||||
cu_num_new_blocks: list[list[int]],
|
||||
# [num_kv_cache_groups, num_new_blocks]
|
||||
new_block_ids: list[list[int]],
|
||||
# [num_reqs]
|
||||
overwrite: list[bool],
|
||||
) -> None:
|
||||
num_reqs = len(req_indices)
|
||||
self.req_indices.np[:num_reqs] = req_indices
|
||||
self.overwrite.np[:num_reqs] = overwrite
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
|
||||
|
||||
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
|
||||
# no clear upper bound to the number of new blocks in a single step.
|
||||
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
|
||||
# guarantee that the buffer is not freed before the copy is completed.
|
||||
self.new_block_ids_cpu = torch.empty(
|
||||
self.num_kv_cache_groups,
|
||||
max(len(x) for x in new_block_ids),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
new_block_ids_np = self.new_block_ids_cpu.numpy()
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i]
|
||||
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
self.req_indices.copy_to_gpu(num_reqs),
|
||||
self.cu_num_new_blocks.copy_to_gpu(),
|
||||
self.cu_num_new_blocks.gpu.stride(0),
|
||||
new_block_ids_gpu,
|
||||
new_block_ids_gpu.stride(0),
|
||||
self.overwrite.copy_to_gpu(num_reqs),
|
||||
self.block_table_strides,
|
||||
self.block_table_ptrs,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
def gather_block_tables(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
idx_mapping,
|
||||
self.block_table_ptrs,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(block_table[:num_reqs]
|
||||
for block_table in self.input_block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
self,
|
||||
query_start_loc: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = query_start_loc.shape[0] - 1
|
||||
num_tokens = positions.shape[0]
|
||||
num_groups = self.num_kv_cache_groups
|
||||
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
query_start_loc,
|
||||
positions,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mappings,
|
||||
self.slot_mappings.stride(0),
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _append_block_ids_kernel(
|
||||
# Inputs
|
||||
req_indices, # [num_reqs]
|
||||
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
|
||||
cu_num_new_blocks_stride,
|
||||
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
|
||||
new_block_ids_stride,
|
||||
overwrite, # [num_reqs]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
# Outputs
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
# Constants
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(req_indices + batch_idx)
|
||||
do_overwrite = tl.load(overwrite + batch_idx)
|
||||
|
||||
group_new_blocks_ptr = (cu_num_new_blocks_ptr +
|
||||
group_id * cu_num_new_blocks_stride)
|
||||
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
|
||||
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
|
||||
num_new_blocks = end_idx - start_idx
|
||||
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
if do_overwrite:
|
||||
dst_start_idx = 0
|
||||
else:
|
||||
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx)
|
||||
dst_end_idx = dst_start_idx + num_new_blocks
|
||||
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
|
||||
|
||||
# Destination
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
row_ptr = block_table_ptr + req_idx * block_table_stride
|
||||
|
||||
group_new_block_ids_ptr = (new_block_ids_ptr +
|
||||
group_id * new_block_ids_stride)
|
||||
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset,
|
||||
mask=offset < num_new_blocks)
|
||||
tl.store(row_ptr + dst_start_idx + offset,
|
||||
block_ids,
|
||||
mask=offset < num_new_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gather_block_tables_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_block_table_ptrs, # [num_kv_cache_groups]
|
||||
dst_block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
|
||||
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
|
||||
src_row_ptr = src_block_table_ptr + req_idx * stride
|
||||
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
|
||||
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
|
||||
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
|
||||
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_slot_mappings_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
cu_num_tokens, # [num_reqs + 1]
|
||||
pos, # [num_tokens]
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
page_sizes, # [num_kv_cache_groups]
|
||||
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
|
||||
slot_mappings_stride,
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
req_idx = tl.program_id(1)
|
||||
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
|
||||
|
||||
if req_idx == tl.num_programs(1) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset,
|
||||
PAD_ID,
|
||||
mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
page_size = tl.load(page_sizes + group_id)
|
||||
|
||||
start_idx = tl.load(cu_num_tokens + req_idx)
|
||||
end_idx = tl.load(cu_num_tokens + req_idx + 1)
|
||||
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // page_size
|
||||
block_numbers = tl.load(block_table_ptr +
|
||||
req_idx * block_table_stride + block_indices)
|
||||
slot_ids = block_numbers * page_size + positions % page_size
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _load_ptr(ptr_to_ptr, elem_dtype):
|
||||
ptr = tl.load(ptr_to_ptr)
|
||||
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
|
||||
return tl.multiple_of(ptr, 16)
|
58
vllm/v1/worker/gpu/dist_utils.py
Normal file
58
vllm/v1/worker/gpu/dist_utils.py
Normal file
@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
|
||||
def evenly_split(
|
||||
n: int,
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
) -> tuple[int, int]:
|
||||
q = n // tp_size
|
||||
r = n % tp_size
|
||||
start = q * tp_rank + min(tp_rank, r)
|
||||
end = start + q + (1 if tp_rank < r else 0)
|
||||
return start, end
|
||||
|
||||
|
||||
def pad_and_all_gather(
|
||||
x: torch.Tensor,
|
||||
padded_size: int,
|
||||
) -> torch.Tensor:
|
||||
n = x.shape[0]
|
||||
if n != padded_size:
|
||||
padded_x = torch.empty(
|
||||
(padded_size, *x.shape[1:]),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
padded_x[:n] = x
|
||||
else:
|
||||
padded_x = x
|
||||
|
||||
x = tensor_model_parallel_all_gather(padded_x)
|
||||
return x
|
||||
|
||||
|
||||
def all_gather_sampler_output(
|
||||
sampler_output: SamplerOutput,
|
||||
num_reqs: int,
|
||||
tp_size: int,
|
||||
) -> SamplerOutput:
|
||||
n = (num_reqs + tp_size - 1) // tp_size
|
||||
sampler_output.sampled_token_ids = pad_and_all_gather(
|
||||
sampler_output.sampled_token_ids, n)[:num_reqs]
|
||||
|
||||
# TODO(woosuk): 3 small all-gathers, could be merged into one.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
if logprobs_tensors is not None:
|
||||
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
|
||||
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
|
||||
logprobs_tensors.logprobs = pad_and_all_gather(
|
||||
logprobs_tensors.logprobs, n)[:num_reqs]
|
||||
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
|
||||
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
|
||||
return sampler_output
|
247
vllm/v1/worker/gpu/input_batch.py
Normal file
247
vllm/v1/worker/gpu/input_batch.py
Normal file
@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numba
|
||||
import numba.types as types
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
class InputBuffers:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
|
||||
self.query_start_loc = self._make_buffer(max_num_reqs + 1,
|
||||
dtype=torch.int32)
|
||||
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*args,
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory,
|
||||
device=self.device)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputBatch:
|
||||
|
||||
# batch_idx -> req_id
|
||||
req_ids: list[str]
|
||||
num_reqs: int
|
||||
|
||||
# batch_idx -> req_state_idx
|
||||
idx_mapping: torch.Tensor
|
||||
idx_mapping_np: np.ndarray
|
||||
|
||||
# batch_idx -> num_scheduled_tokens
|
||||
num_scheduled_tokens: np.ndarray
|
||||
# sum(num_scheduled_tokens)
|
||||
num_tokens: int
|
||||
num_tokens_after_padding: int
|
||||
# [num_reqs]
|
||||
is_chunked_prefilling: np.ndarray
|
||||
|
||||
# [max_num_batched_tokens]
|
||||
input_ids: torch.Tensor
|
||||
# [max_num_batched_tokens]
|
||||
positions: torch.Tensor
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
|
||||
# [num_reqs]
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.tensor(idx_mapping_np, device=device)
|
||||
num_scheduled_tokens = np.full(num_reqs,
|
||||
num_tokens // num_reqs,
|
||||
dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_)
|
||||
input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device)
|
||||
positions = torch.zeros(num_tokens, dtype=torch.int64, device=device)
|
||||
attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = torch.arange(num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
return cls(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens,
|
||||
is_chunked_prefilling=is_chunked_prefilling,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: With the type annotations, this function is pre-compiled
|
||||
# before the first call.
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:], # idx_mapping
|
||||
types.int32[:, :], # token_ids
|
||||
types.int32[:], # num_computed_tokens
|
||||
types.int32[:], # num_scheduled_tokens
|
||||
types.int32[:], # input_ids
|
||||
types.int64[:], # positions
|
||||
types.int32[:], # query_start_loc
|
||||
types.int32[:], # seq_lens
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def _prepare_inputs(
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
|
||||
cu_num_tokens = 0
|
||||
for i in range(num_reqs):
|
||||
req_idx = idx_mapping[i]
|
||||
query_len = num_scheduled_tokens[i]
|
||||
start = num_computed_tokens[req_idx]
|
||||
end = start + query_len
|
||||
seq_lens[i] = end
|
||||
|
||||
start_idx = cu_num_tokens
|
||||
end_idx = start_idx + query_len
|
||||
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
|
||||
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
|
||||
|
||||
cu_num_tokens = end_idx
|
||||
query_start_loc[i + 1] = cu_num_tokens
|
||||
|
||||
# Pad the inputs for CUDA graphs.
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
idx_mapping: np.ndarray,
|
||||
prompt_token_ids: np.ndarray,
|
||||
num_computed_tokens: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
input_ids: CpuGpuBuffer,
|
||||
positions: CpuGpuBuffer,
|
||||
query_start_loc: CpuGpuBuffer,
|
||||
seq_lens: CpuGpuBuffer,
|
||||
num_tokens: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
_prepare_inputs(
|
||||
idx_mapping,
|
||||
prompt_token_ids,
|
||||
num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
input_ids.np,
|
||||
positions.np,
|
||||
query_start_loc.np,
|
||||
seq_lens.np,
|
||||
)
|
||||
input_ids.copy_to_gpu(num_tokens)
|
||||
positions.copy_to_gpu(num_tokens)
|
||||
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
|
||||
# tensors from CPU to GPU, because they may include paddings needed
|
||||
# for full CUDA graph mode.
|
||||
query_start_loc.copy_to_gpu()
|
||||
seq_lens.copy_to_gpu()
|
||||
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
max_query_len = int(num_scheduled_tokens.max())
|
||||
max_seq_len = int(seq_lens.np[:num_reqs].max())
|
||||
return max_query_len, max_seq_len
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _combine_last_token_ids_kernel(
|
||||
input_ids_ptr,
|
||||
idx_mapping_ptr,
|
||||
last_token_ids_ptr,
|
||||
query_start_loc_ptr,
|
||||
seq_lens_ptr,
|
||||
num_tokens_ptr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + batch_idx)
|
||||
num_tokens = tl.load(num_tokens_ptr + req_state_idx)
|
||||
if seq_len < num_tokens:
|
||||
# Chunked prefilling.
|
||||
return
|
||||
|
||||
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
|
||||
if last_token_id == -1:
|
||||
return
|
||||
|
||||
end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
tl.store(input_ids_ptr + end - 1, last_token_id)
|
||||
|
||||
|
||||
def combine_last_token_ids(
|
||||
input_ids: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
last_token_ids: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_tokens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = seq_lens.shape[0]
|
||||
_combine_last_token_ids_kernel[(num_reqs, )](
|
||||
input_ids,
|
||||
idx_mapping,
|
||||
last_token_ids,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
num_tokens,
|
||||
)
|
||||
return input_ids
|
476
vllm/v1/worker/gpu/model_runner.py
Normal file
476
vllm/v1/worker/gpu/model_runner.py
Normal file
@ -0,0 +1,476 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.sample.sampler import SamplerOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
|
||||
init_attn_backend, init_kv_cache)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output,
|
||||
evenly_split)
|
||||
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
|
||||
combine_last_token_ids,
|
||||
prepare_inputs)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
self.kv_cache_dtype = self.dtype
|
||||
if self.cache_config.cache_dtype != "auto":
|
||||
# Quantized KV cache.
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
self.is_pooling_model = False
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
assert self.use_async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream()
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[str]:
|
||||
return ("generate", )
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
time_before_load = time.perf_counter()
|
||||
with DeviceMemoryProfiler() as m:
|
||||
model_loader = get_model_loader(self.vllm_config.load_config)
|
||||
logger.info("Loading model from scratch...")
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.model_config,
|
||||
)
|
||||
time_after_load = time.perf_counter()
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
||||
m.consumed_memory / GiB_bytes,
|
||||
time_after_load - time_before_load)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_kv_cache_spec(self):
|
||||
return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
block_sizes = kv_cache_config.block_sizes
|
||||
|
||||
self.block_tables = BlockTables(
|
||||
block_sizes=block_sizes,
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
|
||||
self.kv_cache_config,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
init_kv_cache(
|
||||
self.kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_cache_config,
|
||||
self.attn_backends,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
input_batch: Optional[InputBatch] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if input_batch is None:
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=min(num_tokens, self.max_num_reqs),
|
||||
num_tokens=num_tokens,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
|
||||
def _dummy_sampler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = hidden_states.shape[0]
|
||||
sampling_metadata = SamplingMetadata.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
device=self.device,
|
||||
)
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
self.sampler(logits, sampling_metadata)
|
||||
|
||||
def profile_run(self) -> None:
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=self.max_num_reqs,
|
||||
num_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
)
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens,
|
||||
input_batch=input_batch,
|
||||
)
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
|
||||
def update_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# for req_id in scheduler_output.preempted_req_ids:
|
||||
# self.req_states.remove_request(req_id)
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
|
||||
# TODO(woosuk): Change SchedulerOutput.
|
||||
req_indices: list[int] = []
|
||||
cu_num_new_blocks = tuple(
|
||||
[0] for _ in range(self.block_tables.num_kv_cache_groups))
|
||||
new_block_ids = tuple(
|
||||
[] for _ in range(self.block_tables.num_kv_cache_groups))
|
||||
overwrite: list[bool] = []
|
||||
|
||||
# Add new requests.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
self.req_states.add_request(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
sampling_params=new_req_data.sampling_params,
|
||||
)
|
||||
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
req_indices.append(req_index)
|
||||
for i, block_ids in enumerate(new_req_data.block_ids):
|
||||
x = cu_num_new_blocks[i][-1]
|
||||
cu_num_new_blocks[i].append(x + len(block_ids))
|
||||
new_block_ids[i].extend(block_ids)
|
||||
overwrite.append(True)
|
||||
|
||||
# Add new blocks for the existing requests.
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
req_new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if req_new_block_ids is not None:
|
||||
req_indices.append(req_index)
|
||||
for group_id, block_ids in enumerate(req_new_block_ids):
|
||||
x = cu_num_new_blocks[group_id][-1]
|
||||
cu_num_new_blocks[group_id].append(x + len(block_ids))
|
||||
new_block_ids[group_id].extend(block_ids)
|
||||
overwrite.append(False)
|
||||
|
||||
if req_indices:
|
||||
self.block_tables.append_block_ids(
|
||||
req_indices=req_indices,
|
||||
cu_num_new_blocks=cu_num_new_blocks,
|
||||
new_block_ids=new_block_ids,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch:
|
||||
num_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert num_tokens > 0
|
||||
num_reqs = len(scheduler_output.num_scheduled_tokens)
|
||||
|
||||
# Decode first, then prefill.
|
||||
# batch_idx -> req_id
|
||||
req_ids = sorted(scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get)
|
||||
num_scheduled_tokens = np.array(
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
|
||||
dtype=np.int32)
|
||||
|
||||
# TODO(woosuk): Support CUDA graphs.
|
||||
num_tokens_after_padding = num_tokens
|
||||
|
||||
idx_mapping_list = [
|
||||
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
idx_mapping = self.input_buffers.idx_mapping
|
||||
idx_mapping.np[:num_reqs] = idx_mapping_list
|
||||
idx_mapping_np = idx_mapping.np[:num_reqs]
|
||||
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
max_query_len, max_seq_len = prepare_inputs(
|
||||
idx_mapping_np,
|
||||
self.req_states.prompt_token_ids,
|
||||
self.req_states.num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
self.input_buffers.input_ids,
|
||||
self.input_buffers.positions,
|
||||
self.input_buffers.query_start_loc,
|
||||
self.input_buffers.seq_lens,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
|
||||
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
|
||||
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
|
||||
seq_lens_cpu = self.input_buffers.seq_lens.cpu[:num_reqs]
|
||||
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens.
|
||||
combine_last_token_ids(
|
||||
self.input_buffers.input_ids.gpu,
|
||||
idx_mapping,
|
||||
self.req_states.last_sampled_tokens,
|
||||
query_start_loc_gpu,
|
||||
seq_lens_gpu,
|
||||
self.req_states.num_tokens.copy_to_gpu(),
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens])
|
||||
|
||||
num_computed_tokens_cpu = torch.from_numpy(
|
||||
self.req_states.num_computed_tokens[idx_mapping_np])
|
||||
|
||||
# Whether the request is chunked-prefilling or not.
|
||||
is_chunked_prefilling = (
|
||||
seq_lens_np < self.req_states.num_tokens.np[idx_mapping_np])
|
||||
|
||||
# Logits indices to sample next token from.
|
||||
logits_indices = query_start_loc_gpu[1:] - 1
|
||||
num_logits_indices = logits_indices.size(0)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = self.kv_cache_config.kv_cache_groups
|
||||
for i, kv_cache_spec in enumerate(kv_cache_groups):
|
||||
block_table = block_tables[i]
|
||||
slot_mapping = slot_mappings[i]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens_gpu,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
logits_indices_padded=None,
|
||||
num_logits_indices=num_logits_indices,
|
||||
causal=True,
|
||||
encoder_seq_lens=None,
|
||||
)
|
||||
|
||||
attn_metadata_builder = self.attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
|
||||
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens_after_padding,
|
||||
is_chunked_prefilling=is_chunked_prefilling,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
) -> SamplerOutput:
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
num_reqs = logits.shape[0]
|
||||
|
||||
# When the batch size is large enough, use DP sampler.
|
||||
tp_group = get_tp_group()
|
||||
tp_size = tp_group.world_size
|
||||
n = (num_reqs + tp_size - 1) // tp_size
|
||||
use_dp_sampler = tp_size > 1 and n > 32 # TODO(woosuk): Tune.
|
||||
if use_dp_sampler:
|
||||
# NOTE(woosuk): Make sure that no rank gets zero requests.
|
||||
tp_rank = tp_group.rank
|
||||
start, end = evenly_split(num_reqs, tp_size, tp_rank)
|
||||
logits = logits[start:end]
|
||||
pos = pos[start:end]
|
||||
idx_mapping_np = idx_mapping_np[start:end]
|
||||
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
idx_mapping_np, pos)
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
needs_prompt_logprobs = np.any(
|
||||
self.req_states.needs_prompt_logprobs[idx_mapping_np])
|
||||
assert not needs_prompt_logprobs
|
||||
|
||||
if use_dp_sampler:
|
||||
# All-gather the outputs.
|
||||
sampler_output = all_gather_sampler_output(
|
||||
sampler_output,
|
||||
num_reqs,
|
||||
tp_size,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
sampler_output: SamplerOutput,
|
||||
input_batch: InputBatch,
|
||||
) -> AsyncOutput:
|
||||
# Store the last sampled token ids.
|
||||
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
|
||||
sampler_output.sampled_token_ids)
|
||||
|
||||
# Get the number of sampled tokens.
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
is_chunked_prefilling = input_batch.is_chunked_prefilling
|
||||
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
|
||||
# Increment the number of tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
self.req_states.num_tokens.np[idx_mapping_np] += num_sampled_tokens
|
||||
# Increment the number of computed tokens.
|
||||
self.req_states.num_computed_tokens[idx_mapping_np] += (
|
||||
input_batch.num_scheduled_tokens)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
sampled_token_ids=None,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
kv_connector_output=None,
|
||||
num_nans_in_logits=None,
|
||||
)
|
||||
return AsyncOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
sampler_output=sampler_output,
|
||||
copy_stream=self.output_copy_stream,
|
||||
)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> AsyncOutput:
|
||||
self.update_states(scheduler_output)
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
input_batch = self.prepare_inputs(scheduler_output)
|
||||
num_tokens = input_batch.num_tokens_after_padding
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
|
||||
sampler_output = self.sample(hidden_states, input_batch)
|
||||
return self.postprocess(sampler_output, input_batch)
|
323
vllm/v1/worker/gpu/sampler.py
Normal file
323
vllm/v1/worker/gpu/sampler.py
Normal file
@ -0,0 +1,323 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config import LogprobsMode
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = "processed_logprobs",
|
||||
):
|
||||
super().__init__()
|
||||
assert logprobs_mode == "processed_logprobs"
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# Divide logits by temperature, in FP32.
|
||||
logits = apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
logits = apply_top_k_top_p(
|
||||
logits,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
# Sample the next token (int64).
|
||||
sampled = gumbel_sample(
|
||||
probs,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
)
|
||||
|
||||
logprobs_tensors = None
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
logprobs_tensors = compute_logprobs(
|
||||
logits,
|
||||
num_logprobs,
|
||||
sampled,
|
||||
)
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.view(-1, 1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _apply_temp_kernel(
|
||||
logits, # bf16[batch_size, vocab_size]
|
||||
logits_stride,
|
||||
output, # fp32[batch_size, vocab_size]
|
||||
output_stride,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
EPSILON: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
|
||||
temp = tl.load(temperature + batch_idx)
|
||||
if temp < EPSILON:
|
||||
# Greedy sampling. Don't apply temperature.
|
||||
# NOTE(woosuk): In this case, we assume that its logprobs are not used.
|
||||
temp = 1.0
|
||||
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
block = block_idx * BLOCK_SIZE + offset
|
||||
|
||||
# Load the logits.
|
||||
x = tl.load(logits + batch_idx * logits_stride + block,
|
||||
mask=block < vocab_size)
|
||||
x = x.to(tl.float32)
|
||||
x = x / temp
|
||||
tl.store(output + batch_idx * output_stride + block,
|
||||
x,
|
||||
mask=block < vocab_size)
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, vocab_size = logits.shape
|
||||
output = torch.empty_like(logits, dtype=torch.float32)
|
||||
BLOCK_SIZE = 8192
|
||||
_apply_temp_kernel[(batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
output,
|
||||
output.stride(0),
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
EPSILON=_SAMPLING_EPS,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _apply_gumbel_kernel(
|
||||
probs_ptr,
|
||||
probs_stride,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
EPSILON: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
|
||||
if temp < EPSILON:
|
||||
# Greedy sampling. Don't apply gumbel noise.
|
||||
return
|
||||
|
||||
seed = tl.load(seeds_ptr + req_idx).to(tl.uint64)
|
||||
pos = tl.load(pos_ptr + req_idx).to(tl.uint64)
|
||||
gumbel_seed = seed ^ (pos * 0x9E3779B97F4A7C15)
|
||||
|
||||
block_id = tl.program_id(1)
|
||||
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
q = tl.rand(gumbel_seed, r_offset)
|
||||
|
||||
# NOTE(woosuk): This logic makes sure q is not 0.
|
||||
RMAX = 0.9999999403953552
|
||||
RMAX_LOG = -5.960464477539063e-08
|
||||
q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q))
|
||||
q = -1.0 * q
|
||||
|
||||
p = tl.load(probs_ptr + req_idx * probs_stride + r_offset,
|
||||
mask=r_offset < vocab_size)
|
||||
p = p / q
|
||||
|
||||
tl.store(probs_ptr + req_idx * probs_stride + r_offset,
|
||||
p,
|
||||
mask=r_offset < vocab_size)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
# fp32[num_reqs, vocab_size]
|
||||
probs: torch.Tensor,
|
||||
# fp32[num_reqs]
|
||||
temperature: torch.Tensor,
|
||||
# int64[num_reqs]
|
||||
seeds: torch.Tensor,
|
||||
# int64[num_reqs]
|
||||
pos: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = probs.shape[0]
|
||||
vocab_size = probs.shape[1]
|
||||
|
||||
# Update the probs in-place.
|
||||
BLOCK_SIZE = 8192
|
||||
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
|
||||
probs,
|
||||
probs.stride(0),
|
||||
seeds,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE,
|
||||
EPSILON=_SAMPLING_EPS,
|
||||
)
|
||||
# Sample the next token.
|
||||
return probs.argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_log_softmax_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
topk_ids_ptr,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PADDED_TOPK: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block,
|
||||
mask=block < vocab_size,
|
||||
other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(l, max_val))
|
||||
|
||||
se = 0.0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
e = tl.exp(l - max_val)
|
||||
e = tl.where(block < vocab_size, e, 0.0)
|
||||
se += tl.sum(e)
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < topk
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask)
|
||||
|
||||
l = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
o = l - max_val - lse
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, vocab_size = logits.shape
|
||||
topk = topk_ids.shape[1]
|
||||
output = torch.empty(
|
||||
batch_size,
|
||||
topk,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_topk_log_softmax_kernel[(batch_size, )](
|
||||
output,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
topk_ids,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024,
|
||||
PADDED_TOPK=triton.next_power_of_2(topk),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
token_ids_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
token_id = tl.load(token_ids_ptr + req_idx)
|
||||
x = tl.load(row_ptr + token_id)
|
||||
|
||||
n = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block,
|
||||
mask=block < vocab_size,
|
||||
other=float("-inf"))
|
||||
n += tl.sum((l > x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
def compute_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
assert num_logprobs >= 0
|
||||
batch_size, vocab_size = logits.shape
|
||||
if num_logprobs == 0:
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
else:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat(
|
||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = compute_topk_logprobs(
|
||||
logits,
|
||||
logprob_token_ids,
|
||||
)
|
||||
|
||||
token_ranks = torch.empty(
|
||||
batch_size,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
_ranks_kernel[(batch_size, )](
|
||||
token_ranks,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampled_token_ids,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192,
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=token_ranks,
|
||||
)
|
229
vllm/v1/worker/gpu/states.py
Normal file
229
vllm/v1/worker/gpu/states.py
Normal file
@ -0,0 +1,229 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: torch.Tensor
|
||||
|
||||
top_p: torch.Tensor | None
|
||||
top_k: torch.Tensor | None
|
||||
|
||||
seeds: torch.Tensor
|
||||
pos: torch.Tensor
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: int | None
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
device: torch.device,
|
||||
) -> "SamplingMetadata":
|
||||
assert num_reqs > 0
|
||||
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
temperature[0] = 0.5
|
||||
top_p = torch.ones(num_reqs, dtype=torch.float32, device=device)
|
||||
top_p[0] = 0.99
|
||||
top_k = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
return cls(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
|
||||
# NOTE(woosuk): Strictly speaking, it contains prompt + some output
|
||||
# because of preemption.
|
||||
self.prompt_token_ids = np.zeros(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.num_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Last sampled tokens.
|
||||
self.last_sampled_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
1,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Sampling parameters.
|
||||
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
|
||||
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
|
||||
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
self.num_logprobs.fill(-1)
|
||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
|
||||
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
|
||||
return Param(size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
# NOTE(woosuk): Strictly speaking, "prompt_len" here may include
|
||||
# output tokens, if the request is resumed from preemption.
|
||||
prompt_len = len(prompt_token_ids)
|
||||
self.prompt_token_ids[req_idx, :prompt_len] = prompt_token_ids
|
||||
self.num_tokens.np[req_idx] = prompt_len
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
# TODO(woosuk): Optimize.
|
||||
self.last_sampled_tokens[req_idx].fill_(-1)
|
||||
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
if 0 < sampling_params.top_k < self.vocab_size:
|
||||
top_k = sampling_params.top_k
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seed = sampling_params.seed
|
||||
else:
|
||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||
self.seeds.np[req_idx] = seed
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
num_logprobs = sampling_params.logprobs
|
||||
else:
|
||||
num_logprobs = -1
|
||||
self.num_logprobs[req_idx] = num_logprobs
|
||||
|
||||
# For now, only support prompt logprobs for the prompt tokens.
|
||||
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
||||
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
idx_mapping: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
temperature = self.temperature.np[idx_mapping]
|
||||
temperature = self.temperature.copy_np_to_gpu(temperature)
|
||||
|
||||
top_p = self.top_p.np[idx_mapping]
|
||||
no_top_p = np.all(top_p == 1.0)
|
||||
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
|
||||
top_k = self.top_k.np[idx_mapping]
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
|
||||
seeds = self.seeds.np[idx_mapping]
|
||||
seeds = self.seeds.copy_np_to_gpu(seeds)
|
||||
|
||||
num_logprobs = self.num_logprobs[idx_mapping]
|
||||
max_num_logprobs = int(np.max(num_logprobs))
|
||||
if max_num_logprobs == -1:
|
||||
max_num_logprobs = None
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
|
||||
class Param:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.buffer = CpuGpuBuffer(
|
||||
size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.np = np.zeros_like(self.buffer.np)
|
||||
|
||||
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
|
||||
n = x.shape[0]
|
||||
self.buffer.np[:n] = x
|
||||
return self.buffer.copy_to_gpu(n)
|
@ -31,7 +31,8 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, ModelRunnerOutput)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
@ -341,7 +342,9 @@ class Worker(WorkerBase):
|
||||
self.model_runner._dummy_run(size,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
|
||||
if self.model_runner.lora_config is not None:
|
||||
self.model_runner.maybe_remove_all_loras(
|
||||
self.model_runner.lora_config)
|
||||
|
||||
# Warmup and tune the kernels used during model execution before
|
||||
# cuda graph capture.
|
||||
@ -436,6 +439,9 @@ class Worker(WorkerBase):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
||||
if len(get_pp_group().ranks) == 1:
|
||||
return self.model_runner.execute_model(scheduler_output)
|
||||
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@ -454,8 +460,6 @@ class Worker(WorkerBase):
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
return output
|
||||
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
@ -692,8 +696,9 @@ class Worker(WorkerBase):
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if runner := getattr(self, "model_runner", None):
|
||||
runner.ensure_kv_transfer_shutdown()
|
||||
# if runner := getattr(self, "model_runner", None):
|
||||
# runner.ensure_kv_transfer_shutdown()
|
||||
pass
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
|
Reference in New Issue
Block a user