[BugFix] Fix server crash on empty prompt (#7746)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser
2024-08-23 10:12:44 -03:00
committed by GitHub
parent faeddb565d
commit e25fee57c2
3 changed files with 39 additions and 0 deletions

View File

@ -0,0 +1,9 @@
import pytest
from vllm import LLM
def test_empty_prompt():
llm = LLM(model="gpt2")
with pytest.raises(ValueError, match='Prompt cannot be empty'):
llm.generate([""])

View File

@ -0,0 +1,22 @@
# imports for guided decoding tests
import re
import openai
import pytest
from ...utils import RemoteOpenAIServer
@pytest.mark.asyncio
async def test_empty_prompt():
model_name = "gpt2"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
with pytest.raises(openai.BadRequestError,
match=re.compile('.+Prompt cannot be empty.+')):
await client.completions.create(model=model_name,
prompt="",
max_tokens=5,
temperature=0.0)

View File

@ -591,6 +591,7 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
@ -1647,3 +1648,10 @@ class LLMEngine:
def is_embedding_model(self):
return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
prompt_key = "encoder_prompt_token_ids" \
if self.is_encoder_decoder_model() else "prompt_token_ids"
if not inputs.get(prompt_key):
raise ValueError("Prompt cannot be empty")