[Core][Bugfix] Fix Online MM Beam Search (#19688)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks
2025-06-19 11:18:07 -06:00
committed by GitHub
parent 01220ce89a
commit ead2110297
3 changed files with 45 additions and 12 deletions

View File

@ -25,6 +25,25 @@ TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]
EXPECTED_MM_BEAM_SEARCH_RES = [
[
"The image shows a wooden boardwalk leading through a",
"The image shows a wooden boardwalk extending into a",
],
[
"The image shows two parrots perched on",
"The image shows two birds perched on a cur",
],
[
"The image shows a Venn diagram with three over",
"This image shows a Venn diagram with three over",
],
[
"This image displays a gradient of colors ranging from",
"This image displays a gradient of colors transitioning from",
],
]
@pytest.fixture(scope="module")
def server():
@ -270,10 +289,13 @@ async def test_single_chat_session_image_base64encoded(
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS))))
async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
client: openai.AsyncOpenAI, model_name: str, image_idx: int,
base64_encoded_image: dict[str, str]):
# NOTE: This test also validates that we pass MM data through beam search
image_url = TEST_IMAGE_URLS[image_idx]
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]
messages = [{
"role":
@ -297,10 +319,11 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
messages=messages,
n=2,
max_completion_tokens=10,
temperature=0.0,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content
for actual, expected_str in zip(chat_completion.choices, expected_res):
assert actual.message.content == expected_str
@pytest.mark.asyncio

View File

@ -88,9 +88,18 @@ class EngineClient(ABC):
if processed_inputs["type"] == "embeds":
raise NotImplementedError
prompt_token_ids = processed_inputs["prompt_token_ids"]
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_token_ids = prompt.get("prompt_token_ids")
multi_modal_data = prompt.get("multi_modal_data")
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
tokenized_length = len(prompt_token_ids)

View File

@ -15,7 +15,8 @@ from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
BeamSearchSequence,
create_sort_beams_key_function)
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
is_init_field)
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
@ -575,10 +576,11 @@ class LLM:
lora_requests = self._get_beam_search_lora_requests(
lora_request, prompts)
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer = self.get_tokenizer()
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id,
length_penalty)
length_penalty,
)
def create_tokens_prompt_from_beam(
beam: BeamSearchSequence) -> TokensPrompt:
@ -593,7 +595,6 @@ class LLM:
"mm_processor_kwargs"] = beam.mm_processor_kwargs
return TokensPrompt(**token_prompt_kwargs)
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa