mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI/Build] Replace vllm.entrypoints.openai.api_server
entrypoint with vllm serve
command (#25967)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -181,18 +181,14 @@ launch_vllm_server() {
|
||||
if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then
|
||||
echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience."
|
||||
model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model')
|
||||
server_command="python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
server_command="vllm serve $model \
|
||||
-tp $tp \
|
||||
--model $model \
|
||||
--port $port \
|
||||
$server_args"
|
||||
else
|
||||
echo "Key 'fp8' does not exist in common params."
|
||||
server_command="python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
server_command="vllm serve $model \
|
||||
-tp $tp \
|
||||
--model $model \
|
||||
--port $port \
|
||||
$server_args"
|
||||
fi
|
||||
|
@ -365,8 +365,7 @@ run_serving_tests() {
|
||||
continue
|
||||
fi
|
||||
|
||||
server_command="$server_envs python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
server_command="$server_envs vllm serve \
|
||||
$server_args"
|
||||
|
||||
# run the server
|
||||
|
@ -18,7 +18,7 @@ vllm bench throughput --input-len 256 --output-len 256 --output-json throughput_
|
||||
bench_throughput_exit_code=$?
|
||||
|
||||
# run server-based benchmarks and upload the result to buildkite
|
||||
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
|
||||
vllm serve meta-llama/Llama-2-7b-chat-hf &
|
||||
server_pid=$!
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
|
@ -55,9 +55,7 @@ benchmark() {
|
||||
output_len=$2
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||
--port 8100 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
@ -65,9 +63,7 @@ benchmark() {
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||
--port 8200 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
|
@ -38,16 +38,12 @@ wait_for_server() {
|
||||
launch_chunked_prefill() {
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
# disagg prefill
|
||||
CUDA_VISIBLE_DEVICES=0 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||
--port 8100 \
|
||||
--max-model-len 10000 \
|
||||
--enable-chunked-prefill \
|
||||
--gpu-memory-utilization 0.6 &
|
||||
CUDA_VISIBLE_DEVICES=1 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||
--port 8200 \
|
||||
--max-model-len 10000 \
|
||||
--enable-chunked-prefill \
|
||||
@ -62,18 +58,14 @@ launch_chunked_prefill() {
|
||||
launch_disagg_prefill() {
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
# disagg prefill
|
||||
CUDA_VISIBLE_DEVICES=0 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||
--port 8100 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||
--port 8200 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
|
@ -565,5 +565,5 @@ ENTRYPOINT ["./sagemaker-entrypoint.sh"]
|
||||
|
||||
FROM vllm-openai-base AS vllm-openai
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENTRYPOINT ["vllm", "serve"]
|
||||
#################### OPENAI API SERVER ####################
|
||||
|
@ -177,4 +177,4 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \
|
||||
uv pip install dist/*.whl
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENTRYPOINT ["vllm", "serve"]
|
||||
|
@ -314,4 +314,4 @@ WORKDIR /workspace/
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENTRYPOINT ["vllm", "serve"]
|
||||
|
@ -309,4 +309,4 @@ USER 2000
|
||||
WORKDIR /home/vllm
|
||||
|
||||
# Set the default entrypoint
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENTRYPOINT ["vllm", "serve"]
|
||||
|
@ -69,4 +69,4 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENTRYPOINT ["vllm", "serve"]
|
||||
|
@ -661,8 +661,7 @@ Benchmark the performance of multi-modal requests in vLLM.
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
vllm serve Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"image": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4v/images
|
||||
@ -688,8 +687,7 @@ vllm bench serve \
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
vllm serve Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"video": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4video/videos
|
||||
|
@ -39,8 +39,7 @@ Refer to <gh-file:examples/offline_inference/simple_profiling.py> for an example
|
||||
|
||||
```bash
|
||||
VLLM_TORCH_PROFILER_DIR=./vllm_profile \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model meta-llama/Meta-Llama-3-70B
|
||||
vllm serve meta-llama/Meta-Llama-3-70B
|
||||
```
|
||||
|
||||
vllm bench command:
|
||||
|
@ -19,8 +19,7 @@ pip install -U "autogen-agentchat" "autogen-ext[openai]"
|
||||
1. Start the vLLM server with the supported chat completion model, e.g.
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model mistralai/Mistral-7B-Instruct-v0.2
|
||||
vllm serve mistralai/Mistral-7B-Instruct-v0.2
|
||||
```
|
||||
|
||||
1. Call it with AutoGen:
|
||||
|
@ -20,7 +20,7 @@ To get started with Open WebUI using vLLM, follow these steps:
|
||||
For example:
|
||||
|
||||
```console
|
||||
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000
|
||||
vllm serve <model> --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
3. Start the Open WebUI Docker container:
|
||||
|
@ -32,6 +32,7 @@ See the vLLM SkyPilot YAML for serving, [serving.yaml](https://github.com/skypil
|
||||
ports: 8081 # Expose to internet traffic.
|
||||
|
||||
envs:
|
||||
PYTHONUNBUFFERED: 1
|
||||
MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
HF_TOKEN: <your-huggingface-token> # Change to your own huggingface token, or use --env to pass.
|
||||
|
||||
@ -47,9 +48,8 @@ See the vLLM SkyPilot YAML for serving, [serving.yaml](https://github.com/skypil
|
||||
run: |
|
||||
conda activate vllm
|
||||
echo 'Starting vllm api server...'
|
||||
python -u -m vllm.entrypoints.openai.api_server \
|
||||
vllm serve $MODEL_NAME \
|
||||
--port 8081 \
|
||||
--model $MODEL_NAME \
|
||||
--trust-remote-code \
|
||||
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
|
||||
2>&1 | tee api_server.log &
|
||||
@ -131,6 +131,7 @@ SkyPilot can scale up the service to multiple service replicas with built-in aut
|
||||
ports: 8081 # Expose to internet traffic.
|
||||
|
||||
envs:
|
||||
PYTHONUNBUFFERED: 1
|
||||
MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
HF_TOKEN: <your-huggingface-token> # Change to your own huggingface token, or use --env to pass.
|
||||
|
||||
@ -146,9 +147,8 @@ SkyPilot can scale up the service to multiple service replicas with built-in aut
|
||||
run: |
|
||||
conda activate vllm
|
||||
echo 'Starting vllm api server...'
|
||||
python -u -m vllm.entrypoints.openai.api_server \
|
||||
vllm serve $MODEL_NAME \
|
||||
--port 8081 \
|
||||
--model $MODEL_NAME \
|
||||
--trust-remote-code \
|
||||
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
|
||||
2>&1 | tee api_server.log
|
||||
@ -243,6 +243,7 @@ This will scale the service up to when the QPS exceeds 2 for each replica.
|
||||
ports: 8081 # Expose to internet traffic.
|
||||
|
||||
envs:
|
||||
PYTHONUNBUFFERED: 1
|
||||
MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
HF_TOKEN: <your-huggingface-token> # Change to your own huggingface token, or use --env to pass.
|
||||
|
||||
@ -258,9 +259,8 @@ This will scale the service up to when the QPS exceeds 2 for each replica.
|
||||
run: |
|
||||
conda activate vllm
|
||||
echo 'Starting vllm api server...'
|
||||
python -u -m vllm.entrypoints.openai.api_server \
|
||||
vllm serve $MODEL_NAME \
|
||||
--port 8081 \
|
||||
--model $MODEL_NAME \
|
||||
--trust-remote-code \
|
||||
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
|
||||
2>&1 | tee api_server.log
|
||||
|
@ -69,6 +69,11 @@ Sometimes you may see the API server entrypoint used directly instead of via the
|
||||
python -m vllm.entrypoints.openai.api_server --model <model>
|
||||
```
|
||||
|
||||
!!! warning
|
||||
|
||||
`python -m vllm.entrypoints.openai.api_server` is deprecated
|
||||
and may become unsupported in a future release.
|
||||
|
||||
That code can be found in <gh-file:vllm/entrypoints/openai/api_server.py>.
|
||||
|
||||
More details on the API server can be found in the [OpenAI-Compatible Server](../serving/openai_compatible_server.md) document.
|
||||
|
@ -64,8 +64,7 @@ To enable sleep mode in a vLLM server you need to initialize it with the flag `V
|
||||
When using the flag `VLLM_SERVER_DEV_MODE=1` you enable development endpoints, and these endpoints should not be exposed to users.
|
||||
|
||||
```bash
|
||||
VLLM_SERVER_DEV_MODE=1 python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen3-0.6B \
|
||||
VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \
|
||||
--enable-sleep-mode \
|
||||
--port 8000
|
||||
```
|
||||
|
@ -48,10 +48,9 @@ The following code configures vLLM in an offline mode to use speculative decodin
|
||||
To perform the same with an online mode launch the server:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
vllm serve facebook/opt-6.7b \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000 \
|
||||
--model facebook/opt-6.7b \
|
||||
--seed 42 \
|
||||
-tp 1 \
|
||||
--gpu_memory_utilization 0.8 \
|
||||
|
@ -67,8 +67,7 @@ docker run -it \
|
||||
XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. For **pipeline parallel**, we support it on single node with mp as the backend. For example, a reference execution like following:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model=facebook/opt-13b \
|
||||
vllm serve facebook/opt-13b \
|
||||
--dtype=bfloat16 \
|
||||
--max_model_len=1024 \
|
||||
--distributed-executor-backend=mp \
|
||||
|
@ -21,4 +21,4 @@ while IFS='=' read -r key value; do
|
||||
done < <(env | grep "^${PREFIX}")
|
||||
|
||||
# Pass the collected arguments to the main entrypoint
|
||||
exec python3 -m vllm.entrypoints.openai.api_server "${ARGS[@]}"
|
||||
exec vllm serve "${ARGS[@]}"
|
@ -786,13 +786,43 @@ def test_model_specification(parser_with_config, cli_config_file,
|
||||
parser_with_config.parse_args(['serve', '--config', cli_config_file])
|
||||
|
||||
# Test using --model option raises error
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=
|
||||
("With `vllm serve`, you should provide the model as a positional "
|
||||
"argument or in a config file instead of via the `--model` option."),
|
||||
):
|
||||
parser_with_config.parse_args(['serve', '--model', 'my-model'])
|
||||
# with pytest.raises(
|
||||
# ValueError,
|
||||
# match=
|
||||
# ("With `vllm serve`, you should provide the model as a positional "
|
||||
# "argument or in a config file instead of via the `--model` option."),
|
||||
# ):
|
||||
# parser_with_config.parse_args(['serve', '--model', 'my-model'])
|
||||
|
||||
# Test using --model option back-compatibility
|
||||
# (when back-compatibility ends, the above test should be uncommented
|
||||
# and the below test should be removed)
|
||||
args = parser_with_config.parse_args([
|
||||
'serve',
|
||||
'--tensor-parallel-size',
|
||||
'2',
|
||||
'--model',
|
||||
'my-model',
|
||||
'--trust-remote-code',
|
||||
'--port',
|
||||
'8001',
|
||||
])
|
||||
assert args.model is None
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code is True
|
||||
assert args.port == 8001
|
||||
|
||||
args = parser_with_config.parse_args([
|
||||
'serve',
|
||||
'--tensor-parallel-size=2',
|
||||
'--model=my-model',
|
||||
'--trust-remote-code',
|
||||
'--port=8001',
|
||||
])
|
||||
assert args.model is None
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code is True
|
||||
assert args.port == 8001
|
||||
|
||||
# Test other config values are preserved
|
||||
args = parser_with_config.parse_args([
|
||||
|
@ -1855,13 +1855,37 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
|
||||
# Check for --model in command line arguments first
|
||||
if args and args[0] == "serve":
|
||||
model_in_cli_args = any(arg == '--model' for arg in args)
|
||||
|
||||
if model_in_cli_args:
|
||||
raise ValueError(
|
||||
try:
|
||||
model_idx = next(
|
||||
i for i, arg in enumerate(args)
|
||||
if arg == "--model" or arg.startswith("--model="))
|
||||
logger.warning(
|
||||
"With `vllm serve`, you should provide the model as a "
|
||||
"positional argument or in a config file instead of via "
|
||||
"the `--model` option.")
|
||||
"the `--model` option. "
|
||||
"The `--model` option will be removed in v0.13.")
|
||||
|
||||
if args[model_idx] == "--model":
|
||||
model_tag = args[model_idx + 1]
|
||||
rest_start_idx = model_idx + 2
|
||||
else:
|
||||
model_tag = args[model_idx].removeprefix("--model=")
|
||||
rest_start_idx = model_idx + 1
|
||||
|
||||
# Move <model> to the front, e,g:
|
||||
# [Before]
|
||||
# vllm serve -tp 2 --model <model> --enforce-eager --port 8001
|
||||
# [After]
|
||||
# vllm serve <model> -tp 2 --enforce-eager --port 8001
|
||||
args = [
|
||||
"serve",
|
||||
model_tag,
|
||||
*args[1:model_idx],
|
||||
*args[rest_start_idx:],
|
||||
]
|
||||
print("args", args)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
if '--config' in args:
|
||||
args = self._pull_args_from_config(args)
|
||||
|
Reference in New Issue
Block a user