diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 5d3469c421..0eb784a9c5 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,4 +1,5 @@ import asyncio +from itertools import cycle from typing import List, Optional, Tuple, Union import pytest @@ -185,3 +186,60 @@ def get_output_from_llm_generator( del llm return tokens, token_ids + + +def run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + print_tokens: bool = False): + """Helper method that compares the outputs of both the baseline LLM and + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly + the same when temperature is zero. + """ + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + ) + + spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + (baseline_batch_tokens, + baseline_batch_token_ids) = get_output_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + assert len(baseline_batch_token_ids) == len(prompts) + assert len(spec_batch_token_ids) == len(prompts) + + for i, (baseline_token_ids, baseline_tokens, spec_token_ids, + spec_tokens) in enumerate( + zip(baseline_batch_token_ids, baseline_batch_tokens, + spec_batch_token_ids, spec_batch_tokens)): + if print_tokens: + print(f'{i=} {baseline_tokens=}') + print(f'{i=} {spec_tokens=}') + print(f'{i=} {baseline_token_ids=}') + print(f'{i=} {spec_token_ids=}') + assert baseline_token_ids == spec_token_ids diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py similarity index 88% rename from tests/spec_decode/e2e/test_correctness.py rename to tests/spec_decode/e2e/test_multistep_correctness.py index ab8d913fb8..f99e0f6778 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -35,7 +35,8 @@ from transformers import AutoTokenizer from vllm import SamplingParams -from .conftest import get_output_from_llm_generator +from .conftest import (get_output_from_llm_generator, + run_greedy_equality_correctness_test) @pytest.mark.parametrize( @@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) - - -def run_greedy_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len, - force_output_len: bool, - print_tokens: bool = False): - """Helper method that compares the outputs of both the baseline LLM and - the test LLM. It asserts greedy equality, e.g. that the outputs are exactly - the same when temperature is zero. - """ - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - "San Francisco is know for its", - "Facebook was created in 2004 by", - "Curious George is a", - "Python 3.11 brings improvements to its", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - # If the test requires that we generated max_output_len tokens, then set the - # sampling params to ignore eos token. - ignore_eos = force_output_len - - sampling_params = SamplingParams( - max_tokens=max_output_len, - ignore_eos=ignore_eos, - temperature=temperature, - ) - - spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( - test_llm_generator, prompts, sampling_params) - - (baseline_batch_tokens, - baseline_batch_token_ids) = get_output_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - assert len(baseline_batch_token_ids) == len(prompts) - assert len(spec_batch_token_ids) == len(prompts) - - for i, (baseline_token_ids, baseline_tokens, spec_token_ids, - spec_tokens) in enumerate( - zip(baseline_batch_token_ids, baseline_batch_tokens, - spec_batch_token_ids, spec_batch_tokens)): - if print_tokens: - print(f'{i=} {baseline_tokens=}') - print(f'{i=} {spec_tokens=}') - print(f'{i=} {baseline_token_ids=}') - print(f'{i=} {spec_token_ids=}') - assert baseline_token_ids == spec_token_ids diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py new file mode 100644 index 0000000000..44ef400c91 --- /dev/null +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -0,0 +1,172 @@ +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding, +and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775. +Since there is no model is needed for generate the proposal, we could make +the testcase much simpler than drafter multi-step one. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various ngram sizes / speculative sizes + +With those tests, we can say at least, ngram spec would not break the correctess +for the target model outputs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-68m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize("output_len", [ + 256, +]) +@pytest.mark.parametrize("batch_size", [1, 64]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-160m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 3, + } + # Try a range of common k, as well as large speculation. + for k in [1, 3, 5] + ] + [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 1, + } + # Try a range of common k, as well as large speculation. + for k in [1, 3, 5] + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index e7aaa1ff4e..98f2731de9 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -6,8 +6,8 @@ import torch from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplerOutput -from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer, - MultiStepWorker) +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker from .utils import (assert_logprobs_dict_allclose, create_batch, @@ -117,8 +117,8 @@ def test_same_output_for_single_step(): zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) - actual_output = multi_step_worker.execute_model_multi_step( - **multi_step_execute_model_data.to_dict(), num_steps=num_steps) + actual_output, _ = multi_step_worker.sampler_output( + **multi_step_execute_model_data.to_dict(), sample_len=num_steps) assert len(actual_output) == num_steps actual_output = actual_output[0] @@ -200,8 +200,8 @@ def test_same_output_for_multi_step(): # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) - multi_step_output = multi_step_worker.execute_model_multi_step( - **execute_model_data.to_dict(), num_steps=num_steps) + multi_step_output, _ = multi_step_worker.sampler_output( + **execute_model_data.to_dict(), sample_len=num_steps) # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) @@ -266,7 +266,7 @@ def test_same_output_for_multi_step(): @torch.inference_mode() def test_draft_proposals_full_speculation_len(): - """Verify DraftModelTop1Proposer correctly handles case where all sequences + """Verify Top1Proposer correctly handles case where all sequences can speculate. """ k = 10 @@ -275,13 +275,13 @@ def test_draft_proposals_full_speculation_len(): device = 'cuda:0' draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=2048, vocab_size=vocab_size, + max_proposal_len=2048, ) - draft_worker.execute_model_multi_step.return_value = [ + draft_worker.sampler_output.return_value = [ SamplerOutput( outputs=[], sampled_token_probs=torch.rand(batch_size, @@ -294,13 +294,13 @@ def test_draft_proposals_full_speculation_len(): device=device, dtype=torch.long), ) for _ in range(k) - ] + ], True execute_model_data, _, _ = create_batch(batch_size, k) proposals = proposer.get_proposals( **execute_model_data.to_dict(), - max_proposal_len=k, + proposal_len=k, ) assert torch.is_tensor(proposals.proposal_token_ids) @@ -315,7 +315,7 @@ def test_draft_proposals_full_speculation_len(): @torch.inference_mode() def test_draft_proposals_no_speculations(): - """Verify DraftModelTop1Proposer correctly handles case where no sequences + """Verify Top1Proposer correctly handles case where no sequences can speculate. """ k = 10 @@ -325,11 +325,11 @@ def test_draft_proposals_no_speculations(): prompt_len = 10 draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=prompt_len + k - 1, vocab_size=vocab_size, + max_proposal_len=prompt_len + k - 1, ) execute_model_data, _, _ = create_batch(batch_size, @@ -338,7 +338,7 @@ def test_draft_proposals_no_speculations(): proposals = proposer.get_proposals( **execute_model_data.to_dict(), - max_proposal_len=k, + proposal_len=k, ) assert torch.is_tensor(proposals.proposal_token_ids) @@ -353,7 +353,7 @@ def test_draft_proposals_no_speculations(): @torch.inference_mode() def test_draft_proposals_mixed_k(): - """Verify DraftModelTop1Proposer correctly handles case some sequences can + """Verify Top1Proposer correctly handles case some sequences can speculate and some can't. """ k = 10 @@ -374,14 +374,14 @@ def test_draft_proposals_mixed_k(): for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len] draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=long_prompt_len + prev_output_token_len + k - 1, vocab_size=vocab_size, + max_proposal_len=long_prompt_len + prev_output_token_len + k - 1, ) - draft_worker.execute_model_multi_step.return_value = [ + draft_worker.sampler_output.return_value = [ SamplerOutput( outputs=[], sampled_token_probs=torch.rand(expected_num_proposal_seqs, @@ -395,7 +395,7 @@ def test_draft_proposals_mixed_k(): device=device, dtype=torch.long), ) for _ in range(k) - ] + ], True execute_model_data, _, _ = create_batch( batch_size, @@ -406,7 +406,7 @@ def test_draft_proposals_mixed_k(): proposals = proposer.get_proposals( **execute_model_data.to_dict(), - max_proposal_len=k, + proposal_len=k, ) assert torch.is_tensor(proposals.proposal_token_ids) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py new file mode 100644 index 0000000000..ee41350157 --- /dev/null +++ b/tests/spec_decode/test_ngram_worker.py @@ -0,0 +1,206 @@ +import torch + +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.top1_proposer import Top1Proposer + +from .utils import (create_execute_model_data, + create_seq_group_metadata_from_prompts, create_worker) + + +def test_ngram_algo_correctness_for_single_no_match(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario cannot find any candidate in one single batch + """ + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window (0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(0, 3) + + prompts = [ + # shall find no candidate + [1, 2, 3, 4, 5, 6, 7], + ] + + proposal_len = 5 + final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + ngram_sampler_output_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens)) + + proposals = proposer.get_proposals( + **ngram_sampler_output_data.to_dict(), + proposal_len=proposal_len, + ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([1]) + assert proposals.proposal_lens.tolist() == [0] + + +def test_ngram_algo_correctness_for_batches_not_match_all(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario find some candidate not full in batchs + """ + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window (0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(0, 3) + + prompts = [ + # shall find no candidate + [1, 2, 3, 4, 5, 6, 7], + # shall find candidate 12,13,14,15,16 + [11, 12, 13, 14, 15, 16, 11], + # shall find candidate 23,24,25,26,21 + [21, 21, 22, 23, 24, 25, 26, 21, 22], + # shall find candidate 34,35,36,37,38 + [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], + # shall find no candidate as exceed max_proposal_len + [ + 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37, + 38, 31, 32, 33 + ], + ] + + proposal_len = 5 + final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + ngram_sampler_output_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens)) + + proposals = proposer.get_proposals( + **ngram_sampler_output_data.to_dict(), + proposal_len=proposal_len, + ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([5]) + + assert proposals.proposal_lens.tolist( + ) == [proposal_len for _ in range(4)] + [0] + + for i in range(proposal_len): + assert proposals.proposal_token_ids[0][i] == 0 + assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1] + assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3] + assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5] + assert proposals.proposal_token_ids[4][i] == -1 + + +def test_ngram_algo_correctness_for_batches_match_all(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario find candidate in all batchs + """ + + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window (0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(0, 3) + + prompts = [ + # shall find candidate 12,13,14,15,16 + [11, 12, 13, 14, 15, 16, 11], + # shall find candidate 23,24,25,26,21 + [21, 21, 22, 23, 24, 25, 26, 21, 22], + # shall find candidate 34,35,36,37,38 + [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], + ] + + proposal_len = 5 + final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + ngram_sampler_output_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens)) + + proposals = proposer.get_proposals( + **ngram_sampler_output_data.to_dict(), + proposal_len=proposal_len, + ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([3]) + + assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)] + + for i in range(proposal_len): + assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1] + assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3] + assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5] diff --git a/vllm/config.py b/vllm/config.py index db4398adda..257d49b6e8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -682,6 +682,8 @@ class SpeculativeConfig: speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -708,6 +710,10 @@ class SpeculativeConfig: use_v2_block_manager (bool): Whether vLLM is configured to use the v2 block manager or not. Used for raising an error since the v2 block manager is required with spec decode. + ngram_prompt_lookup_max (Optional[int]): Max size of ngram token + window, if provided. + ngram_prompt_lookup_min (Optional[int]): Min size of ngram token + window, if provided. Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if @@ -742,39 +748,57 @@ class SpeculativeConfig: draft_code_revision = None draft_quantization = None - draft_model_config = ModelConfig( - model=speculative_model, - tokenizer=target_model_config.tokenizer, - tokenizer_mode=target_model_config.tokenizer_mode, - trust_remote_code=target_model_config.trust_remote_code, - dtype=target_model_config.dtype, - seed=target_model_config.seed, - revision=draft_revision, - code_revision=draft_code_revision, - tokenizer_revision=target_model_config.tokenizer_revision, - max_model_len=None, - quantization=draft_quantization, - enforce_eager=target_model_config.enforce_eager, - max_context_len_to_capture=target_model_config. - max_context_len_to_capture, - max_logprobs=target_model_config.max_logprobs, - ) + if speculative_model == "[ngram]": + assert (ngram_prompt_lookup_max is not None + and ngram_prompt_lookup_max > 0) + if ngram_prompt_lookup_min is None: + ngram_prompt_lookup_min = 0 + else: + assert ngram_prompt_lookup_max > ngram_prompt_lookup_min - draft_model_config.max_model_len = ( - SpeculativeConfig._maybe_override_draft_max_model_len( - speculative_max_model_len, - draft_model_config.max_model_len, - target_model_config.max_model_len, - )) + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + draft_model_config = target_model_config + draft_parallel_config = target_parallel_config + else: + ngram_prompt_lookup_max = 0 + ngram_prompt_lookup_min = 0 + draft_model_config = ModelConfig( + model=speculative_model, + tokenizer=target_model_config.tokenizer, + tokenizer_mode=target_model_config.tokenizer_mode, + trust_remote_code=target_model_config.trust_remote_code, + dtype=target_model_config.dtype, + seed=target_model_config.seed, + revision=draft_revision, + code_revision=draft_code_revision, + tokenizer_revision=target_model_config.tokenizer_revision, + max_model_len=None, + quantization=draft_quantization, + enforce_eager=target_model_config.enforce_eager, + max_context_len_to_capture=target_model_config. + max_context_len_to_capture, + max_logprobs=target_model_config.max_logprobs, + ) - draft_parallel_config = ( - SpeculativeConfig.create_draft_parallel_config( - target_parallel_config)) + draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + speculative_max_model_len, + draft_model_config.max_model_len, + target_model_config.max_model_len, + )) + + draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + target_parallel_config)) return SpeculativeConfig( draft_model_config, draft_parallel_config, num_speculative_tokens, + ngram_prompt_lookup_max, + ngram_prompt_lookup_min, ) @staticmethod @@ -842,6 +866,8 @@ class SpeculativeConfig: draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, + ngram_prompt_lookup_max: int, + ngram_prompt_lookup_min: int, ): """Create a SpeculativeConfig object. @@ -854,6 +880,8 @@ class SpeculativeConfig: self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min self._verify_args() @@ -877,7 +905,10 @@ class SpeculativeConfig: return self.num_speculative_tokens def __repr__(self) -> str: - draft_model = self.draft_model_config.model + if self.ngram_prompt_lookup_max > 0: + draft_model = "[ngram]" + else: + draft_model = self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bd6437ee44..7637616ae6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -75,6 +75,8 @@ class EngineArgs: speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None speculative_max_model_len: Optional[int] = None + ngram_prompt_lookup_max: Optional[int] = None + ngram_prompt_lookup_min: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -449,6 +451,20 @@ class EngineArgs: 'draft model. Sequences over this length will skip ' 'speculation.') + parser.add_argument( + '--ngram-prompt-lookup-max', + type=int, + default=EngineArgs.ngram_prompt_lookup_max, + help='Max size of window for ngram prompt lookup in speculative ' + 'decoding.') + + parser.add_argument( + '--ngram-prompt-lookup-min', + type=int, + default=EngineArgs.ngram_prompt_lookup_min, + help='Min size of window for ngram prompt lookup in speculative ' + 'decoding.') + parser.add_argument('--model-loader-extra-config', type=str, default=EngineArgs.model_loader_extra_config, @@ -502,6 +518,8 @@ class EngineArgs: speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, + ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, ) scheduler_config = SchedulerConfig( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 527a14ff6c..a58856a12f 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -73,7 +73,6 @@ class GPUExecutor(ExecutorBase): """ assert self.speculative_config is not None - from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker target_worker = self._create_worker() @@ -86,10 +85,11 @@ class GPUExecutor(ExecutorBase): # TODO allow draft-model specific load config. #load_config=self.load_config, ) - draft_worker = MultiStepWorker(**draft_worker_kwargs) - spec_decode_worker = SpecDecodeWorker.from_workers( - proposer_worker=draft_worker, scorer_worker=target_worker) + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + ) assert self.parallel_config.world_size == 1, ( "GPUExecutor only supports single GPU.") diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index c29b838f85..8b113e9347 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -333,13 +333,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens target_token_ids, target_probs = sampler_output_to_torch( - [sampler_output]) + [sampler_output], True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens non_spec_target_token_ids, non_spec_target_probs = ( - sampler_output_to_torch([sampler_output])) + sampler_output_to_torch([sampler_output], True)) return (target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 7cf338bbae..d031bc85af 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,12 +1,11 @@ import copy -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import torch from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.spec_decode.interfaces import (SpeculativeProposals, - SpeculativeProposer) -from vllm.spec_decode.util import sampler_output_to_torch +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker @@ -26,29 +25,37 @@ class MultiStepWorker(Worker): super().__init__(*args, **kwargs) # Lazy initialization list. - self._proposer: DraftModelTop1Proposer + self._proposer: Top1Proposer def init_device(self): super().init_device() - self._proposer = DraftModelTop1Proposer( + self._proposer = Top1Proposer( self, self.device, - self.max_model_len, self.vocab_size, + max_proposal_len=self.max_model_len, ) + def set_include_gpu_probs_tensor(self): + # Need include_gpu_probs_tensor for multi_step_worker + self.model_runner.model.sampler.include_gpu_probs_tensor = True + @torch.inference_mode() - def execute_model_multi_step( + def sampler_output( self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - num_steps: int, - ) -> List[SamplerOutput]: - """Run the model forward pass num_steps times. Returns the list of - sampler output, one per model forward pass. + sample_len: int, + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass sample_len times. Returns the list of + sampler output, one per model forward pass, along with indicator of + whether torch tensor in sampler output need to be transposed in latter + sampler_output_to_torch logic. + + For multi step worker, this indicator shall be True. """ self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) @@ -58,12 +65,12 @@ class MultiStepWorker(Worker): copied_seq_group_metadata_list = self._shallow_copy_inputs( seq_group_metadata_list) - # Assert enough KV space for num_steps tokens per sequence. - self._assert_enough_kv_space(seq_group_metadata_list, num_steps) + # Assert enough KV space for sample_len tokens per sequence. + self._assert_enough_kv_space(seq_group_metadata_list, sample_len) - # Run model num_steps times. + # Run model sample_len times. model_outputs = [] - for _ in range(num_steps): + for _ in range(sample_len): model_output = super().execute_model( seq_group_metadata_list=copied_seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, @@ -78,7 +85,7 @@ class MultiStepWorker(Worker): copied_seq_group_metadata_list) model_outputs.append(model_output) - return model_outputs + return model_outputs, True def get_spec_proposals( self, @@ -206,171 +213,3 @@ class MultiStepWorker(Worker): for seq_group_metadata in seq_group_metadata_list): raise NotImplementedError( "MultiStepWorker does not support beam search.") - - -class DraftModelTop1Proposer(SpeculativeProposer): - """Helper class which separates out sequences which would exceed the max - model length when speculated upon. - - This allows combinations of models such as JackFram/llama-68m draft with - meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of - 2048 while Llama2-13b has max_position_embeddings of 4096. - - We treat the sequences which exceed the proposal draft model length as - "non-spec sequences". Essentially they skip the draft model and go through - normal decoding in the target model. - - Currently, only proposal_lens of 0 and k are supported, where k is a global - batch proposal length. In the future vLLM should support per-sequence - proposal lengths. - """ - - def __init__( - self, - draft_worker: MultiStepWorker, - device: str, - max_model_len: int, - vocab_size: int, - ): - self._draft_worker = draft_worker - self._device = device - self._max_model_len = max_model_len - self._vocab_size = vocab_size - - def get_proposals( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, - ) -> SpeculativeProposals: - """Get speculative proposals given the input batch. - - Sequences which would exceed the max model length are skipped during - speculation. - """ - - # Split speculative- and non-speculative- sequences. - (proposal_lens, nonzero_proposal_len_seqs, - nonzero_proposal_len_indices) = self._split_by_max_model_len( - seq_group_metadata_list, max_proposal_len) - - if nonzero_proposal_len_seqs: - # Speculate tokens using the draft worker for the speculative - # sequences. - maybe_sampler_output = self._draft_worker.execute_model_multi_step( - seq_group_metadata_list=nonzero_proposal_len_seqs, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_steps=max_proposal_len, - ) - else: - # If no sequences can be speculated, set sampler output to None. - maybe_sampler_output = None - - # Combine speculative- and non-speculative sequences into the same - # representation. - proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( - batch_size=len(seq_group_metadata_list), - max_proposal_len=max_proposal_len, - maybe_sampler_output=maybe_sampler_output, - proposal_lens=proposal_lens, - nonzero_proposal_len_indices=nonzero_proposal_len_indices, - ) - - proposals = SpeculativeProposals( - proposal_token_ids=proposal_tokens, - proposal_probs=proposal_probs, - proposal_lens=proposal_lens, - ) - - return proposals - - def _split_by_max_model_len( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - max_proposal_len: int, - ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: - """Determine which sequences would exceed the max model length. - """ - - proposal_lens: List[int] = [] - nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] - nonzero_proposal_len_indices: List[int] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_data = next(iter(seq_group_metadata.seq_data.values())) - seq_len = seq_data.get_len() - - # Currently only proposal lens of 0 or the global batch proposal len - # are supported. - if seq_len + max_proposal_len < self._max_model_len: - proposal_lens.append(max_proposal_len) - nonzero_proposal_len_seqs.append(seq_group_metadata) - nonzero_proposal_len_indices.append(i) - else: - proposal_lens.append(0) - - return (proposal_lens, nonzero_proposal_len_seqs, - nonzero_proposal_len_indices) - - def _merge_outputs( - self, - batch_size: int, - max_proposal_len: int, - maybe_sampler_output: Optional[SamplerOutput], - proposal_lens: List[int], - nonzero_proposal_len_indices: List[int], - ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: - """After speculations are produced, merge the speculation results with - the skipped sequences. - """ - if maybe_sampler_output is None: - # If no speculative tokens, the sampler output will be None. - # In this case we return empty proposals. - proposal_tokens = torch.full(size=( - batch_size, - max_proposal_len, - ), - fill_value=-1, - dtype=torch.long, - device=self._device) - proposal_probs = torch.zeros(batch_size, - max_proposal_len, - self._vocab_size, - dtype=torch.float32, - device=self._device) - proposal_lens_tensor = torch.zeros(len(proposal_lens), - dtype=torch.long, - device=self._device) - return proposal_tokens, proposal_probs, proposal_lens_tensor - - sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs = sampler_output_to_torch( - sampler_output) - - # Now, reformat the output GPU tensors such that each sequence has - # a proposal. the proposal can be empty, e.g. [-1, -1, -1] - - entire_proposal_tokens = torch.full(size=(batch_size, - *proposal_tokens.shape[1:]), - fill_value=-1, - dtype=torch.long, - device=self._device) - entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens - entire_proposal_probs = torch.zeros(batch_size, - *proposal_probs.shape[1:], - dtype=torch.float32, - device=self._device) - entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs - - proposal_tokens, proposal_probs = (entire_proposal_tokens, - entire_proposal_probs) - - proposal_lens_tensor = torch.zeros(batch_size, - dtype=torch.long, - device=self._device) - proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len - - return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py new file mode 100644 index 0000000000..696ca96432 --- /dev/null +++ b/vllm/spec_decode/ngram_worker.py @@ -0,0 +1,190 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + + +class NGramWorker(LoraNotSupportedWorkerBase): + """NGramWorker provides a light drafter without need for model. + + Current NGramWorker only implement prompt lookup decoding, + and in future we may also do RAG type drafter and other scenerios + which don't rely on LLM model to give proposals. + """ + + def __init__(self, *args, **kwargs): + # Get local_rank/vocab_size from kwargs attribute + self.local_rank = kwargs["local_rank"] + self.vocab_size = kwargs["model_config"].get_vocab_size() + + # Lazy initialization list. + self._proposer: Top1Proposer + + def set_ngram_window_size(self, ngram_prompt_lookup_min: int, + ngram_prompt_lookup_max: int): + # Search valid candidate window between + # ngram_prompt_lookup_min/ngram_prompt_lookup_max + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + + def init_device(self): + self.device = torch.device(f"cuda:{self.local_rank}") + self.load_model = lambda *args, **kwargs: None + + # Current only support Top1Proposer + self._proposer = Top1Proposer( + self, + device=self.device, + vocab_size=self.vocab_size, + ) + + def set_include_gpu_probs_tensor(self): + # NGram don't need gpu sampler + pass + + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Optional[Dict[int, int]], + blocks_to_swap_out: Optional[Dict[int, int]], + blocks_to_copy: Optional[Dict[int, List[int]]], + ) -> None: + """NGram doesn't depend on model execution, just pass this function""" + pass + + def determine_num_available_blocks(self) -> None: + """NGram doesn't depend on model execution, no need to check blocks""" + pass + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """As there is no cache need to handle, just pass this function""" + pass + + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes.""" + return 0 + + def sampler_output( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + sample_len: int, + ) -> Tuple[Optional[List[SamplerOutput]], bool]: + """NGram match algo to pick proposal candidate. Returns the list of + sampler output, one per SequenceGroupMetadata. + + For ngram worker, we already done needed transposed internal, so the + indicator pass to sampler_output_to_torch shall be False. + """ + self._raise_if_unsupported( + seq_group_metadata_list, + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + ) + + arr = [] + has_spec_out = False + for seq_group_metadata in seq_group_metadata_list: + seq_data = next(iter(seq_group_metadata.seq_data.values())) + + input_ids = torch.as_tensor(seq_data.get_token_ids(), + dtype=torch.long, + device=self.device) + input_length = seq_data.get_len() + + for ngram_size in range( + min(self.ngram_prompt_lookup_max, input_length - 1), + self.ngram_prompt_lookup_min, + -1, + ): + ngram_tensor = input_ids[-1 * ngram_size:] + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + matches = (windows == ngram_tensor).all(dim=1) + match_indices = matches.nonzero(as_tuple=True)[0] + if match_indices.size()[0] > 1: + has_spec_out = True + res = seq_data.get_token_ids() + res = res[match_indices[0] + ngram_size:match_indices[0] + + ngram_size + sample_len] + res_len = len(res) + # pad 0 towards output as sample_len tokens required + res += [0] * (sample_len - res_len) + + break + else: + # if no candidate found, fill with 0 + res = [0] * sample_len + + arr.append(res) + + if not has_spec_out: + return None, False + + outputs = [] + token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device) + indices = token_ids.unsqueeze(2) + + token_probs = torch.zeros( + (len(seq_group_metadata_list), sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + token_probs.scatter_(2, indices, 1) + for i in range(len(seq_group_metadata_list)): + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_probs[i], + sampled_token_ids=token_ids[i], + )) + return outputs, False + + def get_spec_proposals( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + max_proposal_len: int, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_proposals( + seq_group_metadata_list, + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + max_proposal_len, + ) + + def _raise_if_unsupported( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> None: + """NGramWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + raise NotImplementedError( + "NGramWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in seq_group_metadata_list): + raise NotImplementedError( + "NGramWorker does not support beam search.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 4e70ea9686..e33bb4f3f6 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -12,6 +12,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase @@ -48,8 +49,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): """ @classmethod - def from_workers(cls, proposer_worker: MultiStepWorker, - scorer_worker: WorkerBase) -> "SpecDecodeWorker": + def create_worker( + cls, + scorer_worker: WorkerBase, + draft_worker_kwargs, + ) -> "SpecDecodeWorker": + + if "ngram_prompt_lookup_max" in draft_worker_kwargs: + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + else: + ngram_prompt_lookup_max = 0 + + if ngram_prompt_lookup_max > 0: + proposer_worker = NGramWorker(**draft_worker_kwargs) + proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, + ngram_prompt_lookup_max) + else: + proposer_worker = MultiStepWorker(**draft_worker_kwargs) + return SpecDecodeWorker( proposer_worker, scorer_worker, @@ -59,7 +79,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): def __init__( self, - proposer_worker: MultiStepWorker, + proposer_worker: WorkerBase, scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, @@ -134,8 +154,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): """ (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor ) = True - (self.proposer_worker.model_runner.model.sampler. - include_gpu_probs_tensor) = True + self.proposer_worker.set_include_gpu_probs_tensor() def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. @@ -183,8 +202,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): "speculative decoding " "requires non-None seq_group_metadata_list") - logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d", - num_lookahead_slots) + #logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d", + # num_lookahead_slots) # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. @@ -216,7 +235,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): proposer and scorer model so that the KV cache is consistent between the two. """ - logger.info("run proposer worker no spec") + #logger.info("run proposer worker no spec") self.proposer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, @@ -225,7 +244,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): blocks_to_copy=blocks_to_copy, ) - logger.info("run target worker no spec") + #logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, @@ -259,7 +278,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): sequence. """ - logger.info("get spec proposals") + #logger.info("get spec proposals") # Generate proposals using draft worker. assert blocks_to_swap_in is not None assert blocks_to_swap_out is not None @@ -268,7 +287,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, k) - logger.info("score proposals") + #logger.info("score proposals") proposal_scores = self.scorer.score_proposals( seq_group_metadata_list, blocks_to_swap_in, @@ -278,11 +297,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): proposals, ) - logger.info("verify proposals") + #logger.info("verify proposals") accepted_token_ids = self._verify_tokens(seq_group_metadata_list, proposal_scores, proposals, k) - logger.info("create output list") + #logger.info("create output list") return self._create_output_sampler_list(seq_group_metadata_list, accepted_token_ids, k) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py new file mode 100644 index 0000000000..6766a2deb8 --- /dev/null +++ b/vllm/spec_decode/top1_proposer.py @@ -0,0 +1,200 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.util import sampler_output_to_torch +from vllm.worker.worker_base import WorkerBase + + +class Top1Proposer(SpeculativeProposer): + """Helper class which separates out sequences which would exceed the max + model length when speculated upon. + + This allows combinations of models such as JackFram/llama-68m draft with + meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of + 2048 while Llama2-13b has max_position_embeddings of 4096. + + We treat the sequences which exceed the proposal draft model length as + "non-spec sequences". Essentially they skip the draft model and go through + normal decoding in the target model. + + Currently, only proposal_lens of 0 and k are supported, where k is a global + batch proposal length. In the future vLLM should support per-sequence + proposal lengths. + """ + + def __init__( + self, + worker: WorkerBase, + device: str, + vocab_size: int, + max_proposal_len: Optional[int] = None, + ): + self._worker = worker + self._device = device + self.max_proposal_len = max_proposal_len + self._vocab_size = vocab_size + + def get_proposals( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + proposal_len: int, + ) -> SpeculativeProposals: + """Get speculative proposals given the input batch. + + Sequences which would exceed the max model length are skipped during + speculation. + """ + + # Split speculative- and non-speculative- sequences. + ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len) + + if nonzero_proposal_len_seqs: + # Speculate tokens using the draft worker for the speculative + # sequences. + # If sampler_transposed is true, then maybe_sampler_output's + # token_ids is like [batch] format in proposal_len size list, + # while if it is false, the format would be [proposal_len] + # in batch size list + maybe_sampler_output, transposed = self._worker.sampler_output( + seq_group_metadata_list=nonzero_proposal_len_seqs, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + sample_len=proposal_len, + ) + else: + # If no sequences can be speculated, set sampler output to None. + maybe_sampler_output = None + transposed = False + + # Combine speculative- and non-speculative sequences into the same + # representation. + proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( + batch_size=len(seq_group_metadata_list), + proposal_len=proposal_len, + maybe_sampler_output=maybe_sampler_output, + proposal_lens=proposal_lens, + nonzero_proposal_len_indices=nonzero_proposal_len_indices, + sampler_transposed=transposed, + ) + + proposals = SpeculativeProposals( + proposal_token_ids=proposal_tokens, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens, + ) + + return proposals + + def _split_by_max_model_len( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_len: int, + ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: + """Determine which sequences would exceed the max model length.""" + + proposal_lens: List[int] = [] + nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] + nonzero_proposal_len_indices: List[int] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_data = next(iter(seq_group_metadata.seq_data.values())) + seq_len = seq_data.get_len() + + # Currently only proposal lens of 0 or the global batch proposal len + # are supported. + # If max_proposal_len is defined, then we shall no exccess this + # quota for nonzero_proposal + if (self.max_proposal_len is None + or seq_len + proposal_len < self.max_proposal_len): + proposal_lens.append(proposal_len) + nonzero_proposal_len_seqs.append(seq_group_metadata) + nonzero_proposal_len_indices.append(i) + else: + proposal_lens.append(0) + + return ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) + + def _merge_outputs( + self, + batch_size: int, + proposal_len: int, + maybe_sampler_output: Optional[SamplerOutput], + proposal_lens: List[int], + nonzero_proposal_len_indices: List[int], + sampler_transposed: bool, + ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: + """After speculations are produced, merge the speculation results with + the skipped sequences. + """ + if maybe_sampler_output is None: + # If no speculative tokens, the sampler output will be None. + # In this case we return empty proposals. + proposal_tokens = torch.full( + size=( + batch_size, + proposal_len, + ), + fill_value=-1, + dtype=torch.long, + device=self._device, + ) + proposal_probs = torch.zeros( + batch_size, + proposal_len, + self._vocab_size, + dtype=torch.float32, + device=self._device, + ) + proposal_lens_tensor = torch.zeros(len(proposal_lens), + dtype=torch.long, + device=self._device) + return proposal_tokens, proposal_probs, proposal_lens_tensor + + sampler_output = maybe_sampler_output + proposal_tokens, proposal_probs = sampler_output_to_torch( + sampler_output, sampler_transposed) + + # Now, reformat the output GPU tensors such that each sequence has + # a proposal. the proposal can be empty, e.g. [-1, -1, -1] + + entire_proposal_tokens = torch.full( + size=(batch_size, *proposal_tokens.shape[1:]), + fill_value=-1, + dtype=torch.long, + device=self._device, + ) + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens + entire_proposal_probs = torch.zeros( + batch_size, + *proposal_probs.shape[1:], + dtype=torch.float32, + device=self._device, + ) + entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs + + proposal_tokens, proposal_probs = ( + entire_proposal_tokens, + entire_proposal_probs, + ) + + proposal_lens_tensor = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len + + return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index eb6d4ca1da..894d2fd915 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -49,10 +49,13 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( - sampler_output_list: List[SamplerOutput], -) -> Tuple[torch.Tensor, torch.Tensor]: + sampler_output_list: List[SamplerOutput], + sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]: """Utility function which converts a list of SamplerOutput to tensors. + sampler_transposed here is used as the indicator for whether + we need do additional tensor transpose logic here. + Returns: sampled_token_ids: torch.Tensor shape: [batch_size, len(sampler_output_list)] @@ -68,7 +71,10 @@ def sampler_output_to_torch( for sampler_output in sampler_output_list ], dim=0, - ).transpose(0, 1) + ) + + if sampler_transposed: + sampled_token_probs = sampled_token_probs.transpose(0, 1) # shape: [batch_size, num_sampler_output] sampled_token_ids = torch.stack( @@ -77,7 +83,9 @@ def sampler_output_to_torch( for sampler_output in sampler_output_list ], dim=0, - ).transpose(0, 1) + ) + if sampler_transposed: + sampled_token_ids = sampled_token_ids.transpose(0, 1) return sampled_token_ids, sampled_token_probs