mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Frontend] Optimize beam search performance by limiting concurrency (#23599)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@ -96,7 +96,6 @@ def run_vllm(
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
|
@ -1022,15 +1022,17 @@ class VllmRunner:
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
concurrency_limit: Optional[int] = None,
|
||||
) -> list[tuple[list[list[int]], list[str]]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
outputs = self.llm.beam_search(
|
||||
inputs,
|
||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
||||
outputs = self.llm.beam_search(inputs,
|
||||
BeamSearchParams(beam_width=beam_width,
|
||||
max_tokens=max_tokens),
|
||||
concurrency_limit=concurrency_limit)
|
||||
returned_outputs = []
|
||||
for output in outputs:
|
||||
token_ids = [x.tokens for x in output.sequences]
|
||||
|
@ -67,6 +67,59 @@ def test_beam_search_single_input(
|
||||
f"vLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
||||
def test_beam_search_with_concurrency_limit(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
# example_prompts[1]&[3]&[7] fails due to unknown reason even without
|
||||
# concurency limit. skip them for now.
|
||||
example_prompts = (example_prompts[:8])
|
||||
concurrency_limit = 2
|
||||
assert len(example_prompts) > concurrency_limit
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
outputs_with_limit = vllm_model.generate_beam_search(
|
||||
example_prompts,
|
||||
beam_width,
|
||||
max_tokens,
|
||||
concurrency_limit=concurrency_limit)
|
||||
outputs_without_limit = []
|
||||
|
||||
for i in range(0, len(example_prompts), concurrency_limit):
|
||||
outputs_without_limit.extend(
|
||||
vllm_model.generate_beam_search(
|
||||
example_prompts[i:i + concurrency_limit], beam_width,
|
||||
max_tokens))
|
||||
|
||||
correct = True
|
||||
for i in range(len(example_prompts)):
|
||||
output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i]
|
||||
output_ids_without_limit, output_texts_without_limit = (
|
||||
outputs_without_limit[i])
|
||||
for j, (text_with_limit, text_without_limit) in enumerate(
|
||||
zip(output_texts_with_limit, output_texts_without_limit)):
|
||||
print(f">>>{j}-th with limit output:")
|
||||
print(text_with_limit)
|
||||
print(f">>>{j}-th without limit output:")
|
||||
print(text_without_limit)
|
||||
assert len(output_ids_with_limit) == len(output_ids_without_limit)
|
||||
for j in range(len(output_ids_with_limit)):
|
||||
if output_ids_with_limit[j] != output_ids_without_limit[j]:
|
||||
print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n"
|
||||
f"-limit: {output_ids_without_limit}")
|
||||
correct = False
|
||||
assert correct
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
|
||||
|
@ -523,6 +523,7 @@ class LLM:
|
||||
params: BeamSearchParams,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
use_tqdm: bool = False,
|
||||
concurrency_limit: Optional[int] = None,
|
||||
) -> list[BeamSearchOutput]:
|
||||
"""
|
||||
Generate sequences using beam search.
|
||||
@ -533,6 +534,8 @@ class LLM:
|
||||
params: The beam search parameters.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
concurrency_limit: The maximum number of concurrent requests.
|
||||
If None, the number of concurrent requests is unlimited.
|
||||
"""
|
||||
# TODO: how does beam search work together with length penalty,
|
||||
# frequency, penalty, and stopping criteria, etc.?
|
||||
@ -551,6 +554,15 @@ class LLM:
|
||||
length_penalty,
|
||||
)
|
||||
|
||||
if use_tqdm and concurrency_limit is not None:
|
||||
logger.warning(
|
||||
"Progress bar is not supported when using concurrency_limit. "
|
||||
"Disabling progress bar.")
|
||||
use_tqdm = False
|
||||
|
||||
if concurrency_limit is None:
|
||||
concurrency_limit = len(prompts)
|
||||
|
||||
def create_tokens_prompt_from_beam(
|
||||
beam: BeamSearchSequence) -> TokensPrompt:
|
||||
token_prompt_kwargs: TokensPrompt = {
|
||||
@ -595,73 +607,79 @@ class LLM:
|
||||
**mm_kwargs,
|
||||
), )
|
||||
|
||||
token_iter = range(max_tokens)
|
||||
if use_tqdm:
|
||||
token_iter = tqdm(token_iter,
|
||||
desc="Beam search",
|
||||
unit="token",
|
||||
unit_scale=False)
|
||||
logger.warning(
|
||||
"The progress bar shows the upper bound on token steps and "
|
||||
"may finish early due to stopping conditions. It does not "
|
||||
"reflect instance-level progress.")
|
||||
for prompt_start in range(0, len(prompts), concurrency_limit):
|
||||
instances_batch = instances[prompt_start:prompt_start +
|
||||
concurrency_limit]
|
||||
|
||||
for _ in token_iter:
|
||||
all_beams: list[BeamSearchSequence] = list(
|
||||
sum((instance.beams for instance in instances), []))
|
||||
pos = [0] + list(
|
||||
itertools.accumulate(
|
||||
len(instance.beams) for instance in instances))
|
||||
instance_start_and_end: list[tuple[int, int]] = list(
|
||||
zip(pos[:-1], pos[1:]))
|
||||
token_iter = range(max_tokens)
|
||||
if use_tqdm:
|
||||
token_iter = tqdm(token_iter,
|
||||
desc="Beam search",
|
||||
unit="token",
|
||||
unit_scale=False)
|
||||
logger.warning(
|
||||
"The progress bar shows the upper bound on token steps and "
|
||||
"may finish early due to stopping conditions. It does not "
|
||||
"reflect instance-level progress.")
|
||||
for _ in token_iter:
|
||||
all_beams: list[BeamSearchSequence] = list(
|
||||
sum((instance.beams for instance in instances_batch), []))
|
||||
pos = [0] + list(
|
||||
itertools.accumulate(
|
||||
len(instance.beams) for instance in instances_batch))
|
||||
instance_start_and_end: list[tuple[int, int]] = list(
|
||||
zip(pos[:-1], pos[1:]))
|
||||
|
||||
if len(all_beams) == 0:
|
||||
break
|
||||
if len(all_beams) == 0:
|
||||
break
|
||||
|
||||
# create the corresponding batch entries for prompt & optional lora
|
||||
prompts_batch, lora_req_batch = zip(
|
||||
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
|
||||
for beam in all_beams])
|
||||
# create corresponding batch entries for prompt & optional lora
|
||||
prompts_batch, lora_req_batch = zip(
|
||||
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
|
||||
for beam in all_beams])
|
||||
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
output = self.generate(prompts_batch,
|
||||
sampling_params=beam_search_params,
|
||||
use_tqdm=False,
|
||||
lora_request=lora_req_batch)
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
output = self.generate(prompts_batch,
|
||||
sampling_params=beam_search_params,
|
||||
use_tqdm=False,
|
||||
lora_request=lora_req_batch)
|
||||
|
||||
for (start, end), instance in zip(instance_start_and_end,
|
||||
instances):
|
||||
instance_new_beams = []
|
||||
for i in range(start, end):
|
||||
current_beam = all_beams[i]
|
||||
result = output[i]
|
||||
for (start, end), instance in zip(instance_start_and_end,
|
||||
instances_batch):
|
||||
instance_new_beams = []
|
||||
for i in range(start, end):
|
||||
current_beam = all_beams[i]
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
# if `result.outputs[0].logprobs` is None, it means
|
||||
# the sequence is completed because of the max-model-len
|
||||
# or abortion. we don't need to add it to the new beams.
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.
|
||||
mm_processor_kwargs)
|
||||
if result.outputs[0].logprobs is not None:
|
||||
# if `result.outputs[0].logprobs` is None, it means
|
||||
# the sequence is completed because of the
|
||||
# max-model-len or abortion. we don't need to add
|
||||
# it to the new beams.
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs +
|
||||
[logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.
|
||||
mm_processor_kwargs)
|
||||
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
instance.completed.append(new_beam)
|
||||
else:
|
||||
instance_new_beams.append(new_beam)
|
||||
sorted_beams = sorted(instance_new_beams,
|
||||
key=sort_beams_key,
|
||||
reverse=True)
|
||||
instance.beams = sorted_beams[:beam_width]
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
instance.completed.append(new_beam)
|
||||
else:
|
||||
instance_new_beams.append(new_beam)
|
||||
sorted_beams = sorted(instance_new_beams,
|
||||
key=sort_beams_key,
|
||||
reverse=True)
|
||||
instance.beams = sorted_beams[:beam_width]
|
||||
|
||||
outputs = []
|
||||
for instance in instances:
|
||||
|
Reference in New Issue
Block a user