mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core][Bugfix] Fix Online MM Beam Search (#19688)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@ -25,6 +25,25 @@ TEST_IMAGE_URLS = [
|
||||
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
|
||||
]
|
||||
|
||||
EXPECTED_MM_BEAM_SEARCH_RES = [
|
||||
[
|
||||
"The image shows a wooden boardwalk leading through a",
|
||||
"The image shows a wooden boardwalk extending into a",
|
||||
],
|
||||
[
|
||||
"The image shows two parrots perched on",
|
||||
"The image shows two birds perched on a cur",
|
||||
],
|
||||
[
|
||||
"The image shows a Venn diagram with three over",
|
||||
"This image shows a Venn diagram with three over",
|
||||
],
|
||||
[
|
||||
"This image displays a gradient of colors ranging from",
|
||||
"This image displays a gradient of colors transitioning from",
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
@ -270,10 +289,13 @@ async def test_single_chat_session_image_base64encoded(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS))))
|
||||
async def test_single_chat_session_image_base64encoded_beamsearch(
|
||||
client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
||||
client: openai.AsyncOpenAI, model_name: str, image_idx: int,
|
||||
base64_encoded_image: dict[str, str]):
|
||||
# NOTE: This test also validates that we pass MM data through beam search
|
||||
image_url = TEST_IMAGE_URLS[image_idx]
|
||||
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
@ -297,10 +319,11 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
|
||||
messages=messages,
|
||||
n=2,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
assert len(chat_completion.choices) == 2
|
||||
assert chat_completion.choices[
|
||||
0].message.content != chat_completion.choices[1].message.content
|
||||
for actual, expected_str in zip(chat_completion.choices, expected_res):
|
||||
assert actual.message.content == expected_str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -88,9 +88,18 @@ class EngineClient(ABC):
|
||||
if processed_inputs["type"] == "embeds":
|
||||
raise NotImplementedError
|
||||
|
||||
prompt_token_ids = processed_inputs["prompt_token_ids"]
|
||||
# This is a workaround to fix multimodal beam search; this is a
|
||||
# bandaid fix for 2 small problems:
|
||||
# 1. Multi_modal_data on the processed_inputs currently resolves to
|
||||
# `None`.
|
||||
# 2. preprocessing above expands the multimodal placeholders. However,
|
||||
# this happens again in generation, so the double expansion causes
|
||||
# a mismatch.
|
||||
# TODO - would be ideal to handle this more gracefully.
|
||||
prompt_token_ids = prompt.get("prompt_token_ids")
|
||||
multi_modal_data = prompt.get("multi_modal_data")
|
||||
|
||||
prompt_text = processed_inputs.get("prompt")
|
||||
multi_modal_data = processed_inputs.get("multi_modal_data")
|
||||
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
@ -15,7 +15,8 @@ from tqdm.auto import tqdm
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence, get_beam_search_score)
|
||||
BeamSearchSequence,
|
||||
create_sort_beams_key_function)
|
||||
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
|
||||
is_init_field)
|
||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||
@ -575,10 +576,11 @@ class LLM:
|
||||
lora_requests = self._get_beam_search_lora_requests(
|
||||
lora_request, prompts)
|
||||
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(x.tokens, x.cum_logprob,
|
||||
tokenizer.eos_token_id,
|
||||
length_penalty)
|
||||
tokenizer = self.get_tokenizer()
|
||||
sort_beams_key = create_sort_beams_key_function(
|
||||
tokenizer.eos_token_id,
|
||||
length_penalty,
|
||||
)
|
||||
|
||||
def create_tokens_prompt_from_beam(
|
||||
beam: BeamSearchSequence) -> TokensPrompt:
|
||||
@ -593,7 +595,6 @@ class LLM:
|
||||
"mm_processor_kwargs"] = beam.mm_processor_kwargs
|
||||
return TokensPrompt(**token_prompt_kwargs)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
# generate 2 * beam_width candidates at each step
|
||||
# following the huggingface transformers implementation
|
||||
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
|
||||
|
Reference in New Issue
Block a user