mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Minimum requirements for SageMaker compatibility (#11576)
This commit is contained in:
13
Dockerfile
13
Dockerfile
@ -234,8 +234,8 @@ RUN mv vllm test_docs/
|
||||
#################### TEST IMAGE ####################
|
||||
|
||||
#################### OPENAI API SERVER ####################
|
||||
# openai api server alternative
|
||||
FROM vllm-base AS vllm-openai
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
@ -247,5 +247,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
# define sagemaker first, so it is not default from `docker build`
|
||||
FROM vllm-openai-base AS vllm-sagemaker
|
||||
|
||||
COPY examples/sagemaker-entrypoint.sh .
|
||||
RUN chmod +x sagemaker-entrypoint.sh
|
||||
ENTRYPOINT ["./sagemaker-entrypoint.sh"]
|
||||
|
||||
FROM vllm-openai-base AS vllm-openai
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
#################### OPENAI API SERVER ####################
|
||||
|
24
examples/sagemaker-entrypoint.sh
Normal file
24
examples/sagemaker-entrypoint.sh
Normal file
@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Define the prefix for environment variables to look for
|
||||
PREFIX="SM_VLLM_"
|
||||
ARG_PREFIX="--"
|
||||
|
||||
# Initialize an array for storing the arguments
|
||||
# port 8080 required by sagemaker, https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
|
||||
ARGS=(--port 8080)
|
||||
|
||||
# Loop through all environment variables
|
||||
while IFS='=' read -r key value; do
|
||||
# Remove the prefix from the key, convert to lowercase, and replace underscores with dashes
|
||||
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
|
||||
|
||||
# Add the argument name and value to the ARGS array
|
||||
ARGS+=("${ARG_PREFIX}${arg_name}")
|
||||
if [ -n "$value" ]; then
|
||||
ARGS+=("$value")
|
||||
fi
|
||||
done < <(env | grep "^${PREFIX}")
|
||||
|
||||
# Pass the collected arguments to the main entrypoint
|
||||
exec python3 -m vllm.entrypoints.openai.api_server "${ARGS[@]}"
|
@ -16,7 +16,7 @@ from http import HTTPStatus
|
||||
from typing import AsyncIterator, Optional, Set, Tuple
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
@ -44,11 +44,15 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
@ -310,6 +314,12 @@ async def health(raw_request: Request) -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.api_route("/ping", methods=["GET", "POST"])
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
@ -483,6 +493,54 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
TASK_HANDLERS = {
|
||||
"generate": {
|
||||
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||
"default": (CompletionRequest, create_completion),
|
||||
},
|
||||
"embed": {
|
||||
"messages": (EmbeddingChatRequest, create_embedding),
|
||||
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||
},
|
||||
"score": {
|
||||
"default": (ScoreRequest, create_score),
|
||||
},
|
||||
"reward": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
"default": (PoolingCompletionRequest, create_pooling),
|
||||
},
|
||||
"classify": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
"default": (PoolingCompletionRequest, create_pooling),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/invocations")
|
||||
async def invocations(raw_request: Request):
|
||||
"""
|
||||
For SageMaker, routes requests to other handlers based on model `task`.
|
||||
"""
|
||||
body = await raw_request.json()
|
||||
task = raw_request.app.state.task
|
||||
|
||||
if task not in TASK_HANDLERS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported task: '{task}' for '/invocations'. "
|
||||
f"Expected one of {set(TASK_HANDLERS.keys())}")
|
||||
|
||||
handler_config = TASK_HANDLERS[task]
|
||||
if "messages" in body:
|
||||
request_model, handler = handler_config["messages"]
|
||||
else:
|
||||
request_model, handler = handler_config["default"]
|
||||
|
||||
# this is required since we lose the FastAPI automatic casting
|
||||
request = request_model.model_validate(body)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
@ -687,6 +745,7 @@ def init_app_state(
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
)
|
||||
state.task = model_config.task
|
||||
|
||||
|
||||
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
||||
|
Reference in New Issue
Block a user