mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] Fix batch updates for pooling models (#23398)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -50,6 +50,10 @@ class BatchUpdateBuilder:
|
||||
self.added = added or []
|
||||
self._is_removed_sorted = False
|
||||
|
||||
# Used to track changes in the pooling case
|
||||
# where we don't populate the added list.
|
||||
self.batch_changed = False
|
||||
|
||||
def _ensure_removed_sorted(self) -> None:
|
||||
"""Sort removed request indices in
|
||||
descending order.
|
||||
@ -80,6 +84,7 @@ class BatchUpdateBuilder:
|
||||
raise RuntimeError("Cannot register new removed request after"
|
||||
" self.removed has been read.")
|
||||
self._removed.append(index)
|
||||
self.batch_changed = True
|
||||
|
||||
def has_removed(self) -> bool:
|
||||
return bool(self._removed)
|
||||
@ -98,9 +103,15 @@ class BatchUpdateBuilder:
|
||||
return self._removed.pop()
|
||||
return None
|
||||
|
||||
def _is_update(self) -> bool:
|
||||
"""True if there is a batch state change"""
|
||||
return any((self._removed, self.moved, self.added))
|
||||
def reset(self) -> bool:
|
||||
"""Returns True if there were any changes to the batch."""
|
||||
self._is_removed_sorted = False
|
||||
self._removed.clear()
|
||||
self.moved.clear()
|
||||
self.added.clear()
|
||||
batch_changed = self.batch_changed
|
||||
self.batch_changed = False
|
||||
return batch_changed
|
||||
|
||||
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
||||
"""Generate a logitsprocs batch update data structure and reset
|
||||
@ -114,7 +125,8 @@ class BatchUpdateBuilder:
|
||||
"""
|
||||
# Reset removal-sorting logic
|
||||
self._is_removed_sorted = False
|
||||
if not self._is_update():
|
||||
self.batch_changed = False
|
||||
if not any((self._removed, self.moved, self.added)):
|
||||
# No update; short-circuit
|
||||
return None
|
||||
# Build batch state update
|
||||
|
@ -65,8 +65,7 @@ class CachedRequestState:
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
return self.prompt_token_ids[idx]
|
||||
else:
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
|
||||
|
||||
class InputBatch:
|
||||
@ -261,30 +260,27 @@ class InputBatch:
|
||||
Not applicable to pooling models.
|
||||
"""
|
||||
|
||||
# Detailed added request metadata is only required for non-pooling
|
||||
# models, to support logitsprocs
|
||||
assert request.sampling_params
|
||||
|
||||
# Fill the next empty index if there is one.
|
||||
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
|
||||
# Append to end otherwise.
|
||||
new_req_index = self.num_reqs
|
||||
|
||||
assert new_req_index < self.max_num_reqs
|
||||
self.batch_update_builder.added.append(
|
||||
(new_req_index, request.sampling_params, request.prompt_token_ids,
|
||||
request.output_token_ids))
|
||||
self.batch_update_builder.batch_changed = True
|
||||
if request.sampling_params:
|
||||
# Detailed added request metadata is only required for non-pooling
|
||||
# models, to support logitsprocs.
|
||||
self.batch_update_builder.added.append(
|
||||
(new_req_index, request.sampling_params,
|
||||
request.prompt_token_ids, request.output_token_ids))
|
||||
|
||||
return new_req_index
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
) -> int:
|
||||
if not self.is_pooling_model:
|
||||
# New request index bookkeeping for autoregressive models.
|
||||
req_index = self._register_add_request(request)
|
||||
else:
|
||||
req_index = self.num_reqs
|
||||
req_index = self._register_add_request(request)
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
@ -389,7 +385,7 @@ class InputBatch:
|
||||
self.logits_processing_needs_token_ids[req_index] = (
|
||||
pooling_params.requires_token_ids)
|
||||
else:
|
||||
raise NotImplementedError(request)
|
||||
raise NotImplementedError("Unrecognized request type")
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
@ -419,13 +415,25 @@ class InputBatch:
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
if not self.is_pooling_model:
|
||||
# Autoregressive models require bookkeeping of removed requests to
|
||||
# support logitsprocs.
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
lora_req_ids = self.lora_id_to_request_ids[lora_id]
|
||||
lora_req_ids.discard(req_id)
|
||||
if not lora_req_ids:
|
||||
del self.lora_id_to_request_ids[lora_id]
|
||||
del self.lora_id_to_lora_request[lora_id]
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
if self.is_pooling_model:
|
||||
self.pooling_params.pop(req_id, None)
|
||||
return req_index
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
@ -439,29 +447,14 @@ class InputBatch:
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
lora_req_ids = self.lora_id_to_request_ids[lora_id]
|
||||
lora_req_ids.discard(req_id)
|
||||
if not lora_req_ids:
|
||||
del self.lora_id_to_request_ids[lora_id]
|
||||
del self.lora_id_to_lora_request[lora_id]
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
self.pooling_params.pop(req_id, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
# For autoregressive models, track detailed request reordering info
|
||||
# to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(i1, i2, MoveDirectionality.SWAP))
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
@ -479,18 +472,6 @@ class InputBatch:
|
||||
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
||||
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
@ -501,18 +482,41 @@ class InputBatch:
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
|
||||
if self.is_pooling_model:
|
||||
# Sampling and logits parameters don't apply to pooling models.
|
||||
return
|
||||
|
||||
# For autoregressive models, track detailed request reordering info
|
||||
# to support logitsprocs.
|
||||
self.batch_update_builder.moved.append(
|
||||
(i1, i2, MoveDirectionality.SWAP))
|
||||
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] = \
|
||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] = \
|
||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] = \
|
||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
|
||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self) -> None:
|
||||
"""Slide non-empty requests down into lower, empty indices.
|
||||
@ -529,12 +533,6 @@ class InputBatch:
|
||||
"""
|
||||
num_reqs = self.num_reqs
|
||||
|
||||
if self.is_pooling_model:
|
||||
# Will be contiguous in pooling case, just trim the lists.
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
return
|
||||
|
||||
if not (empty_req_indices := self.batch_update_builder.removed):
|
||||
# All removed requests were replaced by added requests, or else no
|
||||
# requests were removed at all. No condense() needed
|
||||
@ -562,11 +560,6 @@ class InputBatch:
|
||||
# Move active request down into empty request
|
||||
# index.
|
||||
self.batch_update_builder.pop_removed()
|
||||
# Autoregressive models require detailed tracking of condense
|
||||
# operations to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(last_req_index, empty_index,
|
||||
MoveDirectionality.UNIDIRECTIONAL))
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
@ -587,6 +580,21 @@ class InputBatch:
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
if self.is_pooling_model:
|
||||
last_req_index -= 1
|
||||
# Samping state not used by pooling models.
|
||||
continue
|
||||
|
||||
# Autoregressive models require detailed tracking of condense
|
||||
# operations to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(last_req_index, empty_index,
|
||||
MoveDirectionality.UNIDIRECTIONAL))
|
||||
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
@ -601,9 +609,6 @@ class InputBatch:
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
# TODO convert these to LogitsProcessors
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[
|
||||
@ -626,8 +631,9 @@ class InputBatch:
|
||||
"""Apply any batch updates to sampling metadata."""
|
||||
|
||||
if self.is_pooling_model:
|
||||
# Batch changes every step for pooling models.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
batch_changed = self.batch_update_builder.reset()
|
||||
if batch_changed:
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
return
|
||||
|
||||
# For non-pooling models - generate and apply logitsprocs update;
|
||||
@ -720,7 +726,8 @@ class InputBatch:
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
num_reqs = self.num_reqs
|
||||
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
@ -728,11 +735,10 @@ class InputBatch:
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
||||
num_reqs, :max_prompt_len]
|
||||
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
for i in range(num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
@ -1489,10 +1489,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for raw_output, seq_len, prompt_len in zip(
|
||||
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||||
|
||||
if seq_len == prompt_len:
|
||||
pooler_output.append(raw_output.data)
|
||||
else:
|
||||
pooler_output.append(None)
|
||||
output = raw_output.data if seq_len == prompt_len else None
|
||||
pooler_output.append(output)
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
@ -1522,7 +1520,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
max_query_len) = (self._prepare_inputs(scheduler_output))
|
||||
max_query_len) = self._prepare_inputs(scheduler_output)
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
|
Reference in New Issue
Block a user