Compare commits

...

1 Commits

Author SHA1 Message Date
bfff9bcd1d [V1] TPU - Remove self.kv_caches 2025-03-05 20:42:05 +00:00
4 changed files with 90 additions and 44 deletions

View File

@ -10,10 +10,14 @@ prompts = [
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=16,
max_model_len=128,
enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
@ -21,4 +25,4 @@ outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -12,6 +12,7 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@ -33,14 +34,16 @@ class DPMetadata:
@dataclass
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
# Copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
# TODO: Extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
# TODO: Remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
# Set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
# Whether this is a profile run (before KV cache init)
is_profile_run: bool = False,
_forward_context: Optional[ForwardContext] = None
@ -58,7 +61,8 @@ def get_forward_context() -> ForwardContext:
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int = 0):
num_tokens: int = 0,
is_profile_run: bool = False):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
@ -93,12 +97,15 @@ def set_forward_context(attn_metadata: Any,
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata)
dp_metadata=dp_metadata,
is_profile_run=is_profile_run)
try:
yield
finally:
@ -111,10 +118,17 @@ def set_forward_context(attn_metadata: Any,
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
torch.cuda.synchronize()
if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
xm.mark_step()
xm.wait_device_ops()
else:
torch.cuda.synchronize()
now = time.perf_counter()
# time measurement is in milliseconds
batchsize_forward_time[batchsize].append(

View File

@ -30,7 +30,6 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
@ -104,9 +103,6 @@ class TPUModelRunner:
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
@ -582,7 +578,6 @@ class TPUModelRunner:
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:total_num_scheduled_tokens]
@ -680,8 +675,8 @@ class TPUModelRunner:
def _dummy_run(
self,
kv_caches,
num_tokens: int,
is_profile_run: bool,
) -> None:
if self.is_multimodal_model:
input_ids = None
@ -728,15 +723,28 @@ class TPUModelRunner:
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0):
with set_forward_context(attn_metadata,
self.vllm_config,
0,
is_profile_run=is_profile_run):
assert self.model is not None
self.model(
input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)
# This is used before KV cache init
def profile_run(self, num_tokens) -> None:
self._dummy_run(num_tokens=num_tokens, is_profile_run=True)
# This is used after KV cache init
def dummy_run(
self,
num_tokens: int,
) -> None:
self._dummy_run(num_tokens=num_tokens, is_profile_run=False)
def capture_model(self) -> None:
"""Compile the model."""
@ -745,7 +753,7 @@ class TPUModelRunner:
start = time.perf_counter()
num_tokens = 16
while True:
self._dummy_run(self.kv_caches, num_tokens)
self.dummy_run(num_tokens)
logger.info(" -- num_tokens: %d", num_tokens)
xm.mark_step()
xm.wait_device_ops()
@ -769,6 +777,7 @@ class TPUModelRunner:
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_shape_prev = None
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
@ -779,6 +788,12 @@ class TPUModelRunner:
layer_spec.head_size)
dtype = layer_spec.dtype
# Ensure all "kv_cache_shape" are the same across the model
if kv_cache_shape_prev is None:
kv_cache_shape_prev = kv_cache_shape
else:
assert kv_cache_shape == kv_cache_shape_prev
tpu_k_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
@ -788,10 +803,16 @@ class TPUModelRunner:
else:
raise NotImplementedError
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
# ModelWrapperV1 needs to know the KV cache shape
self.model.set_kv_cache_shape(kv_cache_shape_prev)
# Associates each attention layer in the `forward_context` with the
# initialized KV cache.
forward_context = self.vllm_config.compilation_config \
.static_forward_context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
class ModelWrapperV1(nn.Module):
@ -799,12 +820,15 @@ class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.kv_cache_shape = None
def set_kv_cache_shape(self, kv_cache_shape):
self.kv_cache_shape = kv_cache_shape
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
@ -817,16 +841,20 @@ class ModelWrapperV1(nn.Module):
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
attn_metadata = get_forward_context().attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
#
# Note: We skip this step during first profiling run (before KV init)
if not forward_context.is_profile_run:
assert self.kv_cache_shape # Ensure initialized
num_kv_heads, num_blocks, block_size, _ = self.kv_cache_shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indicies = torch.arange(0,

View File

@ -21,7 +21,6 @@ from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
logger = init_logger(__name__)
@ -128,18 +127,19 @@ class TPUWorker:
else:
raise NotImplementedError
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
# Associates each attention layer in the `forward_context` with the
# initialized KV cache.
forward_context = self.vllm_config.compilation_config \
.static_forward_context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
self.model_runner._dummy_run(
runner_kv_caches,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
self.model_runner.profile_run(
num_tokens=self.scheduler_config.max_num_batched_tokens)
# Synchronize before measuring the memory usage.
xm.mark_step()
xm.wait_device_ops()
# Get the maximum amount of memory used by the model weights and