mypy type checking for vllm/worker (#11418)

Signed-off-by: lucast2021 <lucast2021@headroyce.org>
Co-authored-by: lucast2021 <lucast2021@headroyce.org>
This commit is contained in:
Lucas Tucker
2024-12-23 07:55:49 -06:00
committed by GitHub
parent f30581c518
commit e51719ae72
3 changed files with 9 additions and 9 deletions

View File

@ -333,9 +333,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None
virtual_engine = execute_model_req.virtual_engine
virtual_engine: int = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)

View File

@ -406,8 +406,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if not cont:
break
def _final_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]):
def _final_process_outputs(
self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
assert model_input.frozen_model_input is not None
has_async_callback = output_proc_callback is not None
@ -594,8 +595,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# should be [SamplerOutput]
return output
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
num_seqs: Optional[int], num_queries: int):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
@ -850,13 +851,13 @@ def _pythonize_sampler_output(
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
seq_outputs: List[SequenceOutput]
if cache is not None:
completion_seq_group_output: CompletionSequenceGroupOutput = \
cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear()
seq_outputs: List[
SequenceOutput] = completion_seq_group_output.samples
seq_outputs = completion_seq_group_output.samples
else:
seq_outputs = []

View File

@ -452,7 +452,7 @@ class WorkerWrapperBase:
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None
def execute_method(self, method, *args, **kwargs):
def execute_method(self, method: str, *args, **kwargs):
try:
target = self if self.worker is None else self.worker
executor = getattr(target, method)