mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bug fix][Core] fixup ngram not setup correctly (#4551)
Co-authored-by: Lei Wen <wenlei03@qiyi.com> Co-authored-by: Cade Daniel <edacih@gmail.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@ -55,7 +55,7 @@ class AsyncLLM:
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
self.engine_args = AsyncEngineArgs(
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
@ -76,6 +76,8 @@ class AsyncLLM:
|
||||
**kwargs,
|
||||
)
|
||||
self.request_counter = Counter()
|
||||
self.llm_engine = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -88,9 +90,6 @@ class AsyncLLM:
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
|
||||
llm_engine = AsyncLLMEngine.from_engine_args(
|
||||
self.engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
|
||||
if prompts is None:
|
||||
raise ValueError("prompts must be provided.")
|
||||
if isinstance(prompts, str):
|
||||
@ -111,8 +110,8 @@ class AsyncLLM:
|
||||
|
||||
async def get_output(prompt, sampling_param) -> str:
|
||||
request_id = random_uuid()
|
||||
results_generator = llm_engine.generate(prompt, sampling_param,
|
||||
request_id)
|
||||
results_generator = self.llm_engine.generate(
|
||||
prompt, sampling_param, request_id)
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
@ -185,12 +184,25 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
return generator_outer
|
||||
|
||||
|
||||
def maybe_assert_ngram_worker(llm):
|
||||
# Verify the proposer worker is ngram if ngram is specified.
|
||||
if (not isinstance(llm, AsyncLLM)
|
||||
and llm.llm_engine.speculative_config is not None
|
||||
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
assert isinstance(
|
||||
llm.llm_engine.model_executor.driver_worker.proposer_worker,
|
||||
NGramWorker)
|
||||
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||
tokens = []
|
||||
token_ids = []
|
||||
for llm in llm_generator():
|
||||
maybe_assert_ngram_worker(llm)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
|
@ -82,6 +82,10 @@ class GPUExecutor(ExecutorBase):
|
||||
draft_worker_kwargs.update(
|
||||
model_config=self.speculative_config.draft_model_config,
|
||||
parallel_config=self.speculative_config.draft_parallel_config,
|
||||
ngram_prompt_lookup_max=self.speculative_config.
|
||||
ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=self.speculative_config.
|
||||
ngram_prompt_lookup_min,
|
||||
# TODO allow draft-model specific load config.
|
||||
#load_config=self.load_config,
|
||||
)
|
||||
|
@ -57,13 +57,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft_worker_kwargs,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
if "ngram_prompt_lookup_max" in draft_worker_kwargs:
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
else:
|
||||
ngram_prompt_lookup_max = 0
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
@ -72,6 +69,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
else:
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||
type(proposer_worker))
|
||||
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
|
Reference in New Issue
Block a user