mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Speculative decoding] Support target-model logprobs (#4378)
This commit is contained in:
@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
import time
|
||||
from itertools import cycle
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
|
||||
nvmlInit)
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
@ -13,7 +17,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.sequence import Logprob, MultiModalData
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, random_uuid
|
||||
|
||||
@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
test_name = request.node.name
|
||||
|
||||
def generator_inner():
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=list(range(torch.cuda.device_count())),
|
||||
threshold_bytes=2 * 2**30,
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
use_async = False
|
||||
if "use_async" in kwargs:
|
||||
use_async = kwargs.pop("use_async")
|
||||
print(f'{use_async=}')
|
||||
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
||||
set_random_seed(seed)
|
||||
|
||||
@ -188,6 +199,20 @@ def get_output_from_llm_generator(
|
||||
return tokens, token_ids
|
||||
|
||||
|
||||
def get_logprobs_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> List[List[Dict[int, Logprob]]]:
|
||||
"""Returns a dict of (token_id: Logprob) for each generated position, for
|
||||
each sequence in the batch.
|
||||
"""
|
||||
for llm in llm_generator():
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
|
||||
del llm
|
||||
|
||||
return logprobs
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
|
||||
def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||
threshold_bytes: int,
|
||||
timeout_s: float = 120) -> None:
|
||||
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
||||
# context.
|
||||
nvmlInit()
|
||||
start_time = time.time()
|
||||
while True:
|
||||
output = {}
|
||||
output_raw = {}
|
||||
for device in devices:
|
||||
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
||||
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
||||
gb_used = mem_info.used / 2**30
|
||||
output_raw[device] = gb_used
|
||||
output[device] = f'{gb_used:.02f}'
|
||||
|
||||
print('gpu memory used (GB): ', end='')
|
||||
for k, v in output.items():
|
||||
print(f'{k}={v}; ', end='')
|
||||
print('')
|
||||
|
||||
dur_s = time.time() - start_time
|
||||
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
|
||||
print(f'Done waiting for free GPU memory on devices {devices=} '
|
||||
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
|
||||
break
|
||||
|
||||
if dur_s >= timeout_s:
|
||||
raise ValueError(f'Memory of devices {devices=} not free after '
|
||||
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
||||
|
||||
time.sleep(5)
|
||||
|
335
tests/spec_decode/e2e/test_logprobs.py
Normal file
335
tests/spec_decode/e2e/test_logprobs.py
Normal file
@ -0,0 +1,335 @@
|
||||
import math
|
||||
from itertools import cycle
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_logprobs_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"max_logprobs": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
7,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify output logprobs are equal with and without speculative decoding.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"max_logprobs": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("num_logprobs", [6])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
7,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int,
|
||||
num_logprobs: int):
|
||||
"""Verify output logprobs are equal with and without spec decode.
|
||||
This specifies a number of logprobs >1.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True,
|
||||
logprob_rank=num_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Veriy logprob greedy equality with different speculation lens.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_when_skip_speculation(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify at least one logprob result has num_logprobs+1, which tests the
|
||||
case where the sampled token is not in top-k logprobs.
|
||||
|
||||
Ideally, this test should validate equality with non-spec by getting
|
||||
logprobs. This is left as future improvement.
|
||||
"""
|
||||
batch_size = 8
|
||||
max_output_len = output_len
|
||||
force_output_len = True
|
||||
logprob_rank = 5
|
||||
|
||||
temperature = 1.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
logprobs=logprob_rank,
|
||||
)
|
||||
|
||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
num_returned_logprobs = [
|
||||
len(logprob_dict) for seq_logprobs in spec_batch_logprobs
|
||||
for logprob_dict in seq_logprobs
|
||||
]
|
||||
|
||||
# Assert one of the returned logprobs has > num_logprobs (indicating the
|
||||
# sampled token is not in top-k).
|
||||
assert any([
|
||||
num_returned > logprob_rank for num_returned in num_returned_logprobs
|
||||
])
|
||||
|
||||
|
||||
def run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
logprob_rank: int = 1):
|
||||
"""Helper method that compares the logprobs outputs of both the baseline LLM
|
||||
and the test LLM. It asserts greedy equality of the logprobs when the
|
||||
temperature is zero.
|
||||
"""
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
logprobs=logprob_rank,
|
||||
)
|
||||
|
||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
baseline_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
assert len(baseline_batch_logprobs) == len(prompts)
|
||||
assert len(spec_batch_logprobs) == len(prompts)
|
||||
|
||||
# For each sequence in the batch.
|
||||
for i, (baseline_logprobs, spec_logprobs) in enumerate(
|
||||
zip(baseline_batch_logprobs, spec_batch_logprobs)):
|
||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||
|
||||
# For each generated position of the sequence.
|
||||
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||
zip(spec_logprobs, baseline_logprobs)):
|
||||
|
||||
# Map rank to token/logprob in spec output.
|
||||
spec_rank_to_token_id = {
|
||||
value.rank: key
|
||||
for key, value in spec_pos_logprobs.items()
|
||||
}
|
||||
spec_rank_to_logprob = {
|
||||
value.rank: value.logprob
|
||||
for key, value in spec_pos_logprobs.items()
|
||||
}
|
||||
|
||||
# Map rank to token/logprob in baseline output.
|
||||
baseline_rank_to_token_id = {
|
||||
value.rank: key
|
||||
for key, value in baseline_pos_logprobs.items()
|
||||
}
|
||||
baseline_rank_to_logprob = {
|
||||
value.rank: value.logprob
|
||||
for key, value in baseline_pos_logprobs.items()
|
||||
}
|
||||
|
||||
# Assert set of ranks returned is equal.
|
||||
assert set(spec_rank_to_token_id.keys()) == set(
|
||||
baseline_rank_to_token_id.keys())
|
||||
|
||||
# Assert each logprob/token id is correct, keyed by rank.
|
||||
for rank in sorted(set(spec_rank_to_token_id.keys())):
|
||||
assert spec_rank_to_token_id[
|
||||
rank] == baseline_rank_to_token_id[rank], f"{rank}"
|
||||
assert math.isclose(
|
||||
a=spec_rank_to_logprob[rank],
|
||||
b=baseline_rank_to_logprob[rank],
|
||||
abs_tol=1e-1,
|
||||
)
|
@ -41,24 +41,17 @@ from .conftest import (get_output_from_llm_generator,
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# whether use AsyncLLM engine
|
||||
"use_async": async_mode,
|
||||
}
|
||||
# Try both async and sync engine execution
|
||||
for async_mode in [True, False]
|
||||
])
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
@ -117,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
assert actual_tokens.strip() == expected_tokens.strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Use AsyncLLM engine
|
||||
"use_async": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_with_async_engine(test_llm_generator,
|
||||
baseline_llm_generator,
|
||||
batch_size: int):
|
||||
"""Verify spec decode works well with async LLM engine.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
@ -292,6 +292,10 @@ def test_draft_proposals_full_speculation_len():
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(batch_size,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, ),
|
||||
@ -392,6 +396,10 @@ def test_draft_proposals_mixed_k():
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(expected_num_proposal_seqs,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
|
@ -192,8 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
@ -273,8 +279,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
@ -294,7 +306,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
num_lookahead_slots=k)
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
|
||||
token_ids=rejection_sampler_output.transpose(0, 1),
|
||||
probs=[None for _ in range(k + 1)],
|
||||
logprobs=[None for _ in range(k + 1)])
|
||||
|
||||
seq_ids = [
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
@ -328,7 +342,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
continue
|
||||
assert actual_by_step[i].output_token == expected_by_step[
|
||||
i].output_token
|
||||
assert actual_by_step[i].logprobs == expected_by_step[i].logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2])
|
||||
@ -387,8 +400,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
|
@ -201,6 +201,7 @@ def assert_logprobs_dict_allclose(
|
||||
def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: Iterable[Optional[torch.Tensor]],
|
||||
logprobs: Iterable[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
@ -222,6 +223,7 @@ def create_sampler_output_list(
|
||||
) for seq_index, token_id in enumerate(token_ids_by_step[step])
|
||||
],
|
||||
sampled_token_probs=probs[step],
|
||||
logprobs=logprobs[step],
|
||||
sampled_token_ids=token_ids[step])
|
||||
for step in range(num_steps)
|
||||
]
|
||||
|
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
from typing import Callable, List
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import (
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
@ -48,10 +49,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
# TODO(sang): Prompt logprob currently not implemented in multi step
|
||||
# workers.
|
||||
self._log_prompt_logprob_unsupported_warning_once()
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache()
|
||||
def _log_prompt_logprob_unsupported_warning_once():
|
||||
logger.warning(
|
||||
"Prompt logprob is not supported by multi step workers. "
|
||||
"(e.g., speculative decode uses multi step workers).")
|
||||
pass
|
||||
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
@ -89,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
valid_samples: List[SequenceOutput],
|
||||
sampling_params: SamplingParams) -> None:
|
||||
output_token_ids = [sample.output_token for sample in valid_samples]
|
||||
output_logprobs = [sample.logprobs for sample in valid_samples]
|
||||
|
||||
# Truncate to max_tokens if necessary.
|
||||
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
||||
@ -113,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
|
||||
# Incrementally append tokens to the sequence, as if we had only one new
|
||||
# token.
|
||||
for output_token_id in output_token_ids:
|
||||
for output_token_id, output_logprob in zip(output_token_ids,
|
||||
output_logprobs):
|
||||
seq.append_token_id(
|
||||
token_id=output_token_id,
|
||||
# TODO emit logprobs in multi-step decoding.
|
||||
logprobs={output_token_id: Logprob(0.0)},
|
||||
logprobs=output_logprob,
|
||||
)
|
||||
|
||||
new_char_count = 0
|
||||
|
@ -103,8 +103,7 @@ class Sampler(nn.Module):
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
assert maybe_sampled_tokens_tensor is not None
|
||||
sampled_tokens_tensor = maybe_sampled_tokens_tensor
|
||||
on_device_tensors = (probs, sampled_tokens_tensor)
|
||||
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
||||
else:
|
||||
on_device_tensors = None
|
||||
|
||||
@ -965,8 +964,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
||||
has implications on the overall design of the sampler, e.g. how to record
|
||||
accurate logprobs for the user, so this improvement is deferred to later.
|
||||
"""
|
||||
logprobs[sample_indices, :] = -float('inf')
|
||||
logprobs[sample_indices, greedy_samples] = 0.0
|
||||
# NOTE: logprobs are not modified so they can be returned to the user.
|
||||
probs[sample_indices, :] = 0
|
||||
probs[sample_indices, greedy_samples] = 1.0
|
||||
|
||||
@ -976,7 +974,8 @@ def _build_sampler_output(
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||
sample_logprobs: List[SampleLogprobs],
|
||||
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]],
|
||||
) -> SamplerOutput:
|
||||
"""Construct Python objects with the output of sampling.
|
||||
|
||||
@ -1005,14 +1004,17 @@ def _build_sampler_output(
|
||||
|
||||
# If not specified, store None values in SamplerOutput.
|
||||
if on_device_tensors is not None:
|
||||
sampled_token_probs, sampled_token_ids = on_device_tensors
|
||||
(sampled_token_probs, logprobs_tensor,
|
||||
sampled_token_ids) = on_device_tensors
|
||||
else:
|
||||
sampled_token_probs, sampled_token_ids = (None, None)
|
||||
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
|
||||
None)
|
||||
|
||||
return SamplerOutput(
|
||||
outputs=sampler_output,
|
||||
sampled_token_probs=sampled_token_probs,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=logprobs_tensor,
|
||||
)
|
||||
|
||||
|
||||
|
@ -700,6 +700,9 @@ class SamplerOutput:
|
||||
# On-device tensor containing probabilities of each token.
|
||||
sampled_token_probs: Optional["torch.Tensor"] = None
|
||||
|
||||
# On-device tensor containing the logprobs of each token.
|
||||
logprobs: Optional["torch.Tensor"] = None
|
||||
|
||||
# On-device tensor containing the sampled token ids.
|
||||
sampled_token_ids: Optional["torch.Tensor"] = None
|
||||
|
||||
|
@ -94,7 +94,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
all_tokens, all_probs = self._contract_batch(
|
||||
all_tokens, all_probs, spec_logprobs = self._contract_batch(
|
||||
contracted_bs=len(seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
@ -107,6 +107,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
return SpeculativeScores(
|
||||
probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=spec_logprobs,
|
||||
)
|
||||
|
||||
def _expand_batch(
|
||||
@ -148,12 +149,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_batch(self, contracted_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
proposals: SpeculativeProposals,
|
||||
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||
spec_indices: List[int],
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _contract_batch(
|
||||
self, contracted_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||
non_spec_indices: List[int], spec_indices: List[int],
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
@ -161,8 +162,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
contracted_bs is the original batch size, and the batch size that the
|
||||
target_sampler_output will be contracted to.
|
||||
"""
|
||||
(target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs) = self._split_scoring_output(
|
||||
(target_token_ids, target_probs, target_logprobs,
|
||||
non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs) = self._split_scoring_output(
|
||||
target_sampler_output, num_scoring_tokens)
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
@ -179,6 +181,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
|
||||
self._vocab_size)
|
||||
target_logprobs = target_logprobs.squeeze().reshape(
|
||||
spec_expanded_bs, k + 1, self._vocab_size)
|
||||
|
||||
all_tokens = torch.full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1,
|
||||
@ -189,16 +193,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
self._vocab_size,
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
all_logprobs = torch.full(size=(
|
||||
contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
),
|
||||
fill_value=-float("inf"),
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
|
||||
if non_spec_indices:
|
||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
|
||||
|
||||
if spec_indices:
|
||||
all_tokens[spec_indices] = target_token_ids
|
||||
all_probs[spec_indices] = target_probs
|
||||
all_logprobs[spec_indices] = target_logprobs
|
||||
|
||||
return all_tokens, all_probs
|
||||
return all_tokens, all_probs, all_logprobs
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
@ -308,7 +322,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
|
||||
def _split_scoring_output(
|
||||
self, sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, torch.Tensor]:
|
||||
"""Split the target model output into speculative and non-speculative
|
||||
output.
|
||||
"""
|
||||
@ -328,21 +343,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||
(spec_sampled_tokens, non_spec_sampled_tokens
|
||||
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
||||
(
|
||||
spec_logprobs,
|
||||
non_spec_logprobs,
|
||||
) = sampler_output.logprobs.split(split_sizes)
|
||||
|
||||
# Convert scores to tensors.
|
||||
sampler_output.sampled_token_probs = spec_probs
|
||||
sampler_output.sampled_token_ids = spec_sampled_tokens
|
||||
target_token_ids, target_probs = sampler_output_to_torch(
|
||||
[sampler_output], True)
|
||||
sampler_output.logprobs = spec_logprobs
|
||||
(target_token_ids, target_probs,
|
||||
target_logprobs) = sampler_output_to_torch([sampler_output], True)
|
||||
|
||||
# Convert non-speculative output tokens to tensors.
|
||||
sampler_output.sampled_token_probs = non_spec_probs
|
||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||
non_spec_target_token_ids, non_spec_target_probs = (
|
||||
sampler_output_to_torch([sampler_output], True))
|
||||
sampler_output.logprobs = non_spec_logprobs
|
||||
(non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
|
||||
True)
|
||||
|
||||
return (target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs)
|
||||
return (target_token_ids, target_probs, target_logprobs,
|
||||
non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs)
|
||||
|
||||
def _create_target_seq_id_iterator(
|
||||
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||
|
@ -38,6 +38,11 @@ class SpeculativeScores:
|
||||
# Probabilities of the speculative tokens according to the scoring model.
|
||||
probs: torch.Tensor
|
||||
|
||||
# Log-probabilities of the speculative tokens according to the scoring
|
||||
# model. These values can be used to generate Logprob objects that are
|
||||
# returned to the user.
|
||||
logprobs: torch.Tensor
|
||||
|
||||
# Token ids sampled from the scoring model. Used for speculative bonus
|
||||
# tokens and also non-speculative normal decoding.
|
||||
token_ids: torch.Tensor
|
||||
|
@ -140,11 +140,17 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
device=self.device,
|
||||
)
|
||||
token_probs.scatter_(2, indices, 1)
|
||||
token_logprobs = torch.zeros(
|
||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(len(seq_group_metadata_list)):
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_probs[i],
|
||||
logprobs=token_logprobs,
|
||||
sampled_token_ids=token_ids[i],
|
||||
))
|
||||
return outputs, False
|
||||
|
@ -5,15 +5,16 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs, get_all_seq_ids,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
||||
@ -258,6 +259,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.probs = None
|
||||
sampler_output.sampled_tokens = None
|
||||
sampler_output.logprobs = None
|
||||
return [sampler_output]
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||
@ -298,12 +300,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
)
|
||||
|
||||
#logger.info("verify proposals")
|
||||
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
|
||||
proposal_scores, proposals, k)
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
seq_group_metadata_list, proposal_scores, proposals, k)
|
||||
|
||||
#logger.info("create output list")
|
||||
return self._create_output_sampler_list(seq_group_metadata_list,
|
||||
accepted_token_ids, k)
|
||||
return self._create_output_sampler_list(
|
||||
seq_group_metadata_list,
|
||||
accepted_token_ids,
|
||||
target_logprobs=target_logprobs,
|
||||
k=k)
|
||||
|
||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||
def _verify_tokens(
|
||||
@ -312,9 +317,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposal_scores: SpeculativeScores,
|
||||
proposals: SpeculativeProposals,
|
||||
max_proposal_len: int,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Determine which speculative tokens are accepted using the
|
||||
probabilities of each token according to the proposer and scorer models.
|
||||
|
||||
Returns a tuple of Tensors, one for the accepted token ids and one for
|
||||
the logprobs according to the scoring model.
|
||||
"""
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
|
||||
@ -361,17 +369,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
non_spec_token_ids[:, 1:] = -1
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
logprobs = proposal_scores.logprobs
|
||||
|
||||
# Rearrange so that results are in the order of the original seq group
|
||||
# metadata.
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
return accepted_token_ids
|
||||
return accepted_token_ids, logprobs
|
||||
|
||||
def _create_output_sampler_list(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
||||
k: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Given the accepted token ids, create a list of SamplerOutput.
|
||||
@ -379,30 +389,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
The output is padded with -1 tokens such that each sequence has
|
||||
the same number of outputs.
|
||||
"""
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
batch_size, num_steps = accepted_token_ids.shape
|
||||
|
||||
# shape: [k+1, batch_size]
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
|
||||
1).tolist()
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
|
||||
|
||||
# Get the logprobs/rank of the accepted tokens.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
|
||||
logprob_tensor=target_logprobs_by_step,
|
||||
sampled_token_ids=accepted_token_ids_by_step,
|
||||
)
|
||||
|
||||
# Get the top-k logprobs (which may or may not include the logprob of
|
||||
# the accepted token).
|
||||
(topk_logprobs_by_step,
|
||||
topk_indices_by_step) = target_logprobs_by_step.topk(
|
||||
k=self.scorer_worker.model_config.max_logprobs,
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||
# batch.
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||
|
||||
# Serialize all tensors to CPU Python lists.
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
accepted_token_id_ranks_by_step = (
|
||||
accepted_token_id_ranks_by_step.tolist())
|
||||
accepted_token_id_logprobs_by_step = (
|
||||
accepted_token_id_logprobs_by_step.tolist())
|
||||
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
|
||||
topk_indices_by_step = topk_indices_by_step.tolist()
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
sampler_output_list = []
|
||||
for token_ids_by_step in accepted_token_ids_by_step:
|
||||
if all(token_id == -1 for token_id in token_ids_by_step):
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == -1
|
||||
for token_id in accepted_token_ids_by_step[step_index]):
|
||||
break
|
||||
|
||||
step_output_token_ids = []
|
||||
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
|
||||
for sequence_index in range(batch_size):
|
||||
# Each sequence may have a different num_logprobs; retrieve it.
|
||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||
|
||||
step_output_token_ids.append(
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
# TODO Add verifier logprobs.
|
||||
logprobs={token_id: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
create_sequence_group_output(
|
||||
token_id=accepted_token_ids_by_step[step_index]
|
||||
[sequence_index],
|
||||
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
||||
step_index][sequence_index],
|
||||
token_id_logprob=accepted_token_id_logprobs_by_step[
|
||||
step_index][sequence_index],
|
||||
seq_id=seq_ids[sequence_index],
|
||||
topk_token_ids=topk_indices_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
topk_logprobs=topk_logprobs_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
))
|
||||
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
|
||||
|
@ -166,7 +166,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
|
@ -1,10 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
|
||||
SeqId = int
|
||||
|
||||
@ -21,6 +22,89 @@ def get_all_seq_ids(
|
||||
]))
|
||||
|
||||
|
||||
def get_all_num_logprobs(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
|
||||
|
||||
If the sampling params do not call for any logprobs, return 0 for that
|
||||
sequence.
|
||||
"""
|
||||
|
||||
all_num_logprobs = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
num_logprobs = seq_group_metadata.sampling_params.logprobs
|
||||
if seq_group_metadata.sampling_params.logprobs is None:
|
||||
num_logprobs = 0
|
||||
all_num_logprobs.append(num_logprobs)
|
||||
|
||||
return all_num_logprobs
|
||||
|
||||
|
||||
def get_sampled_token_logprobs(
|
||||
# shape [num_steps, batch_size, vocab_size]
|
||||
logprob_tensor: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
|
||||
"""
|
||||
num_steps, batch_size, vocab_size = logprob_tensor.shape
|
||||
|
||||
selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
|
||||
torch.arange(batch_size),
|
||||
sampled_token_ids, ]
|
||||
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
|
||||
-1, -1, vocab_size)
|
||||
sampled_token_ids_ranks = (logprob_tensor >=
|
||||
expanded_selected_logprobs).sum(-1)
|
||||
|
||||
return sampled_token_ids_ranks, selected_logprobs
|
||||
|
||||
|
||||
def create_sequence_group_output(
|
||||
token_id: int,
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[int],
|
||||
topk_logprobs: List[float],
|
||||
) -> SequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
Args:
|
||||
token_id (int): The sampled token for the sequence.
|
||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||
token_id_logprob (float): The logprob value of the sampled token.
|
||||
seq_id (int): The sequence id.
|
||||
topk_token_ids (List[int]): The list of top-k token ids.
|
||||
topk_logprobs (List[float]): The list of top-k logprobs.
|
||||
"""
|
||||
# vLLM logprobs always include the sampled token. In addition, the user may
|
||||
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
||||
logprobs: Dict[int, Logprob] = {
|
||||
token_id: Logprob(
|
||||
logprob=token_id_logprob,
|
||||
rank=token_id_logprob_rank,
|
||||
),
|
||||
}
|
||||
logprobs.update({
|
||||
topk_token_ids[topk_logprob_index]: Logprob(
|
||||
logprob=topk_logprobs[topk_logprob_index],
|
||||
rank=topk_logprob_index + 1,
|
||||
)
|
||||
for topk_logprob_index, _ in enumerate(topk_token_ids)
|
||||
})
|
||||
|
||||
return SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs=logprobs)
|
||||
],
|
||||
# TODO add prompt logprobs support.
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
|
||||
def split_batch_by_proposal_len(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_lens: List[int], select_proposal_len_zero: bool
|
||||
@ -49,8 +133,8 @@ def split_batch_by_proposal_len(
|
||||
|
||||
|
||||
def sampler_output_to_torch(
|
||||
sampler_output_list: List[SamplerOutput],
|
||||
sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||
|
||||
sampler_transposed here is used as the indicator for whether
|
||||
@ -76,6 +160,15 @@ def sampler_output_to_torch(
|
||||
if sampler_transposed:
|
||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_logprobs = torch.stack(
|
||||
[sampler_output.logprobs for sampler_output in sampler_output_list],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output]
|
||||
sampled_token_ids = torch.stack(
|
||||
[
|
||||
@ -87,7 +180,7 @@ def sampler_output_to_torch(
|
||||
if sampler_transposed:
|
||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||
|
||||
return sampled_token_ids, sampled_token_probs
|
||||
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
|
||||
|
||||
|
||||
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
|
||||
|
Reference in New Issue
Block a user