mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 21:53:54 +08:00
### What this PR does / why we need it? Change as little existing code as possible to add v1 pooling task's support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to vllm-ascend, Considering the frequent changes in upstream interfaces, in order to decouple, so i move it here ### How was this patch tested? CI passed with new added/existing test, and I have a simple test was first conducted locally which is adapted from https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like bellow: ```python import os import torch from vllm import LLM os.environ["VLLM_USE_MODELSCOPE"]="True" def get_detailed_instruct(task_description: str, query: str) -> str: return f'Instruct: {task_description}\nQuery:{query}' # Each query must come with a one-sentence instruction that describes the task task = 'Given a web search query, retrieve relevant passages that answer the query' queries = [ get_detailed_instruct(task, 'What is the capital of China?'), get_detailed_instruct(task, 'Explain gravity') ] # No need to add instruction for retrieval documents documents = [ "The capital of China is Beijing.", "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." ] input_texts = queries + documents model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed") outputs = model.embed(input_texts) embeddings = torch.tensor([o.outputs.embedding for o in outputs]) scores = (embeddings[:2] @ embeddings[2:].T) print(scores.tolist()) # [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]] ``` --------- Signed-off-by: wangli <wangli858794774@gmail.com> Signed-off-by: wangli <858794774@qq.com> Co-authored-by: wangli <858794774@qq.com>
237 lines
8.2 KiB
Python
237 lines
8.2 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/tests/utils.py
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import functools
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from collections.abc import Sequence
|
|
from typing import Callable, Optional
|
|
|
|
import openai
|
|
import requests
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from typing_extensions import ParamSpec
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
|
from vllm.model_executor.model_loader import get_model_loader
|
|
from vllm.utils import FlexibleArgumentParser, get_open_port
|
|
|
|
_P = ParamSpec("_P")
|
|
|
|
|
|
class RemoteOpenAIServer:
|
|
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
|
|
|
def __init__(self,
|
|
model: str,
|
|
vllm_serve_args: list[str],
|
|
*,
|
|
env_dict: Optional[dict[str, str]] = None,
|
|
seed: Optional[int] = 0,
|
|
auto_port: bool = True,
|
|
max_wait_seconds: Optional[float] = None) -> None:
|
|
if auto_port:
|
|
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
|
raise ValueError("You have manually specified the port "
|
|
"when `auto_port=True`.")
|
|
|
|
# Don't mutate the input args
|
|
vllm_serve_args = vllm_serve_args + [
|
|
"--port", str(get_open_port())
|
|
]
|
|
if seed is not None:
|
|
if "--seed" in vllm_serve_args:
|
|
raise ValueError("You have manually specified the seed "
|
|
f"when `seed={seed}`.")
|
|
|
|
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
|
|
|
parser = FlexibleArgumentParser(
|
|
description="vLLM's remote OpenAI server.")
|
|
parser = make_arg_parser(parser)
|
|
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
|
self.host = str(args.host or 'localhost')
|
|
self.port = int(args.port)
|
|
|
|
self.show_hidden_metrics = \
|
|
args.show_hidden_metrics_for_version is not None
|
|
|
|
# download the model before starting the server to avoid timeout
|
|
is_local = os.path.isdir(model)
|
|
if not is_local:
|
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
model_config = engine_args.create_model_config()
|
|
load_config = engine_args.create_load_config()
|
|
|
|
model_loader = get_model_loader(load_config)
|
|
model_loader.download_model(model_config)
|
|
|
|
env = os.environ.copy()
|
|
# the current process might initialize cuda,
|
|
# to be safe, we should use spawn method
|
|
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
|
if env_dict is not None:
|
|
env.update(env_dict)
|
|
self.proc = subprocess.Popen(
|
|
["vllm", "serve", model, *vllm_serve_args],
|
|
env=env,
|
|
stdout=sys.stdout,
|
|
stderr=sys.stderr,
|
|
)
|
|
max_wait_seconds = max_wait_seconds or 240
|
|
self._wait_for_server(url=self.url_for("health"),
|
|
timeout=max_wait_seconds)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.proc.terminate()
|
|
try:
|
|
self.proc.wait(8)
|
|
except subprocess.TimeoutExpired:
|
|
# force kill if needed
|
|
self.proc.kill()
|
|
|
|
def _wait_for_server(self, *, url: str, timeout: float):
|
|
# run health check
|
|
start = time.time()
|
|
while True:
|
|
try:
|
|
if requests.get(url).status_code == 200:
|
|
break
|
|
except Exception:
|
|
# this exception can only be raised by requests.get,
|
|
# which means the server is not ready yet.
|
|
# the stack trace is not useful, so we suppress it
|
|
# by using `raise from None`.
|
|
result = self.proc.poll()
|
|
if result is not None and result != 0:
|
|
raise RuntimeError("Server exited unexpectedly.") from None
|
|
|
|
time.sleep(0.5)
|
|
if time.time() - start > timeout:
|
|
raise RuntimeError(
|
|
"Server failed to start in time.") from None
|
|
|
|
@property
|
|
def url_root(self) -> str:
|
|
return f"http://{self.host}:{self.port}"
|
|
|
|
def url_for(self, *parts: str) -> str:
|
|
return self.url_root + "/" + "/".join(parts)
|
|
|
|
def get_client(self, **kwargs):
|
|
if "timeout" not in kwargs:
|
|
kwargs["timeout"] = 600
|
|
return openai.OpenAI(
|
|
base_url=self.url_for("v1"),
|
|
api_key=self.DUMMY_API_KEY,
|
|
max_retries=0,
|
|
**kwargs,
|
|
)
|
|
|
|
def get_async_client(self, **kwargs):
|
|
if "timeout" not in kwargs:
|
|
kwargs["timeout"] = 600
|
|
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
|
|
api_key=self.DUMMY_API_KEY,
|
|
max_retries=0,
|
|
**kwargs)
|
|
|
|
|
|
def fork_new_process_for_each_test(
|
|
f: Callable[_P, None]) -> Callable[_P, None]:
|
|
"""Decorator to fork a new process for each test function.
|
|
See https://github.com/vllm-project/vllm/issues/7053 for more details.
|
|
"""
|
|
|
|
@functools.wraps(f)
|
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
|
# Make the process the leader of its own process group
|
|
# to avoid sending SIGTERM to the parent process
|
|
os.setpgrp()
|
|
from _pytest.outcomes import Skipped
|
|
pid = os.fork()
|
|
print(f"Fork a new process to run a test {pid}")
|
|
if pid == 0:
|
|
try:
|
|
f(*args, **kwargs)
|
|
except Skipped as e:
|
|
# convert Skipped to exit code 0
|
|
print(str(e))
|
|
os._exit(0)
|
|
except Exception:
|
|
import traceback
|
|
traceback.print_exc()
|
|
os._exit(1)
|
|
else:
|
|
os._exit(0)
|
|
else:
|
|
pgid = os.getpgid(pid)
|
|
_pid, _exitcode = os.waitpid(pid, 0)
|
|
# ignore SIGTERM signal itself
|
|
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
|
# kill all child processes
|
|
os.killpg(pgid, signal.SIGTERM)
|
|
# restore the signal handler
|
|
signal.signal(signal.SIGTERM, old_signal_handler)
|
|
assert _exitcode == 0, (f"function {f} failed when called with"
|
|
f" args {args} and kwargs {kwargs}")
|
|
|
|
return wrapper
|
|
|
|
|
|
def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
|
|
tensor = torch.tensor(tensor)
|
|
tensor = tensor[..., :dimensions]
|
|
tensor = F.normalize(tensor, p=2, dim=1)
|
|
return tensor
|
|
|
|
|
|
def check_embeddings_close(
|
|
*,
|
|
embeddings_0_lst: Sequence[list[float]],
|
|
embeddings_1_lst: Sequence[list[float]],
|
|
name_0: str,
|
|
name_1: str,
|
|
tol: float = 1e-3,
|
|
) -> None:
|
|
assert len(embeddings_0_lst) == len(embeddings_1_lst)
|
|
|
|
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
|
|
zip(embeddings_0_lst, embeddings_1_lst)):
|
|
assert len(embeddings_0) == len(embeddings_1), (
|
|
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
|
|
|
|
sim = F.cosine_similarity(torch.tensor(embeddings_0),
|
|
torch.tensor(embeddings_1),
|
|
dim=0)
|
|
|
|
fail_msg = (f"Test{prompt_idx}:"
|
|
f"\nCosine similarity: \t{sim:.4f}"
|
|
f"\n{name_0}:\t{embeddings_0[:16]!r}"
|
|
f"\n{name_1}:\t{embeddings_1[:16]!r}")
|
|
|
|
assert sim >= 1 - tol, fail_msg
|