mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs
with ChunkedPrefill (#10132)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
@ -2,6 +2,7 @@ from itertools import cycle
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
@ -154,6 +155,8 @@ def _check_logprobs_when_output_disabled(
|
||||
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
||||
assert spec_pos_logprob.rank == -1
|
||||
assert spec_pos_logprob.logprob == 0.0
|
||||
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
|
||||
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
|
||||
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||
|
||||
|
||||
@ -244,7 +247,8 @@ def run_equality_correctness_test_tp(model,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: int = 0,
|
||||
temperature: float = 0.0):
|
||||
temperature: float = 0.0,
|
||||
logprobs: Optional[int] = None):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
@ -257,7 +261,6 @@ def run_equality_correctness_test_tp(model,
|
||||
results = []
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
|
||||
for args, env in ((arg1, env1), (arg2, env2)):
|
||||
with RemoteOpenAIServer(model,
|
||||
args,
|
||||
@ -269,12 +272,14 @@ def run_equality_correctness_test_tp(model,
|
||||
prompt=prompts,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
temperature=temperature)
|
||||
temperature=temperature,
|
||||
logprobs=logprobs)
|
||||
|
||||
results.append({
|
||||
"test":
|
||||
"seeded_sampling",
|
||||
"text": [choice.text for choice in completion.choices],
|
||||
"logprobs": [choice.logprobs for choice in completion.choices],
|
||||
"finish_reason":
|
||||
[choice.finish_reason for choice in completion.choices],
|
||||
"usage":
|
||||
@ -284,7 +289,15 @@ def run_equality_correctness_test_tp(model,
|
||||
n = len(results) // 2
|
||||
arg1_results = results[:n]
|
||||
arg2_results = results[n:]
|
||||
# Separate logprobs to avoid asserting exact equality.
|
||||
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
|
||||
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
|
||||
|
||||
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
|
||||
assert arg1_result == arg2_result, (
|
||||
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
|
||||
f"{arg1_result=} != {arg2_result=}")
|
||||
if logprobs:
|
||||
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
|
||||
for l1, l2 in zip(logs1, logs2):
|
||||
assert l1.tokens == l2.tokens
|
||||
|
@ -2,6 +2,8 @@
|
||||
tensor parallelism.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
])])
|
||||
@pytest.mark.parametrize("logprobs", [None, 2])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
logprobs: Optional[int],
|
||||
batch_size: int, seed: int):
|
||||
"""Verify spec decode works well with same and different TP size for
|
||||
the draft model with chunked prefill.
|
||||
"""
|
||||
if logprobs:
|
||||
test_llm_kwargs.extend(
|
||||
["--disable_logprobs_during_spec_decoding", "False"])
|
||||
run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
|
@ -4,26 +4,27 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"enforce_eager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
}])
|
||||
@ -36,12 +37,15 @@ from .conftest import run_equality_correctness_test
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
|
||||
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify output logprobs are equal with and without speculative decoding.
|
||||
seed: int, logprobs: int, prefill_chunk_size: int):
|
||||
"""Verify output logprobs are equal with and without speculative decoding,
|
||||
as well as with and without chunked prefill.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
|
@ -21,6 +21,7 @@ correctess for the target model outputs.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
@ -67,12 +68,14 @@ PRECISION = "float32"
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -119,12 +122,15 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
seed: int, logprobs: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -167,12 +173,14 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -217,13 +225,15 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -267,13 +277,15 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -313,14 +325,17 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
output_len: int, seed: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -361,12 +376,14 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
output_len: int, seed: int, prefill_chunk_size: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
|
@ -25,6 +25,7 @@ import pytest
|
||||
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
@ -66,14 +67,16 @@ PRECISION = "float32"
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("batch_size", [4, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -116,12 +119,19 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
logprobs: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
# NOTE Test is sensitive enough st if we don't enable chunked prefill
|
||||
# scheduling on baseline too, we get slightly different logprobs, ending
|
||||
# up sampling different tokens at the tail (ie top tokens don't change).
|
||||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -162,12 +172,15 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("output_len", [2048])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify acceptance rate with different batch size and large output
|
||||
length."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -204,13 +217,17 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("output_len", [64])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("temperature", [1.0])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
temperature: float, seed: int):
|
||||
temperature: float,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify seeded runs produce the same output."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -266,14 +283,16 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -317,12 +336,14 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality when the vocab dimension is padded
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
|
||||
# Default pad_to is 64, test model has vocab_size of 32000
|
||||
def patched_pad_vocab_size(vocab_size, pad_to=None):
|
||||
@ -373,14 +394,16 @@ def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, seed: int,
|
||||
output_len: int):
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int, output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -418,15 +441,21 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
# Speculative decoding is disabled when sequences reach decoding and the batch
|
||||
# consists of single-token requests. Hence we set `max_num_seqs`
|
||||
# >= `speculative_disable_by_batch_size` to test feature interaction.
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, seed: int,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int,
|
||||
output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@ -460,13 +489,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
output_len: int, prefill_chunk_size: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
|
@ -147,20 +147,20 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
"disable_logprobs_during_spec_decoding": False
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
"disable_logprobs_during_spec_decoding": False
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
@ -192,6 +192,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
prompt_logprobs=2,
|
||||
logprobs=2,
|
||||
disable_logprobs=False,
|
||||
temperature=0.0,
|
||||
ensure_all_accepted=ensure_all_accepted)
|
||||
|
||||
|
@ -26,6 +26,7 @@ for the target model outputs.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
|
||||
@ -49,11 +50,13 @@ from .conftest import run_equality_correctness_test
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_mqa_scorer": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@ -68,15 +71,7 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
if prefill_chunk_size > 0:
|
||||
common_llm_kwargs.update(
|
||||
**{
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": prefill_chunk_size,
|
||||
"max_num_seqs": prefill_chunk_size
|
||||
})
|
||||
else:
|
||||
common_llm_kwargs["enable_chunked_prefill"] = False
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
|
@ -60,6 +60,7 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
scorer_worker = create_worker(Worker, model_name, block_size,
|
||||
num_gpu_blocks, seed)
|
||||
scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer
|
||||
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
scorer_worker.model_runner.model.sampler.\
|
||||
should_modify_greedy_probs_inplace = True
|
||||
|
@ -754,6 +754,7 @@ def test_populate_seq_ids_with_bonus_tokens():
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
accepted_token_ids=accepted_token_ids,
|
||||
target_logprobs=target_token_logprobs,
|
||||
prompt_logprobs=None,
|
||||
k=k,
|
||||
stage_times=(0, 0, 0))
|
||||
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||
|
@ -274,3 +274,15 @@ def create_batch(batch_size,
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
||||
|
||||
def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
|
||||
if prefill_chunk_size > 0:
|
||||
llm_kwargs.update(
|
||||
**{
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": prefill_chunk_size,
|
||||
"max_num_seqs": prefill_chunk_size
|
||||
})
|
||||
else:
|
||||
llm_kwargs["enable_chunked_prefill"] = False
|
||||
|
@ -1685,7 +1685,8 @@ class SpeculativeConfig:
|
||||
raise ValueError("Expect the batch size threshold of disabling "
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{speculative_disable_by_batch_size=}")
|
||||
|
||||
if (enable_chunked_prefill and speculative_model == "eagle"):
|
||||
raise ValueError("Chunked prefill and EAGLE are not compatible.")
|
||||
# TODO: The user should be able to specify revision/max model len
|
||||
# for the draft model. It is not currently supported.
|
||||
draft_revision = None
|
||||
@ -1752,12 +1753,6 @@ class SpeculativeConfig:
|
||||
f"num_speculative_tokens={n_predict}, but "
|
||||
f"{num_speculative_tokens=} was provided.")
|
||||
|
||||
if enable_chunked_prefill and draft_hf_config.model_type in (
|
||||
"medusa", "mlp_speculator", "eagle"):
|
||||
raise ValueError(
|
||||
"Chunked prefill and hidden-state based draft models are "
|
||||
"not compatible.")
|
||||
|
||||
speculative_draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
||||
target_parallel_config,
|
||||
|
@ -1010,8 +1010,23 @@ class LLMEngine:
|
||||
self.speculative_config
|
||||
# Organize outputs by [step][sequence group] instead of
|
||||
# [sequence group][step].
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
outputs, num_seq_groups=len(seq_group_metadata_list))
|
||||
if self.scheduler_config.is_multi_step:
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
outputs, len(seq_group_metadata_list))
|
||||
elif self.speculative_config:
|
||||
# Decodes are multi-steps while prefills are not, outputting at
|
||||
# most 1 token. Separate them so that we can trigger chunk
|
||||
# processing without having to pad or copy over prompts K times
|
||||
# to match decodes structure (costly with prompt_logprobs).
|
||||
num_prefills = sum(sg.is_prompt
|
||||
for sg in seq_group_metadata_list)
|
||||
prefills, decodes = outputs[:num_prefills], outputs[
|
||||
num_prefills:]
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
decodes,
|
||||
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
|
||||
outputs_by_sequence_group = [p.outputs for p in prefills
|
||||
] + outputs_by_sequence_group
|
||||
# We have outputs for multiple steps submitted in a single burst,
|
||||
# so invalidate is_first_step_output.
|
||||
is_first_step_output = None
|
||||
|
@ -83,13 +83,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
|
||||
if not non_spec_indices:
|
||||
# All sequence groups in batch have spec decoding enabled
|
||||
contracted = self._contract_batch_all_spec(
|
||||
return self._contract_batch_all_spec(
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
)
|
||||
else:
|
||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||
contracted = self._contract_batch(
|
||||
return self._contract_batch(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
@ -99,14 +99,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
|
||||
return SpeculativeScores(
|
||||
probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=spec_logprobs,
|
||||
hidden_states=all_hidden_states,
|
||||
)
|
||||
|
||||
def _expand_batch(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
@ -143,13 +135,57 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_non_speculative(
|
||||
self, scores: SpeculativeScores,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
|
||||
has_prompt_log: bool) -> SpeculativeScores:
|
||||
"""
|
||||
Augment input `scores` with non-speculative requests outputs.
|
||||
This includes decode requests with speculation turned off, as well
|
||||
as prefill requests when `enable_chunked_prefill` is set.
|
||||
For the latter, prefills are further separated into terminal and
|
||||
non-terminal chunks (from which no token is sampled).
|
||||
"""
|
||||
if not non_spec_indices:
|
||||
return scores
|
||||
|
||||
if has_prompt_log:
|
||||
# When prompt_logprobs is enabled, prefills yield output token
|
||||
# (and respective prob) in the last entry (prompt|out):
|
||||
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
|
||||
# With chunked prefill, non-terminal chunks have -1 on each
|
||||
# position: they're still picked, but they're discarded later.
|
||||
seq_meta = seq_group_metadata_list
|
||||
nospec_sizes = torch.tensor([
|
||||
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
|
||||
for i in non_spec_indices
|
||||
])
|
||||
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
|
||||
else:
|
||||
# In this case only sampled tokens are returned, select all.
|
||||
nospec_sampled_token_idxs = list(
|
||||
range(len(non_spec_outputs.token_ids)))
|
||||
|
||||
scores.token_ids[non_spec_indices, :1] = \
|
||||
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.probs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.logprobs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
if scores.hidden_states is not None:
|
||||
assert non_spec_outputs.hidden_states is not None
|
||||
scores.hidden_states[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
return scores
|
||||
|
||||
def _contract_batch(
|
||||
self, contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
target_sampler_output: SamplerOutput, proposals: SpeculativeProposals,
|
||||
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||
spec_indices: List[int], k: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
self,
|
||||
contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||
non_spec_indices: List[int], spec_indices: List[int],
|
||||
k: int) -> SpeculativeScores:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
@ -195,23 +231,28 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
# Rule out prefills that produce no tokens.
|
||||
non_spec_indices = [
|
||||
idx for idx in non_spec_indices
|
||||
if contracted_seq_group_metadata_list[idx].do_sample
|
||||
]
|
||||
if len(non_spec_indices):
|
||||
all_tokens[non_spec_indices, :1] = \
|
||||
non_spec_target_token_ids.unsqueeze(1)
|
||||
all_probs[non_spec_indices, :1, :] = \
|
||||
non_spec_target_probs.unsqueeze(1)
|
||||
all_logprobs[non_spec_indices, :1, :] = \
|
||||
non_spec_target_logprobs.unsqueeze(1)
|
||||
if all_hidden_states is not None:
|
||||
assert non_spec_target_hidden_states is not None
|
||||
all_hidden_states[non_spec_indices, :1, :] = \
|
||||
non_spec_target_hidden_states.unsqueeze(1)
|
||||
has_prompt_log = any((sg.sampling_params.prompt_logprobs
|
||||
and sg.sampling_params.prompt_logprobs > 0)
|
||||
for sg in contracted_seq_group_metadata_list)
|
||||
# When prompt logprobs is enabled, lens of returned tensors go from
|
||||
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
|
||||
# We adjust stride accordingly to get the generated tokens and
|
||||
# their probs, but pass on prompt_logprobs as is.
|
||||
prompt_logprobs = None
|
||||
if (not self._scorer_worker.model_runner.disable_logprobs\
|
||||
and has_prompt_log):
|
||||
prompt_logprobs = [
|
||||
o.prompt_logprobs for o in target_sampler_output.outputs
|
||||
]
|
||||
elif not has_prompt_log:
|
||||
# When prompt logprobs are not to be returned,
|
||||
# we can ignore non-terminal chunks (no out token).
|
||||
non_spec_indices = [
|
||||
idx for idx in non_spec_indices
|
||||
if contracted_seq_group_metadata_list[idx].do_sample
|
||||
]
|
||||
|
||||
# "Contract" speculative.
|
||||
if spec_indices:
|
||||
all_tokens[spec_indices] = target_token_ids
|
||||
all_probs[spec_indices] = target_probs
|
||||
@ -219,14 +260,27 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
if all_hidden_states is not None:
|
||||
all_hidden_states[spec_indices] = target_hidden_states
|
||||
|
||||
return all_tokens, all_probs, all_logprobs, all_hidden_states
|
||||
spec_scores = SpeculativeScores(probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=all_logprobs,
|
||||
hidden_states=all_hidden_states,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
||||
non_spec_outputs = SpeculativeScores(
|
||||
probs=non_spec_target_probs,
|
||||
token_ids=non_spec_target_token_ids,
|
||||
logprobs=non_spec_target_logprobs,
|
||||
hidden_states=non_spec_target_hidden_states)
|
||||
# Contract remaining nonspec entries based on non_spec_indices, if any.
|
||||
return self._contract_non_speculative(
|
||||
spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
|
||||
non_spec_outputs, has_prompt_log)
|
||||
|
||||
def _contract_batch_all_spec(
|
||||
self,
|
||||
target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
) -> SpeculativeScores:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
@ -250,8 +304,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
target_hidden_states = target_hidden_states.reshape(
|
||||
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||
|
||||
return (target_token_ids, target_probs, target_logprobs,
|
||||
target_hidden_states)
|
||||
return SpeculativeScores(probs=target_probs,
|
||||
token_ids=target_token_ids,
|
||||
logprobs=target_logprobs,
|
||||
hidden_states=target_hidden_states,
|
||||
prompt_logprobs=None)
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
|
@ -1,10 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Set, Union
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PromptLogprobs
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
@ -54,6 +54,10 @@ class SpeculativeScores:
|
||||
# Optional last hidden states from the scoring model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Scoring model may also return logprobs for prompt tokens
|
||||
# for each request, when chunked prefill is enabled.
|
||||
prompt_logprobs: Optional[List[PromptLogprobs]] = None
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeScores("
|
||||
f"probs={self.probs.shape}, "
|
||||
|
@ -72,9 +72,15 @@ class MQAScorer(SpeculativeScorer):
|
||||
target_token_ids = target_sampler_output.sampled_token_ids
|
||||
target_probs = target_sampler_output.sampled_token_probs
|
||||
target_logprobs = target_sampler_output.logprobs
|
||||
prompt_logprobs = None
|
||||
|
||||
# If all requests have the same number of query tokens, we can avoid
|
||||
# the for loop to build output for better performance.
|
||||
if min(all_proposal_lengths) == k:
|
||||
# Regular decodes only.
|
||||
assert all(not sg.is_prompt
|
||||
for sg in target_seq_group_metadata_list
|
||||
if sg.is_prompt)
|
||||
bs, _ = proposals.proposal_token_ids.shape
|
||||
all_tokens = target_token_ids.reshape(bs, k + 1)
|
||||
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
||||
@ -88,19 +94,56 @@ class MQAScorer(SpeculativeScorer):
|
||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||
fill_value=-float("inf"))
|
||||
target_token_ids = target_token_ids.flatten()
|
||||
start_loc = 0
|
||||
for i, (proposed_len, seq_meta) in enumerate(
|
||||
zip(all_proposal_lengths, target_seq_group_metadata_list)):
|
||||
|
||||
# When prompt logprobs is enabled, lens of returned tensors go from
|
||||
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
|
||||
# We adjust stride accordingly to get the generated tokens and
|
||||
# their probs, but pass on prompt_logprobs as is, since it may be
|
||||
# that n_prompts >> K.
|
||||
has_prompt_log = any((sg.sampling_params.prompt_logprobs
|
||||
and sg.sampling_params.prompt_logprobs > 0)
|
||||
for sg in target_seq_group_metadata_list)
|
||||
# TODO (NickLucche) we should surface `disable_logprobs` as to not
|
||||
# break abstraction to get its value.
|
||||
if (not self._scorer_worker.model_runner.disable_logprobs\
|
||||
and has_prompt_log):
|
||||
prompt_logprobs = [
|
||||
o.prompt_logprobs for o in target_sampler_output.outputs
|
||||
]
|
||||
|
||||
# Split loop into prefill|decode for readability.
|
||||
start_loc, i = 0, 0
|
||||
while i < len(target_seq_group_metadata_list
|
||||
) and target_seq_group_metadata_list[i].is_prompt:
|
||||
seq_meta = target_seq_group_metadata_list[i]
|
||||
end_loc = start_loc
|
||||
if has_prompt_log:
|
||||
end_loc += seq_meta.token_chunk_size
|
||||
elif seq_meta.do_sample:
|
||||
end_loc += 1
|
||||
|
||||
# Skip chunks with no output tokens.
|
||||
if seq_meta.do_sample:
|
||||
output_len = proposed_len + 1
|
||||
end_loc = start_loc + output_len
|
||||
all_tokens[
|
||||
i, :output_len] = target_token_ids[start_loc:end_loc]
|
||||
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
||||
all_logprobs[
|
||||
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||
start_loc = end_loc
|
||||
# Get sampled token (last position in chunk) and its prob.
|
||||
all_tokens[i, 0] = target_token_ids[end_loc - 1]
|
||||
all_probs[i, 0] = target_probs[end_loc - 1]
|
||||
all_logprobs[i, 0] = target_logprobs[end_loc - 1]
|
||||
|
||||
i += 1
|
||||
start_loc = end_loc
|
||||
# Decodes.
|
||||
while i < len(target_seq_group_metadata_list):
|
||||
proposed_len, seq_meta = all_proposal_lengths[
|
||||
i], target_seq_group_metadata_list[i]
|
||||
output_len = proposed_len + 1
|
||||
end_loc = start_loc + output_len
|
||||
all_tokens[
|
||||
i, :output_len] = target_token_ids[start_loc:end_loc]
|
||||
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
||||
all_logprobs[
|
||||
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||
start_loc = end_loc
|
||||
i += 1
|
||||
|
||||
hidden_states = None
|
||||
if target_sampler_output.hidden_states is not None:
|
||||
@ -110,4 +153,5 @@ class MQAScorer(SpeculativeScorer):
|
||||
return SpeculativeScores(probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=all_logprobs,
|
||||
hidden_states=hidden_states)
|
||||
hidden_states=hidden_states,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
@ -563,50 +563,57 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
(seq_id, seq_data) for sg in \
|
||||
execute_model_req.seq_group_metadata_list \
|
||||
for seq_id, seq_data in sg.seq_data.items()
|
||||
if sg.do_sample # ignore empty token sequences
|
||||
]
|
||||
completion_seq_group_output_list: List[
|
||||
CompletionSequenceGroupOutput] = []
|
||||
output_index = 0
|
||||
# Make sure the non-terminal prefill chunks are still aligned with
|
||||
# their own empty output.
|
||||
for seq_group_meta in execute_model_req.seq_group_metadata_list:
|
||||
for idx, seq_group_meta in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
needs_prompt_logprobs = seq_output_prompt_logprobs[idx]
|
||||
seq_id, seq_data = seq_data_entries[idx]
|
||||
if needs_prompt_logprobs:
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
|
||||
# Some of these sequences may belong to non-terminal chunks,
|
||||
# which may still have to report logprobs for prompts.
|
||||
start = 1 if seq_data._num_computed_tokens == 0 \
|
||||
else seq_data._num_computed_tokens
|
||||
end = (seq_data._num_computed_tokens + \
|
||||
seq_group_meta.token_chunk_size)
|
||||
prompt_token_ids = prompt_token_ids[start:end]
|
||||
prompt_logprobs = [
|
||||
create_logprobs_output(
|
||||
token_id=p_token_id,
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
) for p_token_id in prompt_token_ids
|
||||
]
|
||||
else:
|
||||
prompt_logprobs = None
|
||||
|
||||
# Since we can get chunks here, we dont always have a sampled token
|
||||
# (only on last chunk) but we still have to provide an output.
|
||||
if not seq_group_meta.do_sample:
|
||||
completion_seq_group_output_list.append(
|
||||
CompletionSequenceGroupOutput(samples=[],
|
||||
prompt_logprobs=None))
|
||||
else:
|
||||
# Sequence with output.
|
||||
seq_id, seq_data = seq_data_entries[output_index]
|
||||
needs_prompt_logprobs = seq_output_prompt_logprobs[
|
||||
output_index]
|
||||
if needs_prompt_logprobs:
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
prompt_logprobs = [
|
||||
create_logprobs_output(
|
||||
token_id=p_token_id,
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
)
|
||||
# no prompt logprobs for the first token
|
||||
for p_token_id in prompt_token_ids[1:]
|
||||
]
|
||||
else:
|
||||
prompt_logprobs = None
|
||||
completion_seq_group_output_list.append(
|
||||
create_sequence_group_output(
|
||||
token_id=sampled_token_ids_list[output_index][0],
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
seq_id=seq_id,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
prompt_logprobs=prompt_logprobs))
|
||||
output_index += 1
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[], prompt_logprobs=prompt_logprobs))
|
||||
continue
|
||||
|
||||
# Sequence with output.
|
||||
completion_seq_group_output_list.append(
|
||||
create_sequence_group_output(
|
||||
token_id=sampled_token_ids_list[output_index][0],
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
seq_id=seq_id,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
prompt_logprobs=prompt_logprobs))
|
||||
output_index += 1
|
||||
|
||||
return [SamplerOutput(outputs=completion_seq_group_output_list)]
|
||||
|
||||
@ -624,24 +631,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
# Store hidden states from target model execution.
|
||||
# Store hidden states from target model execution, BxD.
|
||||
hidden_states = sampler_output.hidden_states
|
||||
if hidden_states is not None:
|
||||
# remove hidden_states for prompt tokens
|
||||
# TODO Enable `return_hidden_states`: prefill chunks hidden states
|
||||
# are pruned by the logits processor. Also, they should be arranged
|
||||
# back into full-prefill latent. Address it to enable MLPSpeculator.
|
||||
if any(seq.is_prompt
|
||||
for seq in execute_model_req.seq_group_metadata_list):
|
||||
# Only decodes and prefill terminal chunks need a hidden state.
|
||||
seq_group_meta_with_hidden = [
|
||||
sg for sg in execute_model_req.seq_group_metadata_list
|
||||
if sg.do_sample
|
||||
]
|
||||
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
|
||||
# Drop hidden_states with no prediction (eg non-terminal chunks)
|
||||
hidden_states = hidden_states[
|
||||
torch.where(sampler_output.sampled_token_ids -
|
||||
VLLM_INVALID_TOKEN_ID)[0]]
|
||||
if self.previous_hidden_states is None:
|
||||
if self.previous_hidden_states is None and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, execute_model_req.seq_group_metadata_list)
|
||||
else:
|
||||
self.previous_hidden_states.update(
|
||||
hidden_states, execute_model_req.seq_group_metadata_list)
|
||||
hidden_states, seq_group_meta_with_hidden)
|
||||
elif self.previous_hidden_states and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states.update(hidden_states,
|
||||
seq_group_meta_with_hidden)
|
||||
|
||||
if not skip_proposer:
|
||||
# We prepare the prefill hidden states here so that there no
|
||||
@ -752,13 +762,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
]
|
||||
if len(non_spec_indices):
|
||||
all_hidden_states = proposal_scores.hidden_states
|
||||
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
|
||||
if all_hidden_states is not None:
|
||||
prefill_hidden_states = all_hidden_states[non_spec_indices]
|
||||
execute_model_req.previous_hidden_states = \
|
||||
prepare_prefill_hidden_states(prefill_hidden_states)
|
||||
# Sync proposer KV cache for prefills.
|
||||
prefill_req = execute_model_req.clone(non_spec_seqs)
|
||||
# TODO avoid sampling here?
|
||||
self.proposer_worker.execute_model(prefill_req)
|
||||
|
||||
with Timer() as verification_timer:
|
||||
@ -774,6 +784,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
accepted_token_ids,
|
||||
target_logprobs=target_logprobs,
|
||||
prompt_logprobs=proposal_scores.prompt_logprobs
|
||||
if not self._disable_logprobs else None,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
stage_times=stage_times)
|
||||
|
||||
@ -845,19 +857,32 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# metadata.
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
# B x K+1 x D
|
||||
hidden_states = proposal_scores.hidden_states
|
||||
if hidden_states is not None:
|
||||
# Only get terminal hidden states for next step
|
||||
terminal_metadata = [
|
||||
sg for sg in seq_group_metadata_list if sg.do_sample
|
||||
]
|
||||
|
||||
# Contract hidden states based on accepted tokens
|
||||
hs_size = hidden_states.shape[-1]
|
||||
|
||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
||||
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
||||
# Drop non-terminal prefill chunks hidden states.
|
||||
hidden_states = hidden_states[
|
||||
accepted_index != VLLM_INVALID_TOKEN_ID]
|
||||
accepted_index = accepted_index[
|
||||
accepted_index != VLLM_INVALID_TOKEN_ID]
|
||||
assert len(accepted_index) == hidden_states.shape[0] == len(
|
||||
terminal_metadata)
|
||||
index = accepted_index[:, None, None].expand(-1, 1,
|
||||
hs_size) # b x 1 x d
|
||||
second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
||||
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||
# Store hidden states from target model for subsequent decode step
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, seq_group_metadata_list,
|
||||
hidden_states, terminal_metadata,
|
||||
second_last_token_hidden_states)
|
||||
return accepted_token_ids, logprobs
|
||||
|
||||
@ -866,6 +891,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
||||
prompt_logprobs: Optional[
|
||||
torch.Tensor], # shape: [nprompt_tokens, vocab_size]
|
||||
k: int,
|
||||
stage_times: Tuple[float, float, float],
|
||||
) -> List[SamplerOutput]:
|
||||
@ -909,15 +936,89 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
# Non-terminal prefill chunks will end up here as rows with just -1s
|
||||
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
|
||||
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
|
||||
# terminal chunks will only have one generated token at time 0.
|
||||
sampler_output_list: List[SamplerOutput] = []
|
||||
|
||||
# Prefills are not multi-step (return at most 1 token), in order to
|
||||
# avoid padding or repetition to fit decodes, we separate them.
|
||||
for i, sg in enumerate(seq_group_metadata_list):
|
||||
if not sg.is_prompt:
|
||||
# Requests are ordered as prefills|decodes=>no more prefills.
|
||||
break
|
||||
num_logprobs = num_logprobs_per_seq[i]
|
||||
seq_kwargs = dict(token_id=-1,
|
||||
token_id_logprob_rank=0,
|
||||
token_id_logprob=-float('inf'),
|
||||
topk_token_ids=[-1] * num_logprobs,
|
||||
topk_logprobs=[-float('inf')] * num_logprobs,
|
||||
seq_id=seq_ids[i])
|
||||
# Terminal chunk, has token.
|
||||
if sg.do_sample:
|
||||
seq_kwargs.update(
|
||||
dict(
|
||||
token_id=accepted_token_ids[i][0].item(),
|
||||
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
||||
0][i],
|
||||
token_id_logprob=accepted_token_id_logprobs_by_step[0]
|
||||
[i],
|
||||
topk_token_ids=topk_indices_by_step[0][i]
|
||||
[:num_logprobs],
|
||||
# output only so step is 0
|
||||
topk_logprobs=topk_logprobs_by_step[0][i]
|
||||
[:num_logprobs],
|
||||
))
|
||||
needs_plogs = (sg.sampling_params.prompt_logprobs
|
||||
and sg.sampling_params.prompt_logprobs > 0)
|
||||
plogs = None
|
||||
if prompt_logprobs is not None:
|
||||
# Even non-terminal prompt chunks can have logprobs here.
|
||||
plogs = prompt_logprobs[i]
|
||||
elif needs_plogs:
|
||||
# Prompt logprobs are requested but `_disable_logprobs` is set.
|
||||
seq_data = next(iter(sg.seq_data.values()))
|
||||
# Get only the tokens in this chunk!
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
prompt_token_ids = prompt_token_ids[
|
||||
seq_data.
|
||||
_num_computed_tokens:seq_data._num_computed_tokens +
|
||||
sg.token_chunk_size]
|
||||
|
||||
is_first_chunk = seq_data._num_computed_tokens == 0
|
||||
# There's no prob generated for the first token in a sequence.
|
||||
if is_first_chunk:
|
||||
prompt_token_ids = prompt_token_ids[1:]
|
||||
plogs = [
|
||||
create_logprobs_output(
|
||||
token_id=p_token_id,
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
) for p_token_id in prompt_token_ids
|
||||
]
|
||||
seq_kwargs.update(dict(prompt_logprobs=plogs))
|
||||
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(
|
||||
outputs=[create_sequence_group_output(
|
||||
**seq_kwargs)])) # type: ignore
|
||||
|
||||
# Decodes, create one SamplerOutput per-step (at most K+1).
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == -1
|
||||
for token_id in accepted_token_ids_by_step[step_index]):
|
||||
if all(token_id == -1 for sg, token_id in zip(
|
||||
seq_group_metadata_list,
|
||||
accepted_token_ids_by_step[step_index])
|
||||
if not sg.is_prompt):
|
||||
break
|
||||
|
||||
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
|
||||
for sequence_index in range(batch_size):
|
||||
seq_meta = seq_group_metadata_list[sequence_index]
|
||||
# Prompts already processed above.
|
||||
if seq_meta.is_prompt:
|
||||
continue
|
||||
|
||||
# Each sequence may have a different num_logprobs; retrieve it.
|
||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||
step_output_token_ids.append(
|
||||
@ -952,6 +1053,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# This is periodic because the rejection sampler emits metrics
|
||||
# periodically.
|
||||
self._maybe_log_stage_times(*stage_times)
|
||||
# First `n_prefills` entries will contain prefills SamplerOutput when
|
||||
# chunked prefill is enabled, the rest is decodes in multi-step format.
|
||||
return sampler_output_list
|
||||
|
||||
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
||||
|
Reference in New Issue
Block a user