mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Speculative decoding] Improve n-gram efficiency (#4724)
This commit is contained in:
@ -34,8 +34,8 @@ def test_ngram_algo_correctness_for_single_no_match():
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window (0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(0, 3)
|
||||
# set ngram window [1, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
@ -90,8 +90,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window (0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(0, 3)
|
||||
# set ngram window [1, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
@ -128,11 +128,12 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([5])
|
||||
|
||||
# the first sequence has no match so proposal_len should be overwritten to 0
|
||||
assert proposals.proposal_lens.tolist(
|
||||
) == [proposal_len for _ in range(4)] + [0]
|
||||
) == [0] + [proposal_len for _ in range(3)] + [0]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == 0
|
||||
assert proposals.proposal_token_ids[0][i] == -1
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
|
||||
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
|
||||
@ -167,8 +168,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window (0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(0, 3)
|
||||
# set ngram window [0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find candidate 12,13,14,15,16
|
||||
|
@ -784,12 +784,15 @@ class SpeculativeConfig:
|
||||
draft_quantization = None
|
||||
|
||||
if speculative_model == "[ngram]":
|
||||
assert (ngram_prompt_lookup_max is not None
|
||||
and ngram_prompt_lookup_max > 0)
|
||||
if ngram_prompt_lookup_min is None:
|
||||
ngram_prompt_lookup_min = 0
|
||||
else:
|
||||
assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
|
||||
ngram_prompt_lookup_min = 1
|
||||
if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
|
||||
raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
|
||||
if ngram_prompt_lookup_min < 1:
|
||||
raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
|
||||
if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
|
||||
raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
|
||||
f"larger than {ngram_prompt_lookup_max=}")
|
||||
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
|
@ -77,9 +77,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
arr = []
|
||||
has_spec_out = False
|
||||
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
||||
token_id_list = []
|
||||
token_prob_list = []
|
||||
for idx, seq_group_metadata in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||
@ -89,59 +91,64 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
for ngram_size in range(
|
||||
min(self.ngram_prompt_lookup_max, input_length - 1),
|
||||
self.ngram_prompt_lookup_min,
|
||||
self.ngram_prompt_lookup_min - 1,
|
||||
-1,
|
||||
):
|
||||
ngram_tensor = input_ids[-1 * ngram_size:]
|
||||
windows = input_ids.unfold(dimension=0,
|
||||
size=ngram_size,
|
||||
step=1)
|
||||
matches = (windows == ngram_tensor).all(dim=1)
|
||||
match_indices = matches.nonzero(as_tuple=True)[0]
|
||||
if match_indices.size()[0] > 1:
|
||||
has_spec_out = True
|
||||
res = seq_data.get_token_ids()
|
||||
res = res[match_indices[0] + ngram_size:match_indices[0] +
|
||||
ngram_size + sample_len]
|
||||
res_len = len(res)
|
||||
# pad 0 towards output as sample_len tokens required
|
||||
res += [0] * (sample_len - res_len)
|
||||
ngram_tensor = input_ids[-ngram_size:]
|
||||
proposal_start_idx = None
|
||||
if ngram_size == 1:
|
||||
# Do not match itself and do not use unfold and all
|
||||
matches = (input_ids[:-1] == ngram_tensor)
|
||||
else:
|
||||
windows = input_ids.unfold(dimension=0,
|
||||
size=ngram_size,
|
||||
step=1)
|
||||
# Do not match itself
|
||||
matches = (windows[:-1] == ngram_tensor).all(dim=-1)
|
||||
|
||||
# first_match includes "values" (bool), indicating whether
|
||||
# the match is found, and "indices", indicating the index
|
||||
# of the first match.
|
||||
# Note that "first_match.values.item()" triggers GPU-CPU
|
||||
# sync so it is a bit inefficient, but we have not found
|
||||
# a better way to do this.
|
||||
first_match = matches.max(dim=-1)
|
||||
if first_match.values.item():
|
||||
proposal_start_idx = first_match.indices.add_(ngram_size)
|
||||
spec_indices = (
|
||||
proposal_start_idx).repeat(sample_len) + torch.arange(
|
||||
sample_len, device=self.device)
|
||||
spec_indices.clamp_(max=input_ids.shape[-1] - 1)
|
||||
res = input_ids.gather(dim=-1, index=spec_indices)
|
||||
token_id_list.append(res)
|
||||
token_prob_list.append(
|
||||
torch.nn.functional.one_hot(
|
||||
res,
|
||||
num_classes=self.vocab_size).to(torch.float32))
|
||||
has_spec_out = True
|
||||
break
|
||||
else:
|
||||
# if no candidate found, fill with 0
|
||||
res = [0] * sample_len
|
||||
|
||||
arr.append(res)
|
||||
token_id_list.append(None)
|
||||
token_prob_list.append(None)
|
||||
|
||||
if not has_spec_out:
|
||||
return None, False
|
||||
|
||||
outputs = []
|
||||
token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
|
||||
indices = token_ids.unsqueeze(2)
|
||||
outputs: List[Optional[SamplerOutput]] = []
|
||||
for idx in range(len(execute_model_req.seq_group_metadata_list)):
|
||||
if token_id_list[idx] is None:
|
||||
outputs.append(None)
|
||||
else:
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_prob_list[idx],
|
||||
logprobs=torch.zeros((sample_len, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device),
|
||||
sampled_token_ids=token_id_list[idx],
|
||||
))
|
||||
|
||||
token_probs = torch.zeros(
|
||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||
self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
token_probs.scatter_(2, indices, 1)
|
||||
token_logprobs = torch.zeros(
|
||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||
self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(len(execute_model_req.seq_group_metadata_list)):
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_probs[i],
|
||||
logprobs=token_logprobs[i],
|
||||
sampled_token_ids=token_ids[i],
|
||||
))
|
||||
return outputs, False
|
||||
|
||||
def get_spec_proposals(
|
||||
|
@ -73,6 +73,14 @@ class Top1Proposer(SpeculativeProposer):
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
)
|
||||
(
|
||||
proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._remove_no_proposal_seqs(proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
transposed)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
@ -140,6 +148,61 @@ class Top1Proposer(SpeculativeProposer):
|
||||
nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
|
||||
nonzero_proposal_len_indices, transposed):
|
||||
"""Remove sequences from nonzero_proposal_len_indices and reset
|
||||
their proposal_len to 0 the draft worker does not provide a proposal
|
||||
(maybe_sampler_output=None). This can avoid scoring overheads.
|
||||
"""
|
||||
|
||||
# If maybe_sampler_output is None, then the draft worker did not
|
||||
# provide a proposal for any sequence and thus no action needed.
|
||||
# Also we do not support transposed maybe_sampler_output for now
|
||||
# because it seems not straightforward for draft workers outputting
|
||||
# transposed sampler outputs to handle the case of no proposal.
|
||||
if maybe_sampler_output is None or transposed:
|
||||
return (proposal_lens, maybe_sampler_output,
|
||||
nonzero_proposal_len_indices)
|
||||
|
||||
new_proposal_lens: List[int] = []
|
||||
new_nonzero_proposal_len_indices: List[int] = []
|
||||
new_maybe_sampler_output: List[SamplerOutput] = []
|
||||
nonzero_proposal_len_idx_ptr = 0
|
||||
seq_idx = 0
|
||||
while seq_idx < len(
|
||||
proposal_lens) and nonzero_proposal_len_idx_ptr < len(
|
||||
nonzero_proposal_len_indices):
|
||||
if seq_idx < nonzero_proposal_len_indices[
|
||||
nonzero_proposal_len_idx_ptr]:
|
||||
# Sequence is not in the original nonzero_proposal_len_indices,
|
||||
# meaning that it has a proposal length of 0 before sending to
|
||||
# the draft worker.
|
||||
assert proposal_lens[seq_idx] == 0
|
||||
new_proposal_lens.append(0)
|
||||
else:
|
||||
# Sequence is in the original nonzero_proposal_len_indices
|
||||
if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
|
||||
# but does not have a proposal from the draft worker.
|
||||
new_proposal_lens.append(0)
|
||||
else:
|
||||
# and has a proposal from the draft worker. Add it to the
|
||||
# new nonzero proposal list and keep the sampler output.
|
||||
new_proposal_lens.append(proposal_lens[seq_idx])
|
||||
new_nonzero_proposal_len_indices.append(seq_idx)
|
||||
new_maybe_sampler_output.append(
|
||||
maybe_sampler_output[nonzero_proposal_len_idx_ptr])
|
||||
nonzero_proposal_len_idx_ptr += 1
|
||||
seq_idx += 1
|
||||
|
||||
# The remaining sequences should have proposal length of 0.
|
||||
new_proposal_lens.extend(proposal_lens[seq_idx:])
|
||||
|
||||
# We assume sampler_output will not be a list of all Nones.
|
||||
# In this case this function should not be called.
|
||||
assert new_maybe_sampler_output
|
||||
return (new_proposal_lens, new_maybe_sampler_output,
|
||||
new_nonzero_proposal_len_indices)
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
|
Reference in New Issue
Block a user