mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc][Speculative decoding] Typos and typing fixes (#6467)
Co-authored-by: caishangming.csm <caishangming.csm@alibaba-inc.com>
This commit is contained in:
@ -43,7 +43,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
# Need include_gpu_probs_tensor for MultiStepWorker
|
||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
@torch.inference_mode()
|
||||
|
@ -13,7 +13,7 @@ from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
"""NGramWorker provides a light drafter without need for model.
|
||||
|
||||
Current NGramWorker only implement prompt lookup decoding,
|
||||
Current NGramWorker only implements prompt lookup decoding,
|
||||
and in future we may also do RAG type drafter and other scenarios
|
||||
which don't rely on LLM model to give proposals.
|
||||
"""
|
||||
@ -37,7 +37,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.load_model = lambda *args, **kwargs: None
|
||||
|
||||
# Current only support Top1Proposer
|
||||
# Current NGramWorker only supports Top1Proposer
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
device=self.device,
|
||||
|
@ -24,7 +24,7 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
|
@ -206,7 +206,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
||||
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||
# Lazy initiazliation.
|
||||
# Lazy initialization.
|
||||
self.scorer: SpeculativeScorer
|
||||
|
||||
# Hidden states from target model to pass to proposer
|
||||
|
@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
# If max_proposal_len is defined, then we shall no exccess this
|
||||
# If max_proposal_len is defined, then we shall no exceed this
|
||||
# quota for nonzero_proposal
|
||||
new_k = 0
|
||||
if (self.max_proposal_len is None
|
||||
@ -219,7 +219,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user