[CI/Build] Fix pre-commit errors (#13696)

This commit is contained in:
Cyrus Leung
2025-02-22 16:31:26 +08:00
committed by GitHub
parent 105b8ce4c0
commit 7f6bae561c
6 changed files with 24 additions and 17 deletions

View File

@ -43,9 +43,9 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= ( assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.output_len), ( args.input_len +
"Please ensure that max_model_len is greater than" args.output_len), ("Please ensure that max_model_len is greater than"
" the sum of input_len and output_len.") " the sum of input_len and output_len.")
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,

View File

@ -523,7 +523,7 @@ class OpenAIServing:
return logprob.decoded_token return logprob.decoded_token
return tokenizer.decode(token_id) return tokenizer.decode(token_id)
def _is_model_supported(self, model_name) -> bool: def _is_model_supported(self, model_name: Optional[str]) -> bool:
if not model_name: if not model_name:
return True return True
return self.models.is_base_model(model_name) return self.models.is_base_model(model_name)

View File

@ -358,7 +358,12 @@ class ServingScores(OpenAIServing):
request.truncate_prompt_tokens, request.truncate_prompt_tokens,
) )
return self.request_output_to_rerank_response( return self.request_output_to_rerank_response(
final_res_batch, request_id, self._get_model_name(request.model), documents, top_n) final_res_batch,
request_id,
self._get_model_name(request.model),
documents,
top_n,
)
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:

View File

@ -134,7 +134,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
return 0 return 0
# for n_groups == 1, this is exactly tp_size - n_groups # for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups return tp_size - ngroups
def mamba_v2_sharded_weight_loader( def mamba_v2_sharded_weight_loader(
@ -168,12 +168,9 @@ def mamba_v2_sharded_weight_loader(
# - compute the rank into the loaded shard. # - compute the rank into the loaded shard.
# - if there is replication, different TP shards will # - if there is replication, different TP shards will
# take from the same rank. # take from the same rank.
if duplicate_groups: # NOTE: currently we only support duplication
# NOTE: currently we only support duplication # in the case where num_groups == 1
# in the case where num_groups == 1 rank = 0 if duplicate_groups else tp_rank
rank = 0
else:
rank = tp_rank
# - leftmost boundary index into loaded weight. # - leftmost boundary index into loaded weight.
loaded_skip = rank * shard_size loaded_skip = rank * shard_size
@ -247,7 +244,7 @@ class MambaMixer2(CustomOp):
assert num_heads % self.tp_size == 0, \ assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads." "Tensor parallel world size must divide num heads."
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
( (
"If tensor parallel world size does not divide num_heads, " "If tensor parallel world size does not divide num_heads, "

View File

@ -1198,10 +1198,12 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
try: try:
value = int(value) value = int(value)
except ValueError: except ValueError:
raise argparse.ArgumentTypeError("Port must be an integer") msg = "Port must be an integer"
raise argparse.ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535): if not (1024 <= value <= 65535):
raise argparse.ArgumentTypeError("Port must be between 1024 and 65535") raise argparse.ArgumentTypeError(
"Port must be between 1024 and 65535")
return value return value

View File

@ -1319,13 +1319,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
generators={}, generators={},
max_num_logprobs=None, max_num_logprobs=None,
no_penalties=True, no_penalties=True,
prompt_token_ids=torch.ones_like(logits, dtype=torch.int64), prompt_token_ids=torch.ones_like(logits,
dtype=torch.int64),
frequency_penalties=dummy_tensors(0.1), frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1), presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1), repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)], output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={}, min_tokens={},
logit_bias=[None for _ in range(num_reqs)]) logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
)
sampler_output = self.model.sample( sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata) logits=logits, sampling_metadata=dummy_metadata)
else: else: