[Bugfix]: Fix final_res_batch list index out of range error (#21055)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey
2025-07-17 15:29:09 +08:00
committed by GitHub
parent c5b8b5953a
commit fdc5b43d20
2 changed files with 78 additions and 40 deletions

View File

@ -7,6 +7,7 @@ import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import regex as re
import requests
from openai import BadRequestError
from tests.utils import RemoteOpenAIServer
@ -26,7 +27,8 @@ def default_server_args():
"2048",
"--max-num-seqs",
"128",
"--enforce-eager"
"--enforce-eager",
"--enable-prompt-tokens-details",
]
@ -679,3 +681,17 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
prompt=prompt,
extra_body={"guided_grammar": invalid_simplified_sql_grammar},
)
@pytest.mark.asyncio
async def test_completion_with_empty_prompt_embeds(
client: openai.AsyncOpenAI) -> None:
"""Test completion with empty prompt embeds."""
payload: dict[str, list] = {"prompt_embeds": []}
headers: dict[str, str] = {"Content-Type": "application/json"}
# base_url = http://localhost:8000/v1/completions
response = requests.post(f"{client.base_url}completions",
headers=headers,
json=payload)
assert response.status_code == 200, (
f"Expected status code 200, got {response.status_code}. ")

View File

@ -60,20 +60,25 @@ class OpenAIServingCompletion(OpenAIServing):
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)
super().__init__(
engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default completion sampling params from %s: %s",
source, self.default_sampling_params)
logger.info(
"Using default completion sampling params from %s: %s",
source,
self.default_sampling_params,
)
async def create_completion(
self,
@ -172,23 +177,28 @@ class OpenAIServingCompletion(OpenAIServing):
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params)
default_sampling_params=self.default_sampling_params,
)
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
max_tokens, self.model_config.logits_processor_pattern,
self.default_sampling_params)
max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._log_inputs(
request_id_item,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
@ -245,7 +255,8 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage)
enable_force_include_usage=self.enable_force_include_usage,
)
# Non-streaming response
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
@ -321,10 +332,10 @@ class OpenAIServingCompletion(OpenAIServing):
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage or \
enable_force_include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
include_usage = (stream_options.include_usage
or enable_force_include_usage)
include_continuous_usage = (include_usage and
stream_options.continuous_usage_stats)
else:
include_usage, include_continuous_usage = False, False
@ -370,7 +381,8 @@ class OpenAIServingCompletion(OpenAIServing):
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
*prompt_token_ids,
*output.token_ids,
]
out_logprobs = [
*(prompt_logprobs or []),
@ -383,8 +395,8 @@ class OpenAIServingCompletion(OpenAIServing):
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
if not delta_text and not delta_token_ids \
and not previous_num_tokens[i]:
if (not delta_text and not delta_token_ids
and not previous_num_tokens[i]):
# Chunked prefill case, don't return empty chunks
continue
@ -420,7 +432,8 @@ class OpenAIServingCompletion(OpenAIServing):
finish_reason=finish_reason,
stop_reason=stop_reason,
)
])
],
)
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
@ -438,7 +451,8 @@ class OpenAIServingCompletion(OpenAIServing):
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
total_tokens=total_prompt_tokens + total_completion_tokens,
)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
@ -452,8 +466,8 @@ class OpenAIServingCompletion(OpenAIServing):
choices=[],
usage=final_usage_info,
)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True))
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True)
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
@ -478,8 +492,10 @@ class OpenAIServingCompletion(OpenAIServing):
choices: list[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
kv_transfer_params = None
last_final_res = None
for final_res in final_res_batch:
last_final_res = final_res
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
@ -548,19 +564,22 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
if (self.enable_prompt_tokens_details and last_final_res
and last_final_res.num_cached_tokens):
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
cached_tokens=last_final_res.num_cached_tokens)
request_metadata.final_usage_info = usage
if final_res_batch:
kv_transfer_params = final_res_batch[0].kv_transfer_params
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
kv_transfer_params=final_res_batch[0].kv_transfer_params)
kv_transfer_params=kv_transfer_params,
)
def _create_completion_logprobs(
self,
@ -579,8 +598,9 @@ class OpenAIServingCompletion(OpenAIServing):
last_token_len = 0
should_return_as_token_id = return_as_token_id if \
return_as_token_id is not None else self.return_tokens_as_token_ids
should_return_as_token_id = (return_as_token_id
if return_as_token_id is not None else
self.return_tokens_as_token_ids)
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
@ -612,10 +632,12 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=should_return_as_token_id):
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=should_return_as_token_id,
):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i