[Frontend] Optimize beam search performance by limiting concurrency (#23599)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-08-26 21:59:14 -07:00
committed by GitHub
parent 3210264421
commit 142ac08030
4 changed files with 136 additions and 64 deletions

View File

@ -96,7 +96,6 @@ def run_vllm(
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0].expected_output_len
for request in requests:

View File

@ -1022,15 +1022,17 @@ class VllmRunner:
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
concurrency_limit: Optional[int] = None,
) -> list[tuple[list[list[int]], list[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
outputs = self.llm.beam_search(
inputs,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
outputs = self.llm.beam_search(inputs,
BeamSearchParams(beam_width=beam_width,
max_tokens=max_tokens),
concurrency_limit=concurrency_limit)
returned_outputs = []
for output in outputs:
token_ids = [x.tokens for x in output.sequences]

View File

@ -67,6 +67,59 @@ def test_beam_search_single_input(
f"vLLM: {vllm_output_ids}")
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_with_concurrency_limit(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
# example_prompts[1]&[3]&[7] fails due to unknown reason even without
# concurency limit. skip them for now.
example_prompts = (example_prompts[:8])
concurrency_limit = 2
assert len(example_prompts) > concurrency_limit
with vllm_runner(model, dtype=dtype) as vllm_model:
outputs_with_limit = vllm_model.generate_beam_search(
example_prompts,
beam_width,
max_tokens,
concurrency_limit=concurrency_limit)
outputs_without_limit = []
for i in range(0, len(example_prompts), concurrency_limit):
outputs_without_limit.extend(
vllm_model.generate_beam_search(
example_prompts[i:i + concurrency_limit], beam_width,
max_tokens))
correct = True
for i in range(len(example_prompts)):
output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i]
output_ids_without_limit, output_texts_without_limit = (
outputs_without_limit[i])
for j, (text_with_limit, text_without_limit) in enumerate(
zip(output_texts_with_limit, output_texts_without_limit)):
print(f">>>{j}-th with limit output:")
print(text_with_limit)
print(f">>>{j}-th without limit output:")
print(text_without_limit)
assert len(output_ids_with_limit) == len(output_ids_without_limit)
for j in range(len(output_ids_with_limit)):
if output_ids_with_limit[j] != output_ids_without_limit[j]:
print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n"
f"-limit: {output_ids_without_limit}")
correct = False
assert correct
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)

View File

@ -523,6 +523,7 @@ class LLM:
params: BeamSearchParams,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
use_tqdm: bool = False,
concurrency_limit: Optional[int] = None,
) -> list[BeamSearchOutput]:
"""
Generate sequences using beam search.
@ -533,6 +534,8 @@ class LLM:
params: The beam search parameters.
lora_request: LoRA request to use for generation, if any.
use_tqdm: Whether to use tqdm to display the progress bar.
concurrency_limit: The maximum number of concurrent requests.
If None, the number of concurrent requests is unlimited.
"""
# TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.?
@ -551,6 +554,15 @@ class LLM:
length_penalty,
)
if use_tqdm and concurrency_limit is not None:
logger.warning(
"Progress bar is not supported when using concurrency_limit. "
"Disabling progress bar.")
use_tqdm = False
if concurrency_limit is None:
concurrency_limit = len(prompts)
def create_tokens_prompt_from_beam(
beam: BeamSearchSequence) -> TokensPrompt:
token_prompt_kwargs: TokensPrompt = {
@ -595,73 +607,79 @@ class LLM:
**mm_kwargs,
), )
token_iter = range(max_tokens)
if use_tqdm:
token_iter = tqdm(token_iter,
desc="Beam search",
unit="token",
unit_scale=False)
logger.warning(
"The progress bar shows the upper bound on token steps and "
"may finish early due to stopping conditions. It does not "
"reflect instance-level progress.")
for prompt_start in range(0, len(prompts), concurrency_limit):
instances_batch = instances[prompt_start:prompt_start +
concurrency_limit]
for _ in token_iter:
all_beams: list[BeamSearchSequence] = list(
sum((instance.beams for instance in instances), []))
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances))
instance_start_and_end: list[tuple[int, int]] = list(
zip(pos[:-1], pos[1:]))
token_iter = range(max_tokens)
if use_tqdm:
token_iter = tqdm(token_iter,
desc="Beam search",
unit="token",
unit_scale=False)
logger.warning(
"The progress bar shows the upper bound on token steps and "
"may finish early due to stopping conditions. It does not "
"reflect instance-level progress.")
for _ in token_iter:
all_beams: list[BeamSearchSequence] = list(
sum((instance.beams for instance in instances_batch), []))
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances_batch))
instance_start_and_end: list[tuple[int, int]] = list(
zip(pos[:-1], pos[1:]))
if len(all_beams) == 0:
break
if len(all_beams) == 0:
break
# create the corresponding batch entries for prompt & optional lora
prompts_batch, lora_req_batch = zip(
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
for beam in all_beams])
# create corresponding batch entries for prompt & optional lora
prompts_batch, lora_req_batch = zip(
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
for beam in all_beams])
# only runs for one step
# we don't need to use tqdm here
output = self.generate(prompts_batch,
sampling_params=beam_search_params,
use_tqdm=False,
lora_request=lora_req_batch)
# only runs for one step
# we don't need to use tqdm here
output = self.generate(prompts_batch,
sampling_params=beam_search_params,
use_tqdm=False,
lora_request=lora_req_batch)
for (start, end), instance in zip(instance_start_and_end,
instances):
instance_new_beams = []
for i in range(start, end):
current_beam = all_beams[i]
result = output[i]
for (start, end), instance in zip(instance_start_and_end,
instances_batch):
instance_new_beams = []
for i in range(start, end):
current_beam = all_beams[i]
result = output[i]
if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs)
if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the
# max-model-len or abortion. we don't need to add
# it to the new beams.
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.
multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=sort_beams_key,
reverse=True)
instance.beams = sorted_beams[:beam_width]
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=sort_beams_key,
reverse=True)
instance.beams = sorted_beams[:beam_width]
outputs = []
for instance in instances: