EAGLE 3: Fix preamble so that measured speedup over Eagle 1 becomes 32% instead of 5% on MTBench (#25916)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
Ekagra Ranjan
2025-10-02 14:29:35 -04:00
committed by GitHub
parent 1e50f1be70
commit 1cab2f9cad

View File

@ -1151,6 +1151,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
help="Do not oversample if the dataset has " \
"fewer samples than num-prompts.",
)
parser.add_argument(
"--skip-chat-template",
action="store_true",
help=
"Skip applying chat template to prompt for datasets that support it.",
)
# group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options")
@ -1161,12 +1167,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
help=
"Number of output tokens per request, used only for custom dataset.",
)
custom_group.add_argument(
"--custom-skip-chat-template",
action="store_true",
help=
"Skip applying chat template to prompt, used only for custom dataset.",
)
spec_bench_group = parser.add_argument_group("spec bench dataset options")
spec_bench_group.add_argument(
@ -1435,7 +1435,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template,
skip_chat_template=args.skip_chat_template,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
)
@ -1576,6 +1576,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
skip_chat_template=args.skip_chat_template,
**hf_kwargs
)
@ -1815,7 +1816,6 @@ class SpecBench(CustomDataset):
def sample(self, **kwargs) -> list:
# leverage CustomDataset sample
kwargs["skip_chat_template"] = False
return super().sample(**kwargs)
@ -2221,6 +2221,7 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs) -> list:
@ -2236,14 +2237,15 @@ class InstructCoderDataset(HuggingFaceDataset):
)
# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
@ -2284,6 +2286,7 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
@ -2298,14 +2301,15 @@ class MTBenchDataset(HuggingFaceDataset):
prompt = item["turns"][0]
# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
@ -2349,6 +2353,7 @@ class BlazeditDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
skip_chat_template: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
min_distance: float = 0.0,
@ -2372,7 +2377,7 @@ class BlazeditDataset(HuggingFaceDataset):
# template copied from
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
instruction = f"""Given a code file, please apply the change requests and generate the new file.
prompt = f"""Given a code file, please apply the change requests and generate the new file.
Original file:
```python
@ -2385,14 +2390,15 @@ Change request:
Please generate the new code file in the "New file" section below.""" # noqa: E501
# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": instruction
}],
add_generation_prompt=True,
tokenize=False,
)
if not skip_chat_template:
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)