[mypy] Fix mypy warnings in api_server.py (#11941)

Signed-off-by: Fred Reiss <frreiss@us.ibm.com>
This commit is contained in:
Fred Reiss
2025-01-10 17:04:58 -08:00
committed by GitHub
parent d45cbe70f5
commit c9f09a4fe8

View File

@ -14,7 +14,7 @@ from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set, Tuple
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
import uvloop
from fastapi import APIRouter, FastAPI, HTTPException, Request
@ -420,6 +420,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
generator: Union[ErrorResponse, EmbeddingResponse]
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
@ -494,7 +496,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
TASK_HANDLERS = {
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
@ -652,7 +654,7 @@ def build_app(args: Namespace) -> FastAPI:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
app.add_middleware(imported) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else: