[Speculative decoding] Support target-model logprobs (#4378)

This commit is contained in:
Cade Daniel
2024-05-03 15:52:01 -07:00
committed by GitHub
parent 43c413ec57
commit ab50275111
15 changed files with 727 additions and 86 deletions

View File

@ -1,9 +1,13 @@
import asyncio
import time
from itertools import cycle
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import pytest
import ray
import torch
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit)
from tests.conftest import cleanup
from vllm import LLM
@ -13,7 +17,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.sequence import Logprob, MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid
@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
test_name = request.node.name
def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
use_async = False
if "use_async" in kwargs:
use_async = kwargs.pop("use_async")
print(f'{use_async=}')
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
set_random_seed(seed)
@ -188,6 +199,20 @@ def get_output_from_llm_generator(
return tokens, token_ids
def get_logprobs_from_llm_generator(
llm_generator, prompts,
sampling_params) -> List[List[Dict[int, Logprob]]]:
"""Returns a dict of (token_id: Logprob) for each generated position, for
each sequence in the batch.
"""
for llm in llm_generator():
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
del llm
return logprobs
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit()
start_time = time.time()
while True:
output = {}
output_raw = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'
print('gpu memory used (GB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')
dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
break
if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
time.sleep(5)

View File

@ -0,0 +1,335 @@
import math
from itertools import cycle
import pytest
from vllm import SamplingParams
from .conftest import get_logprobs_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"max_logprobs": 6,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
7,
])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify output logprobs are equal with and without speculative decoding.
"""
run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"max_logprobs": 6,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("num_logprobs", [6])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
7,
])
@pytest.mark.parametrize("seed", [1])
def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int,
num_logprobs: int):
"""Verify output logprobs are equal with and without spec decode.
This specifies a number of logprobs >1.
"""
run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
logprob_rank=num_logprobs)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_when_skip_speculation(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify at least one logprob result has num_logprobs+1, which tests the
case where the sampled token is not in top-k logprobs.
Ideally, this test should validate equality with non-spec by getting
logprobs. This is left as future improvement.
"""
batch_size = 8
max_output_len = output_len
force_output_len = True
logprob_rank = 5
temperature = 1.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
logprobs=logprob_rank,
)
spec_batch_logprobs = get_logprobs_from_llm_generator(
test_llm_generator, prompts, sampling_params)
num_returned_logprobs = [
len(logprob_dict) for seq_logprobs in spec_batch_logprobs
for logprob_dict in seq_logprobs
]
# Assert one of the returned logprobs has > num_logprobs (indicating the
# sampled token is not in top-k).
assert any([
num_returned > logprob_rank for num_returned in num_returned_logprobs
])
def run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
logprob_rank: int = 1):
"""Helper method that compares the logprobs outputs of both the baseline LLM
and the test LLM. It asserts greedy equality of the logprobs when the
temperature is zero.
"""
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
logprobs=logprob_rank,
)
spec_batch_logprobs = get_logprobs_from_llm_generator(
test_llm_generator, prompts, sampling_params)
baseline_batch_logprobs = get_logprobs_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_logprobs) == len(prompts)
assert len(spec_batch_logprobs) == len(prompts)
# For each sequence in the batch.
for i, (baseline_logprobs, spec_logprobs) in enumerate(
zip(baseline_batch_logprobs, spec_batch_logprobs)):
assert len(spec_logprobs) == len(baseline_logprobs)
# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):
# Map rank to token/logprob in spec output.
spec_rank_to_token_id = {
value.rank: key
for key, value in spec_pos_logprobs.items()
}
spec_rank_to_logprob = {
value.rank: value.logprob
for key, value in spec_pos_logprobs.items()
}
# Map rank to token/logprob in baseline output.
baseline_rank_to_token_id = {
value.rank: key
for key, value in baseline_pos_logprobs.items()
}
baseline_rank_to_logprob = {
value.rank: value.logprob
for key, value in baseline_pos_logprobs.items()
}
# Assert set of ranks returned is equal.
assert set(spec_rank_to_token_id.keys()) == set(
baseline_rank_to_token_id.keys())
# Assert each logprob/token id is correct, keyed by rank.
for rank in sorted(set(spec_rank_to_token_id.keys())):
assert spec_rank_to_token_id[
rank] == baseline_rank_to_token_id[rank], f"{rank}"
assert math.isclose(
a=spec_rank_to_logprob[rank],
b=baseline_rank_to_logprob[rank],
abs_tol=1e-1,
)

View File

@ -41,24 +41,17 @@ from .conftest import (get_output_from_llm_generator,
@pytest.mark.parametrize(
"common_llm_kwargs",
[
{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# whether use AsyncLLM engine
"use_async": async_mode,
}
# Try both async and sync engine execution
for async_mode in [True, False]
])
# Required for spec decode.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
@ -117,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
assert actual_tokens.strip() == expected_tokens.strip()
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Use AsyncLLM engine
"use_async": True,
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_with_async_engine(test_llm_generator,
baseline_llm_generator,
batch_size: int):
"""Verify spec decode works well with async LLM engine.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{

View File

@ -292,6 +292,10 @@ def test_draft_proposals_full_speculation_len():
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(low=0,
high=vocab_size,
size=(batch_size, ),
@ -392,6 +396,10 @@ def test_draft_proposals_mixed_k():
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(
low=0,
high=vocab_size,

View File

@ -192,8 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
@ -273,8 +279,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
@ -294,7 +306,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
num_lookahead_slots=k)
expected_output = create_sampler_output_list(
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
token_ids=rejection_sampler_output.transpose(0, 1),
probs=[None for _ in range(k + 1)],
logprobs=[None for _ in range(k + 1)])
seq_ids = [
next(iter(seq_group_metadata.seq_data.keys()))
@ -328,7 +342,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
continue
assert actual_by_step[i].output_token == expected_by_step[
i].output_token
assert actual_by_step[i].logprobs == expected_by_step[i].logprobs
@pytest.mark.parametrize('k', [1, 2])
@ -387,8 +400,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]

View File

@ -201,6 +201,7 @@ def assert_logprobs_dict_allclose(
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]],
logprobs: Iterable[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
@ -222,6 +223,7 @@ def create_sampler_output_list(
) for seq_index, token_id in enumerate(token_ids_by_step[step])
],
sampled_token_probs=probs[step],
logprobs=logprobs[step],
sampled_token_ids=token_ids[step])
for step in range(num_steps)
]

View File

@ -1,3 +1,4 @@
import functools
from typing import Callable, List
from transformers import PreTrainedTokenizer
@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import (
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
@ -48,10 +49,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
outputs: List[SequenceGroupOutput]) -> None:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
self._log_prompt_logprob_unsupported_warning_once()
@staticmethod
@functools.lru_cache()
def _log_prompt_logprob_unsupported_warning_once():
logger.warning(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
pass
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
@ -89,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]
# Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
@ -113,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for output_token_id in output_token_ids:
for output_token_id, output_logprob in zip(output_token_ids,
output_logprobs):
seq.append_token_id(
token_id=output_token_id,
# TODO emit logprobs in multi-step decoding.
logprobs={output_token_id: Logprob(0.0)},
logprobs=output_logprob,
)
new_char_count = 0

View File

@ -103,8 +103,7 @@ class Sampler(nn.Module):
if self.include_gpu_probs_tensor:
assert maybe_sampled_tokens_tensor is not None
sampled_tokens_tensor = maybe_sampled_tokens_tensor
on_device_tensors = (probs, sampled_tokens_tensor)
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
on_device_tensors = None
@ -965,8 +964,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
"""
logprobs[sample_indices, :] = -float('inf')
logprobs[sample_indices, greedy_samples] = 0.0
# NOTE: logprobs are not modified so they can be returned to the user.
probs[sample_indices, :] = 0
probs[sample_indices, greedy_samples] = 1.0
@ -976,7 +974,8 @@ def _build_sampler_output(
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]],
) -> SamplerOutput:
"""Construct Python objects with the output of sampling.
@ -1005,14 +1004,17 @@ def _build_sampler_output(
# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
sampled_token_probs, sampled_token_ids = on_device_tensors
(sampled_token_probs, logprobs_tensor,
sampled_token_ids) = on_device_tensors
else:
sampled_token_probs, sampled_token_ids = (None, None)
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
None)
return SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor,
)

View File

@ -700,6 +700,9 @@ class SamplerOutput:
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional["torch.Tensor"] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional["torch.Tensor"] = None

View File

@ -94,7 +94,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch(
all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
@ -107,6 +107,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
)
def _expand_batch(
@ -148,12 +149,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_batch(self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int],
spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
def _contract_batch(
self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
@ -161,8 +162,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
(target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
@ -179,6 +181,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size)
all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1,
@ -189,16 +193,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self._vocab_size,
device=self._device,
dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
return all_tokens, all_probs
return all_tokens, all_probs, all_logprobs
def _create_scoring_model_input(
self,
@ -308,7 +322,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
"""Split the target model output into speculative and non-speculative
output.
"""
@ -328,21 +343,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
(
spec_logprobs,
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)
# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
target_token_ids, target_probs = sampler_output_to_torch(
[sampler_output], True)
sampler_output.logprobs = spec_logprobs
(target_token_ids, target_probs,
target_logprobs) = sampler_output_to_torch([sampler_output], True)
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = (
sampler_output_to_torch([sampler_output], True))
sampler_output.logprobs = non_spec_logprobs
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
True)
return (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs)
return (target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs)
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:

View File

@ -38,6 +38,11 @@ class SpeculativeScores:
# Probabilities of the speculative tokens according to the scoring model.
probs: torch.Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs: torch.Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids: torch.Tensor

View File

@ -140,11 +140,17 @@ class NGramWorker(LoraNotSupportedWorkerBase):
device=self.device,
)
token_probs.scatter_(2, indices, 1)
token_logprobs = torch.zeros(
(len(seq_group_metadata_list), sample_len, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
for i in range(len(seq_group_metadata_list)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_probs[i],
logprobs=token_logprobs,
sampled_token_ids=token_ids[i],
))
return outputs, False

View File

@ -5,15 +5,16 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
@ -258,6 +259,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# overhead when the engine runs in a different process than the workers.
sampler_output.probs = None
sampler_output.sampled_tokens = None
sampler_output.logprobs = None
return [sampler_output]
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
@ -298,12 +300,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
#logger.info("verify proposals")
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k)
accepted_token_ids, target_logprobs = self._verify_tokens(
seq_group_metadata_list, proposal_scores, proposals, k)
#logger.info("create output list")
return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k)
return self._create_output_sampler_list(
seq_group_metadata_list,
accepted_token_ids,
target_logprobs=target_logprobs,
k=k)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
@ -312,9 +317,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list = proposals.proposal_lens.tolist()
@ -361,17 +369,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
logprobs = proposal_scores.logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone()
return accepted_token_ids
return accepted_token_ids, logprobs
def _create_output_sampler_list(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
k: int,
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
@ -379,30 +389,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
seq_ids = get_all_seq_ids(seq_group_metadata_list)
batch_size, num_steps = accepted_token_ids.shape
# shape: [k+1, batch_size]
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
1).tolist()
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(topk_logprobs_by_step,
topk_indices_by_step) = target_logprobs_by_step.topk(
k=self.scorer_worker.model_config.max_logprobs,
dim=-1,
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids = get_all_seq_ids(seq_group_metadata_list)
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize all tensors to CPU Python lists.
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list = []
for token_ids_by_step in accepted_token_ids_by_step:
if all(token_id == -1 for token_id in token_ids_by_step):
for step_index in range(num_steps):
if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
for sequence_index in range(batch_size):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append(
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
# TODO Add verifier logprobs.
logprobs={token_id: Logprob(0.0)},
)
],
prompt_logprobs=None,
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
[sequence_index],
token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index],
token_id_logprob=accepted_token_id_logprobs_by_step[
step_index][sequence_index],
seq_id=seq_ids[sequence_index],
topk_token_ids=topk_indices_by_step[step_index]
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))

View File

@ -166,7 +166,7 @@ class Top1Proposer(SpeculativeProposer):
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch(
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
sampler_output, sampler_transposed)
# Now, reformat the output GPU tensors such that each sequence has

View File

@ -1,10 +1,11 @@
from contextlib import contextmanager
from itertools import chain
from typing import List, Tuple
from typing import Dict, List, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
SeqId = int
@ -21,6 +22,89 @@ def get_all_seq_ids(
]))
def get_all_num_logprobs(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs = []
for seq_group_metadata in seq_group_metadata_list:
num_logprobs = seq_group_metadata.sampling_params.logprobs
if seq_group_metadata.sampling_params.logprobs is None:
num_logprobs = 0
all_num_logprobs.append(num_logprobs)
return all_num_logprobs
def get_sampled_token_logprobs(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor: torch.Tensor,
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps, batch_size, vocab_size = logprob_tensor.shape
selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
torch.arange(batch_size),
sampled_token_ids, ]
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
-1, -1, vocab_size)
sampled_token_ids_ranks = (logprob_tensor >=
expanded_selected_logprobs).sum(-1)
return sampled_token_ids_ranks, selected_logprobs
def create_sequence_group_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
) -> SequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[int]): The list of top-k token ids.
topk_logprobs (List[float]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,
),
}
logprobs.update({
topk_token_ids[topk_logprob_index]: Logprob(
logprob=topk_logprobs[topk_logprob_index],
rank=topk_logprob_index + 1,
)
for topk_logprob_index, _ in enumerate(topk_token_ids)
})
return SequenceGroupOutput(
samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs=logprobs)
],
# TODO add prompt logprobs support.
prompt_logprobs=None,
)
def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool
@ -49,8 +133,8 @@ def split_batch_by_proposal_len(
def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput],
sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
@ -76,6 +160,15 @@ def sampler_output_to_torch(
if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0,
)
if sampler_transposed:
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
# shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack(
[
@ -87,7 +180,7 @@ def sampler_output_to_torch(
if sampler_transposed:
sampled_token_ids = sampled_token_ids.transpose(0, 1)
return sampled_token_ids, sampled_token_probs
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,