[CI/Build] Fix pre-commit errors (#13696)

This commit is contained in:
Cyrus Leung
2025-02-22 16:31:26 +08:00
committed by GitHub
parent 105b8ce4c0
commit 7f6bae561c
6 changed files with 24 additions and 17 deletions

View File

@ -43,8 +43,8 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.output_len), (
"Please ensure that max_model_len is greater than"
args.input_len +
args.output_len), ("Please ensure that max_model_len is greater than"
" the sum of input_len and output_len.")
sampling_params = SamplingParams(

View File

@ -523,7 +523,7 @@ class OpenAIServing:
return logprob.decoded_token
return tokenizer.decode(token_id)
def _is_model_supported(self, model_name) -> bool:
def _is_model_supported(self, model_name: Optional[str]) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name)

View File

@ -358,7 +358,12 @@ class ServingScores(OpenAIServing):
request.truncate_prompt_tokens,
)
return self.request_output_to_rerank_response(
final_res_batch, request_id, self._get_model_name(request.model), documents, top_n)
final_res_batch,
request_id,
self._get_model_name(request.model),
documents,
top_n,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:

View File

@ -168,12 +168,9 @@ def mamba_v2_sharded_weight_loader(
# - compute the rank into the loaded shard.
# - if there is replication, different TP shards will
# take from the same rank.
if duplicate_groups:
# NOTE: currently we only support duplication
# in the case where num_groups == 1
rank = 0
else:
rank = tp_rank
rank = 0 if duplicate_groups else tp_rank
# - leftmost boundary index into loaded weight.
loaded_skip = rank * shard_size

View File

@ -1198,10 +1198,12 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
try:
value = int(value)
except ValueError:
raise argparse.ArgumentTypeError("Port must be an integer")
msg = "Port must be an integer"
raise argparse.ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535):
raise argparse.ArgumentTypeError("Port must be between 1024 and 65535")
raise argparse.ArgumentTypeError(
"Port must be between 1024 and 65535")
return value

View File

@ -1319,13 +1319,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=torch.ones_like(logits, dtype=torch.int64),
prompt_token_ids=torch.ones_like(logits,
dtype=torch.int64),
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)])
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
)
sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata)
else: