[BugFix] Fix batch updates for pooling models (#23398)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-08-22 17:20:41 -07:00
committed by GitHub
parent 24d0c9e6ed
commit c80c53a30f
3 changed files with 95 additions and 79 deletions

View File

@ -50,6 +50,10 @@ class BatchUpdateBuilder:
self.added = added or []
self._is_removed_sorted = False
# Used to track changes in the pooling case
# where we don't populate the added list.
self.batch_changed = False
def _ensure_removed_sorted(self) -> None:
"""Sort removed request indices in
descending order.
@ -80,6 +84,7 @@ class BatchUpdateBuilder:
raise RuntimeError("Cannot register new removed request after"
" self.removed has been read.")
self._removed.append(index)
self.batch_changed = True
def has_removed(self) -> bool:
return bool(self._removed)
@ -98,9 +103,15 @@ class BatchUpdateBuilder:
return self._removed.pop()
return None
def _is_update(self) -> bool:
"""True if there is a batch state change"""
return any((self._removed, self.moved, self.added))
def reset(self) -> bool:
"""Returns True if there were any changes to the batch."""
self._is_removed_sorted = False
self._removed.clear()
self.moved.clear()
self.added.clear()
batch_changed = self.batch_changed
self.batch_changed = False
return batch_changed
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
"""Generate a logitsprocs batch update data structure and reset
@ -114,7 +125,8 @@ class BatchUpdateBuilder:
"""
# Reset removal-sorting logic
self._is_removed_sorted = False
if not self._is_update():
self.batch_changed = False
if not any((self._removed, self.moved, self.added)):
# No update; short-circuit
return None
# Build batch state update

View File

@ -65,8 +65,7 @@ class CachedRequestState:
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
else:
return self.output_token_ids[idx - self.num_prompt_tokens]
return self.output_token_ids[idx - self.num_prompt_tokens]
class InputBatch:
@ -261,30 +260,27 @@ class InputBatch:
Not applicable to pooling models.
"""
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs
assert request.sampling_params
# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.added.append(
(new_req_index, request.sampling_params, request.prompt_token_ids,
request.output_token_ids))
self.batch_update_builder.batch_changed = True
if request.sampling_params:
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs.
self.batch_update_builder.added.append(
(new_req_index, request.sampling_params,
request.prompt_token_ids, request.output_token_ids))
return new_req_index
def add_request(
self,
request: "CachedRequestState",
) -> int:
if not self.is_pooling_model:
# New request index bookkeeping for autoregressive models.
req_index = self._register_add_request(request)
else:
req_index = self.num_reqs
req_index = self._register_add_request(request)
req_id = request.req_id
if req_index == len(self._req_ids):
@ -389,7 +385,7 @@ class InputBatch:
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError(request)
raise NotImplementedError("Unrecognized request type")
# Add request lora ID
if request.lora_request:
@ -419,13 +415,25 @@ class InputBatch:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
if not self.is_pooling_model:
# Autoregressive models require bookkeeping of removed requests to
# support logitsprocs.
self.batch_update_builder.removed_append(req_index)
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
return req_index
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
@ -439,29 +447,14 @@ class InputBatch:
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
# For autoregressive models, track detailed request reordering info
# to support logitsprocs
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
@ -479,18 +472,6 @@ class InputBatch:
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
@ -501,18 +482,41 @@ class InputBatch:
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
self.block_table.swap_row(i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
if self.is_pooling_model:
# Sampling and logits parameters don't apply to pooling models.
return
# For autoregressive models, track detailed request reordering info
# to support logitsprocs.
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
self.temperature_cpu[i1], self.temperature_cpu[i2] = \
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] = \
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] = \
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
self.block_table.swap_row(i1, i2)
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
@ -529,12 +533,6 @@ class InputBatch:
"""
num_reqs = self.num_reqs
if self.is_pooling_model:
# Will be contiguous in pooling case, just trim the lists.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
return
if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
@ -562,11 +560,6 @@ class InputBatch:
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
@ -587,6 +580,21 @@ class InputBatch:
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
if self.is_pooling_model:
last_req_index -= 1
# Samping state not used by pooling models.
continue
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
@ -601,9 +609,6 @@ class InputBatch:
if generator is not None:
self.generators[empty_index] = generator
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
# TODO convert these to LogitsProcessors
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
@ -626,8 +631,9 @@ class InputBatch:
"""Apply any batch updates to sampling metadata."""
if self.is_pooling_model:
# Batch changes every step for pooling models.
self.sampling_metadata = self._make_sampling_metadata()
batch_changed = self.batch_update_builder.reset()
if batch_changed:
self.sampling_metadata = self._make_sampling_metadata()
return
# For non-pooling models - generate and apply logitsprocs update;
@ -720,7 +726,8 @@ class InputBatch:
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
num_reqs = self.num_reqs
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
@ -728,11 +735,10 @@ class InputBatch:
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:self.
num_reqs, :max_prompt_len]
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
for i in range(num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)

View File

@ -1489,10 +1489,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
if seq_len == prompt_len:
pooler_output.append(raw_output.data)
else:
pooler_output.append(None)
output = raw_output.data if seq_len == prompt_len else None
pooler_output.append(output)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
@ -1522,7 +1520,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = (self._prepare_inputs(scheduler_output))
max_query_len) = self._prepare_inputs(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE