mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
mypy type checking for vllm/worker (#11418)
Signed-off-by: lucast2021 <lucast2021@headroyce.org> Co-authored-by: lucast2021 <lucast2021@headroyce.org>
This commit is contained in:
@ -333,9 +333,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
assert execute_model_req is not None
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
virtual_engine: int = execute_model_req.virtual_engine
|
||||
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
|
||||
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
|
@ -406,8 +406,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
if not cont:
|
||||
break
|
||||
|
||||
def _final_process_outputs(self, model_input: StatefulModelInput,
|
||||
output_proc_callback: Optional[Callable]):
|
||||
def _final_process_outputs(
|
||||
self, model_input: StatefulModelInput,
|
||||
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
has_async_callback = output_proc_callback is not None
|
||||
@ -594,8 +595,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
# should be [SamplerOutput]
|
||||
return output
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
|
||||
num_seqs: Optional[int], num_queries: int):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
@ -850,13 +851,13 @@ def _pythonize_sampler_output(
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids = sample_result
|
||||
parent_ids = [0]
|
||||
seq_outputs: List[SequenceOutput]
|
||||
|
||||
if cache is not None:
|
||||
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
||||
cache.cached_completion_seq_group_output.get_object()
|
||||
completion_seq_group_output.samples.clear()
|
||||
seq_outputs: List[
|
||||
SequenceOutput] = completion_seq_group_output.samples
|
||||
seq_outputs = completion_seq_group_output.samples
|
||||
else:
|
||||
seq_outputs = []
|
||||
|
||||
|
@ -452,7 +452,7 @@ class WorkerWrapperBase:
|
||||
self.worker = worker_class(*args, **kwargs)
|
||||
assert self.worker is not None
|
||||
|
||||
def execute_method(self, method, *args, **kwargs):
|
||||
def execute_method(self, method: str, *args, **kwargs):
|
||||
try:
|
||||
target = self if self.worker is None else self.worker
|
||||
executor = getattr(target, method)
|
||||
|
Reference in New Issue
Block a user