mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
1 Commits
v0.11.1rc1
...
bind_kv_ca
Author | SHA1 | Date | |
---|---|---|---|
bfff9bcd1d |
@ -10,10 +10,14 @@ prompts = [
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
max_num_seqs=16,
|
||||
max_model_len=128,
|
||||
enforce_eager=True)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
@ -21,4 +25,4 @@ outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
@ -12,6 +12,7 @@ import torch.distributed as dist
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@ -33,14 +34,16 @@ class DPMetadata:
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
# copy from vllm_config.compilation_config.static_forward_context
|
||||
# Copy from vllm_config.compilation_config.static_forward_context
|
||||
no_compile_layers: dict[str, Any]
|
||||
# TODO: extend to support per-layer dynamic forward context
|
||||
# TODO: Extend to support per-layer dynamic forward context
|
||||
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
||||
# TODO: remove after making all virtual_engines share the same kv cache
|
||||
# TODO: Remove after making all virtual_engines share the same kv cache
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
# set dynamically for each forward pass
|
||||
# Set dynamically for each forward pass
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
# Whether this is a profile run (before KV cache init)
|
||||
is_profile_run: bool = False,
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
@ -58,7 +61,8 @@ def get_forward_context() -> ForwardContext:
|
||||
def set_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0):
|
||||
num_tokens: int = 0,
|
||||
is_profile_run: bool = False):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
@ -93,12 +97,15 @@ def set_forward_context(attn_metadata: Any,
|
||||
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
|
||||
_forward_context = ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata)
|
||||
dp_metadata=dp_metadata,
|
||||
is_profile_run=is_profile_run)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -111,10 +118,17 @@ def set_forward_context(attn_metadata: Any,
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = attn_metadata.num_input_tokens
|
||||
|
||||
# we use synchronous scheduling right now,
|
||||
# adding a sync point here should not affect
|
||||
# scheduling of the next batch
|
||||
torch.cuda.synchronize()
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla.core.xla_model as xm
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
else:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
now = time.perf_counter()
|
||||
# time measurement is in milliseconds
|
||||
batchsize_forward_time[batchsize].append(
|
||||
|
@ -30,7 +30,6 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -104,9 +103,6 @@ class TPUModelRunner:
|
||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
# Lazy initialization
|
||||
# self.model: nn.Module # Set after load_model
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
|
||||
@ -582,7 +578,6 @@ class TPUModelRunner:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = hidden_states[:total_num_scheduled_tokens]
|
||||
@ -680,8 +675,8 @@ class TPUModelRunner:
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
kv_caches,
|
||||
num_tokens: int,
|
||||
is_profile_run: bool,
|
||||
) -> None:
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
@ -728,15 +723,28 @@ class TPUModelRunner:
|
||||
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
0,
|
||||
is_profile_run=is_profile_run):
|
||||
assert self.model is not None
|
||||
self.model(
|
||||
input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
# This is used before KV cache init
|
||||
def profile_run(self, num_tokens) -> None:
|
||||
self._dummy_run(num_tokens=num_tokens, is_profile_run=True)
|
||||
|
||||
# This is used after KV cache init
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
self._dummy_run(num_tokens=num_tokens, is_profile_run=False)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
|
||||
@ -745,7 +753,7 @@ class TPUModelRunner:
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
while True:
|
||||
self._dummy_run(self.kv_caches, num_tokens)
|
||||
self.dummy_run(num_tokens)
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
@ -769,6 +777,7 @@ class TPUModelRunner:
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
kv_cache_shape_prev = None
|
||||
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
||||
@ -779,6 +788,12 @@ class TPUModelRunner:
|
||||
layer_spec.head_size)
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
# Ensure all "kv_cache_shape" are the same across the model
|
||||
if kv_cache_shape_prev is None:
|
||||
kv_cache_shape_prev = kv_cache_shape
|
||||
else:
|
||||
assert kv_cache_shape == kv_cache_shape_prev
|
||||
|
||||
tpu_k_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
@ -788,10 +803,16 @@ class TPUModelRunner:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
# ModelWrapperV1 needs to know the KV cache shape
|
||||
self.model.set_kv_cache_shape(kv_cache_shape_prev)
|
||||
|
||||
# Associates each attention layer in the `forward_context` with the
|
||||
# initialized KV cache.
|
||||
forward_context = self.vllm_config.compilation_config \
|
||||
.static_forward_context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
class ModelWrapperV1(nn.Module):
|
||||
@ -799,12 +820,15 @@ class ModelWrapperV1(nn.Module):
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.kv_cache_shape = None
|
||||
|
||||
def set_kv_cache_shape(self, kv_cache_shape):
|
||||
self.kv_cache_shape = kv_cache_shape
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
@ -817,16 +841,20 @@ class ModelWrapperV1(nn.Module):
|
||||
inputs_embeds: The input embeddings of shape [num_tokens,
|
||||
hidden_size]. It is used for multimodal models.
|
||||
"""
|
||||
# Skip this in memory profiling at initialization.
|
||||
if kv_caches[0][0].numel() > 0:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
# work, we need to flatten the first three dimensions and modify
|
||||
# the slot_mapping accordingly.
|
||||
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
|
||||
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
# work, we need to flatten the first three dimensions and modify
|
||||
# the slot_mapping accordingly.
|
||||
#
|
||||
# Note: We skip this step during first profiling run (before KV init)
|
||||
if not forward_context.is_profile_run:
|
||||
assert self.kv_cache_shape # Ensure initialized
|
||||
num_kv_heads, num_blocks, block_size, _ = self.kv_cache_shape
|
||||
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
head_indicies = torch.arange(0,
|
||||
|
@ -21,7 +21,6 @@ from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -128,18 +127,19 @@ class TPUWorker:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
runner_kv_caches)
|
||||
# Associates each attention layer in the `forward_context` with the
|
||||
# initialized KV cache.
|
||||
forward_context = self.vllm_config.compilation_config \
|
||||
.static_forward_context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
self.model_runner._dummy_run(
|
||||
runner_kv_caches,
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
self.model_runner.profile_run(
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
|
Reference in New Issue
Block a user