mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[CI/Build] Fix pre-commit errors (#13696)
This commit is contained in:
@ -43,9 +43,9 @@ def main(args: argparse.Namespace):
|
|||||||
# the engine will automatically process the request in multiple batches.
|
# the engine will automatically process the request in multiple batches.
|
||||||
llm = LLM(**dataclasses.asdict(engine_args))
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
assert llm.llm_engine.model_config.max_model_len >= (
|
assert llm.llm_engine.model_config.max_model_len >= (
|
||||||
args.input_len + args.output_len), (
|
args.input_len +
|
||||||
"Please ensure that max_model_len is greater than"
|
args.output_len), ("Please ensure that max_model_len is greater than"
|
||||||
" the sum of input_len and output_len.")
|
" the sum of input_len and output_len.")
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=args.n,
|
n=args.n,
|
||||||
|
@ -523,7 +523,7 @@ class OpenAIServing:
|
|||||||
return logprob.decoded_token
|
return logprob.decoded_token
|
||||||
return tokenizer.decode(token_id)
|
return tokenizer.decode(token_id)
|
||||||
|
|
||||||
def _is_model_supported(self, model_name) -> bool:
|
def _is_model_supported(self, model_name: Optional[str]) -> bool:
|
||||||
if not model_name:
|
if not model_name:
|
||||||
return True
|
return True
|
||||||
return self.models.is_base_model(model_name)
|
return self.models.is_base_model(model_name)
|
||||||
|
@ -358,7 +358,12 @@ class ServingScores(OpenAIServing):
|
|||||||
request.truncate_prompt_tokens,
|
request.truncate_prompt_tokens,
|
||||||
)
|
)
|
||||||
return self.request_output_to_rerank_response(
|
return self.request_output_to_rerank_response(
|
||||||
final_res_batch, request_id, self._get_model_name(request.model), documents, top_n)
|
final_res_batch,
|
||||||
|
request_id,
|
||||||
|
self._get_model_name(request.model),
|
||||||
|
documents,
|
||||||
|
top_n,
|
||||||
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -134,7 +134,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||||
return tp_size - ngroups
|
return tp_size - ngroups
|
||||||
|
|
||||||
|
|
||||||
def mamba_v2_sharded_weight_loader(
|
def mamba_v2_sharded_weight_loader(
|
||||||
@ -168,12 +168,9 @@ def mamba_v2_sharded_weight_loader(
|
|||||||
# - compute the rank into the loaded shard.
|
# - compute the rank into the loaded shard.
|
||||||
# - if there is replication, different TP shards will
|
# - if there is replication, different TP shards will
|
||||||
# take from the same rank.
|
# take from the same rank.
|
||||||
if duplicate_groups:
|
# NOTE: currently we only support duplication
|
||||||
# NOTE: currently we only support duplication
|
# in the case where num_groups == 1
|
||||||
# in the case where num_groups == 1
|
rank = 0 if duplicate_groups else tp_rank
|
||||||
rank = 0
|
|
||||||
else:
|
|
||||||
rank = tp_rank
|
|
||||||
|
|
||||||
# - leftmost boundary index into loaded weight.
|
# - leftmost boundary index into loaded weight.
|
||||||
loaded_skip = rank * shard_size
|
loaded_skip = rank * shard_size
|
||||||
@ -247,7 +244,7 @@ class MambaMixer2(CustomOp):
|
|||||||
assert num_heads % self.tp_size == 0, \
|
assert num_heads % self.tp_size == 0, \
|
||||||
"Tensor parallel world size must divide num heads."
|
"Tensor parallel world size must divide num heads."
|
||||||
|
|
||||||
|
|
||||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
||||||
(
|
(
|
||||||
"If tensor parallel world size does not divide num_heads, "
|
"If tensor parallel world size does not divide num_heads, "
|
||||||
|
@ -1198,10 +1198,12 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
try:
|
try:
|
||||||
value = int(value)
|
value = int(value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise argparse.ArgumentTypeError("Port must be an integer")
|
msg = "Port must be an integer"
|
||||||
|
raise argparse.ArgumentTypeError(msg) from None
|
||||||
|
|
||||||
if not (1024 <= value <= 65535):
|
if not (1024 <= value <= 65535):
|
||||||
raise argparse.ArgumentTypeError("Port must be between 1024 and 65535")
|
raise argparse.ArgumentTypeError(
|
||||||
|
"Port must be between 1024 and 65535")
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -1319,13 +1319,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
generators={},
|
generators={},
|
||||||
max_num_logprobs=None,
|
max_num_logprobs=None,
|
||||||
no_penalties=True,
|
no_penalties=True,
|
||||||
prompt_token_ids=torch.ones_like(logits, dtype=torch.int64),
|
prompt_token_ids=torch.ones_like(logits,
|
||||||
|
dtype=torch.int64),
|
||||||
frequency_penalties=dummy_tensors(0.1),
|
frequency_penalties=dummy_tensors(0.1),
|
||||||
presence_penalties=dummy_tensors(0.1),
|
presence_penalties=dummy_tensors(0.1),
|
||||||
repetition_penalties=dummy_tensors(0.1),
|
repetition_penalties=dummy_tensors(0.1),
|
||||||
output_token_ids=[[] for _ in range(num_reqs)],
|
output_token_ids=[[] for _ in range(num_reqs)],
|
||||||
min_tokens={},
|
min_tokens={},
|
||||||
logit_bias=[None for _ in range(num_reqs)])
|
logit_bias=[None for _ in range(num_reqs)],
|
||||||
|
allowed_token_ids_mask=None,
|
||||||
|
)
|
||||||
sampler_output = self.model.sample(
|
sampler_output = self.model.sample(
|
||||||
logits=logits, sampling_metadata=dummy_metadata)
|
logits=logits, sampling_metadata=dummy_metadata)
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user