mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Frontend] Optimize beam search performance by limiting concurrency (#23599)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@ -96,7 +96,6 @@ def run_vllm(
|
|||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
else:
|
else:
|
||||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||||
prompts = [request.prompt for request in requests]
|
|
||||||
# output_len should be the same for all requests.
|
# output_len should be the same for all requests.
|
||||||
output_len = requests[0].expected_output_len
|
output_len = requests[0].expected_output_len
|
||||||
for request in requests:
|
for request in requests:
|
||||||
|
@ -1022,15 +1022,17 @@ class VllmRunner:
|
|||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
|
concurrency_limit: Optional[int] = None,
|
||||||
) -> list[tuple[list[list[int]], list[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
audios=audios)
|
audios=audios)
|
||||||
|
|
||||||
outputs = self.llm.beam_search(
|
outputs = self.llm.beam_search(inputs,
|
||||||
inputs,
|
BeamSearchParams(beam_width=beam_width,
|
||||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
max_tokens=max_tokens),
|
||||||
|
concurrency_limit=concurrency_limit)
|
||||||
returned_outputs = []
|
returned_outputs = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
token_ids = [x.tokens for x in output.sequences]
|
token_ids = [x.tokens for x in output.sequences]
|
||||||
|
@ -67,6 +67,59 @@ def test_beam_search_single_input(
|
|||||||
f"vLLM: {vllm_output_ids}")
|
f"vLLM: {vllm_output_ids}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||||
|
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
||||||
|
def test_beam_search_with_concurrency_limit(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
beam_width: int,
|
||||||
|
) -> None:
|
||||||
|
# example_prompts[1]&[3]&[7] fails due to unknown reason even without
|
||||||
|
# concurency limit. skip them for now.
|
||||||
|
example_prompts = (example_prompts[:8])
|
||||||
|
concurrency_limit = 2
|
||||||
|
assert len(example_prompts) > concurrency_limit
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
outputs_with_limit = vllm_model.generate_beam_search(
|
||||||
|
example_prompts,
|
||||||
|
beam_width,
|
||||||
|
max_tokens,
|
||||||
|
concurrency_limit=concurrency_limit)
|
||||||
|
outputs_without_limit = []
|
||||||
|
|
||||||
|
for i in range(0, len(example_prompts), concurrency_limit):
|
||||||
|
outputs_without_limit.extend(
|
||||||
|
vllm_model.generate_beam_search(
|
||||||
|
example_prompts[i:i + concurrency_limit], beam_width,
|
||||||
|
max_tokens))
|
||||||
|
|
||||||
|
correct = True
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i]
|
||||||
|
output_ids_without_limit, output_texts_without_limit = (
|
||||||
|
outputs_without_limit[i])
|
||||||
|
for j, (text_with_limit, text_without_limit) in enumerate(
|
||||||
|
zip(output_texts_with_limit, output_texts_without_limit)):
|
||||||
|
print(f">>>{j}-th with limit output:")
|
||||||
|
print(text_with_limit)
|
||||||
|
print(f">>>{j}-th without limit output:")
|
||||||
|
print(text_without_limit)
|
||||||
|
assert len(output_ids_with_limit) == len(output_ids_without_limit)
|
||||||
|
for j in range(len(output_ids_with_limit)):
|
||||||
|
if output_ids_with_limit[j] != output_ids_without_limit[j]:
|
||||||
|
print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n"
|
||||||
|
f"-limit: {output_ids_without_limit}")
|
||||||
|
correct = False
|
||||||
|
assert correct
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||||
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
|
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
|
||||||
|
@ -523,6 +523,7 @@ class LLM:
|
|||||||
params: BeamSearchParams,
|
params: BeamSearchParams,
|
||||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||||
use_tqdm: bool = False,
|
use_tqdm: bool = False,
|
||||||
|
concurrency_limit: Optional[int] = None,
|
||||||
) -> list[BeamSearchOutput]:
|
) -> list[BeamSearchOutput]:
|
||||||
"""
|
"""
|
||||||
Generate sequences using beam search.
|
Generate sequences using beam search.
|
||||||
@ -533,6 +534,8 @@ class LLM:
|
|||||||
params: The beam search parameters.
|
params: The beam search parameters.
|
||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
concurrency_limit: The maximum number of concurrent requests.
|
||||||
|
If None, the number of concurrent requests is unlimited.
|
||||||
"""
|
"""
|
||||||
# TODO: how does beam search work together with length penalty,
|
# TODO: how does beam search work together with length penalty,
|
||||||
# frequency, penalty, and stopping criteria, etc.?
|
# frequency, penalty, and stopping criteria, etc.?
|
||||||
@ -551,6 +554,15 @@ class LLM:
|
|||||||
length_penalty,
|
length_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_tqdm and concurrency_limit is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Progress bar is not supported when using concurrency_limit. "
|
||||||
|
"Disabling progress bar.")
|
||||||
|
use_tqdm = False
|
||||||
|
|
||||||
|
if concurrency_limit is None:
|
||||||
|
concurrency_limit = len(prompts)
|
||||||
|
|
||||||
def create_tokens_prompt_from_beam(
|
def create_tokens_prompt_from_beam(
|
||||||
beam: BeamSearchSequence) -> TokensPrompt:
|
beam: BeamSearchSequence) -> TokensPrompt:
|
||||||
token_prompt_kwargs: TokensPrompt = {
|
token_prompt_kwargs: TokensPrompt = {
|
||||||
@ -595,73 +607,79 @@ class LLM:
|
|||||||
**mm_kwargs,
|
**mm_kwargs,
|
||||||
), )
|
), )
|
||||||
|
|
||||||
token_iter = range(max_tokens)
|
for prompt_start in range(0, len(prompts), concurrency_limit):
|
||||||
if use_tqdm:
|
instances_batch = instances[prompt_start:prompt_start +
|
||||||
token_iter = tqdm(token_iter,
|
concurrency_limit]
|
||||||
desc="Beam search",
|
|
||||||
unit="token",
|
|
||||||
unit_scale=False)
|
|
||||||
logger.warning(
|
|
||||||
"The progress bar shows the upper bound on token steps and "
|
|
||||||
"may finish early due to stopping conditions. It does not "
|
|
||||||
"reflect instance-level progress.")
|
|
||||||
|
|
||||||
for _ in token_iter:
|
token_iter = range(max_tokens)
|
||||||
all_beams: list[BeamSearchSequence] = list(
|
if use_tqdm:
|
||||||
sum((instance.beams for instance in instances), []))
|
token_iter = tqdm(token_iter,
|
||||||
pos = [0] + list(
|
desc="Beam search",
|
||||||
itertools.accumulate(
|
unit="token",
|
||||||
len(instance.beams) for instance in instances))
|
unit_scale=False)
|
||||||
instance_start_and_end: list[tuple[int, int]] = list(
|
logger.warning(
|
||||||
zip(pos[:-1], pos[1:]))
|
"The progress bar shows the upper bound on token steps and "
|
||||||
|
"may finish early due to stopping conditions. It does not "
|
||||||
|
"reflect instance-level progress.")
|
||||||
|
for _ in token_iter:
|
||||||
|
all_beams: list[BeamSearchSequence] = list(
|
||||||
|
sum((instance.beams for instance in instances_batch), []))
|
||||||
|
pos = [0] + list(
|
||||||
|
itertools.accumulate(
|
||||||
|
len(instance.beams) for instance in instances_batch))
|
||||||
|
instance_start_and_end: list[tuple[int, int]] = list(
|
||||||
|
zip(pos[:-1], pos[1:]))
|
||||||
|
|
||||||
if len(all_beams) == 0:
|
if len(all_beams) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
# create the corresponding batch entries for prompt & optional lora
|
# create corresponding batch entries for prompt & optional lora
|
||||||
prompts_batch, lora_req_batch = zip(
|
prompts_batch, lora_req_batch = zip(
|
||||||
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
|
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
|
||||||
for beam in all_beams])
|
for beam in all_beams])
|
||||||
|
|
||||||
# only runs for one step
|
# only runs for one step
|
||||||
# we don't need to use tqdm here
|
# we don't need to use tqdm here
|
||||||
output = self.generate(prompts_batch,
|
output = self.generate(prompts_batch,
|
||||||
sampling_params=beam_search_params,
|
sampling_params=beam_search_params,
|
||||||
use_tqdm=False,
|
use_tqdm=False,
|
||||||
lora_request=lora_req_batch)
|
lora_request=lora_req_batch)
|
||||||
|
|
||||||
for (start, end), instance in zip(instance_start_and_end,
|
for (start, end), instance in zip(instance_start_and_end,
|
||||||
instances):
|
instances_batch):
|
||||||
instance_new_beams = []
|
instance_new_beams = []
|
||||||
for i in range(start, end):
|
for i in range(start, end):
|
||||||
current_beam = all_beams[i]
|
current_beam = all_beams[i]
|
||||||
result = output[i]
|
result = output[i]
|
||||||
|
|
||||||
if result.outputs[0].logprobs is not None:
|
if result.outputs[0].logprobs is not None:
|
||||||
# if `result.outputs[0].logprobs` is None, it means
|
# if `result.outputs[0].logprobs` is None, it means
|
||||||
# the sequence is completed because of the max-model-len
|
# the sequence is completed because of the
|
||||||
# or abortion. we don't need to add it to the new beams.
|
# max-model-len or abortion. we don't need to add
|
||||||
logprobs = result.outputs[0].logprobs[0]
|
# it to the new beams.
|
||||||
for token_id, logprob_obj in logprobs.items():
|
logprobs = result.outputs[0].logprobs[0]
|
||||||
new_beam = BeamSearchSequence(
|
for token_id, logprob_obj in logprobs.items():
|
||||||
tokens=current_beam.tokens + [token_id],
|
new_beam = BeamSearchSequence(
|
||||||
logprobs=current_beam.logprobs + [logprobs],
|
tokens=current_beam.tokens + [token_id],
|
||||||
lora_request=current_beam.lora_request,
|
logprobs=current_beam.logprobs +
|
||||||
cum_logprob=current_beam.cum_logprob +
|
[logprobs],
|
||||||
logprob_obj.logprob,
|
lora_request=current_beam.lora_request,
|
||||||
multi_modal_data=current_beam.multi_modal_data,
|
cum_logprob=current_beam.cum_logprob +
|
||||||
mm_processor_kwargs=current_beam.
|
logprob_obj.logprob,
|
||||||
mm_processor_kwargs)
|
multi_modal_data=current_beam.
|
||||||
|
multi_modal_data,
|
||||||
|
mm_processor_kwargs=current_beam.
|
||||||
|
mm_processor_kwargs)
|
||||||
|
|
||||||
if token_id == tokenizer.eos_token_id and \
|
if token_id == tokenizer.eos_token_id and \
|
||||||
not ignore_eos:
|
not ignore_eos:
|
||||||
instance.completed.append(new_beam)
|
instance.completed.append(new_beam)
|
||||||
else:
|
else:
|
||||||
instance_new_beams.append(new_beam)
|
instance_new_beams.append(new_beam)
|
||||||
sorted_beams = sorted(instance_new_beams,
|
sorted_beams = sorted(instance_new_beams,
|
||||||
key=sort_beams_key,
|
key=sort_beams_key,
|
||||||
reverse=True)
|
reverse=True)
|
||||||
instance.beams = sorted_beams[:beam_width]
|
instance.beams = sorted_beams[:beam_width]
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
|
Reference in New Issue
Block a user