mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI/Build] Fix pre-commit errors (#13696)
This commit is contained in:
@ -43,9 +43,9 @@ def main(args: argparse.Namespace):
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len + args.output_len), (
|
||||
"Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len.")
|
||||
args.input_len +
|
||||
args.output_len), ("Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len.")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
|
@ -523,7 +523,7 @@ class OpenAIServing:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
def _is_model_supported(self, model_name) -> bool:
|
||||
def _is_model_supported(self, model_name: Optional[str]) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
@ -358,7 +358,12 @@ class ServingScores(OpenAIServing):
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch, request_id, self._get_model_name(request.model), documents, top_n)
|
||||
final_res_batch,
|
||||
request_id,
|
||||
self._get_model_name(request.model),
|
||||
documents,
|
||||
top_n,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
|
@ -134,7 +134,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
return tp_size - ngroups
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
@ -168,12 +168,9 @@ def mamba_v2_sharded_weight_loader(
|
||||
# - compute the rank into the loaded shard.
|
||||
# - if there is replication, different TP shards will
|
||||
# take from the same rank.
|
||||
if duplicate_groups:
|
||||
# NOTE: currently we only support duplication
|
||||
# in the case where num_groups == 1
|
||||
rank = 0
|
||||
else:
|
||||
rank = tp_rank
|
||||
# NOTE: currently we only support duplication
|
||||
# in the case where num_groups == 1
|
||||
rank = 0 if duplicate_groups else tp_rank
|
||||
|
||||
# - leftmost boundary index into loaded weight.
|
||||
loaded_skip = rank * shard_size
|
||||
@ -247,7 +244,7 @@ class MambaMixer2(CustomOp):
|
||||
assert num_heads % self.tp_size == 0, \
|
||||
"Tensor parallel world size must divide num heads."
|
||||
|
||||
|
||||
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
||||
(
|
||||
"If tensor parallel world size does not divide num_heads, "
|
||||
|
@ -1198,10 +1198,12 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError("Port must be an integer")
|
||||
msg = "Port must be an integer"
|
||||
raise argparse.ArgumentTypeError(msg) from None
|
||||
|
||||
if not (1024 <= value <= 65535):
|
||||
raise argparse.ArgumentTypeError("Port must be between 1024 and 65535")
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Port must be between 1024 and 65535")
|
||||
|
||||
return value
|
||||
|
||||
|
@ -1319,13 +1319,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=torch.ones_like(logits, dtype=torch.int64),
|
||||
prompt_token_ids=torch.ones_like(logits,
|
||||
dtype=torch.int64),
|
||||
frequency_penalties=dummy_tensors(0.1),
|
||||
presence_penalties=dummy_tensors(0.1),
|
||||
repetition_penalties=dummy_tensors(0.1),
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
min_tokens={},
|
||||
logit_bias=[None for _ in range(num_reqs)])
|
||||
logit_bias=[None for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits, sampling_metadata=dummy_metadata)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user