mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
EAGLE 3: Fix preamble so that measured speedup over Eagle 1 becomes 32% instead of 5% on MTBench (#25916)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
@ -1151,6 +1151,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
help="Do not oversample if the dataset has " \
|
||||
"fewer samples than num-prompts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-chat-template",
|
||||
action="store_true",
|
||||
help=
|
||||
"Skip applying chat template to prompt for datasets that support it.",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
custom_group = parser.add_argument_group("custom dataset options")
|
||||
@ -1161,12 +1167,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
help=
|
||||
"Number of output tokens per request, used only for custom dataset.",
|
||||
)
|
||||
custom_group.add_argument(
|
||||
"--custom-skip-chat-template",
|
||||
action="store_true",
|
||||
help=
|
||||
"Skip applying chat template to prompt, used only for custom dataset.",
|
||||
)
|
||||
|
||||
spec_bench_group = parser.add_argument_group("spec bench dataset options")
|
||||
spec_bench_group.add_argument(
|
||||
@ -1435,7 +1435,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.custom_output_len,
|
||||
skip_chat_template=args.custom_skip_chat_template,
|
||||
skip_chat_template=args.skip_chat_template,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
)
|
||||
@ -1576,6 +1576,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
output_len=args.hf_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
skip_chat_template=args.skip_chat_template,
|
||||
**hf_kwargs
|
||||
)
|
||||
|
||||
@ -1815,7 +1816,6 @@ class SpecBench(CustomDataset):
|
||||
|
||||
def sample(self, **kwargs) -> list:
|
||||
# leverage CustomDataset sample
|
||||
kwargs["skip_chat_template"] = False
|
||||
return super().sample(**kwargs)
|
||||
|
||||
|
||||
@ -2221,6 +2221,7 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
skip_chat_template: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs) -> list:
|
||||
@ -2236,14 +2237,15 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
)
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
if not skip_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
@ -2284,6 +2286,7 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
skip_chat_template: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
@ -2298,14 +2301,15 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
prompt = item["turns"][0]
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
if not skip_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
@ -2349,6 +2353,7 @@ class BlazeditDataset(HuggingFaceDataset):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
skip_chat_template: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
min_distance: float = 0.0,
|
||||
@ -2372,7 +2377,7 @@ class BlazeditDataset(HuggingFaceDataset):
|
||||
|
||||
# template copied from
|
||||
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
|
||||
instruction = f"""Given a code file, please apply the change requests and generate the new file.
|
||||
prompt = f"""Given a code file, please apply the change requests and generate the new file.
|
||||
|
||||
Original file:
|
||||
```python
|
||||
@ -2385,14 +2390,15 @@ Change request:
|
||||
Please generate the new code file in the "New file" section below.""" # noqa: E501
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": instruction
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
if not skip_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
|
||||
|
Reference in New Issue
Block a user