[Bugfix] Validate lora adapters to avoid crashing server (#11727)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Joe Runde
2025-01-10 00:56:36 -07:00
committed by GitHub
parent cf5f000d21
commit ac2f3f7fee
15 changed files with 460 additions and 172 deletions

View File

@ -0,0 +1,269 @@
import asyncio
import json
import shutil
from contextlib import suppress
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def server_with_lora_modules_json(zephyr_lora_files):
# Define the json format LoRA module configurations
lora_module_1 = {
"name": "zephyr-lora",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}
lora_module_2 = {
"name": "zephyr-lora2",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
json.dumps(lora_module_1),
json.dumps(lora_module_2),
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"64",
]
# Enable the /v1/load_lora_adapter endpoint
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server_with_lora_modules_json):
async with server_with_lora_modules_json.get_async_client(
) as async_client:
yield async_client
@pytest.mark.asyncio
async def test_static_lora_lineage(client: openai.AsyncOpenAI,
zephyr_lora_files):
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert served_model.root == MODEL_NAME
assert served_model.parent is None
assert all(lora_model.root == zephyr_lora_files
for lora_model in lora_models)
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.asyncio
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI,
zephyr_lora_files):
response = await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "zephyr-lora-3",
"lora_path": zephyr_lora_files
})
# Ensure adapter loads before querying /models
assert "success" in response
models = await client.models.list()
models = models.data
dynamic_lora_model = models[-1]
assert dynamic_lora_model.root == zephyr_lora_files
assert dynamic_lora_model.parent == MODEL_NAME
assert dynamic_lora_model.id == "zephyr-lora-3"
@pytest.mark.asyncio
async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI):
with pytest.raises(openai.NotFoundError):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "notfound",
"lora_path": "/not/an/adapter"
})
@pytest.mark.asyncio
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
tmp_path):
invalid_files = tmp_path / "invalid_files"
invalid_files.mkdir()
(invalid_files / "adapter_config.json").write_text("this is not json")
with pytest.raises(openai.BadRequestError):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_files)
})
@pytest.mark.asyncio
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
tmp_path, zephyr_lora_files):
invalid_rank = tmp_path / "invalid_rank"
# Copy adapter from zephyr_lora_files to invalid_rank
shutil.copytree(zephyr_lora_files, invalid_rank)
with open(invalid_rank / "adapter_config.json") as f:
adapter_config = json.load(f)
print(adapter_config)
# assert False
# Change rank to invalid value
adapter_config["r"] = 1024
with open(invalid_rank / "adapter_config.json", "w") as f:
json.dump(adapter_config, f)
with pytest.raises(openai.BadRequestError,
match="is greater than max_lora_rank"):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_rank)
})
@pytest.mark.asyncio
async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path,
zephyr_lora_files):
"""Validate that many loras can be dynamically registered and inferenced
with concurrently"""
# This test file configures the server with --max-cpu-loras=2 and this test
# will concurrently load 10 adapters, so it should flex the LRU cache
async def load_and_run_adapter(adapter_name: str):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": adapter_name,
"lora_path": str(zephyr_lora_files)
})
for _ in range(3):
await client.completions.create(
model=adapter_name,
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
)
lora_tasks = []
for i in range(10):
lora_tasks.append(
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
results, _ = await asyncio.wait(lora_tasks)
for r in results:
assert not isinstance(r, Exception), f"Got exception {r}"
@pytest.mark.asyncio
async def test_loading_invalid_adapters_does_not_break_others(
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files):
invalid_files = tmp_path / "invalid_files"
invalid_files.mkdir()
(invalid_files / "adapter_config.json").write_text("this is not json")
stop_good_requests_event = asyncio.Event()
async def run_good_requests(client):
# Run chat completions requests until event set
results = []
while not stop_good_requests_event.is_set():
try:
batch = await client.completions.create(
model="zephyr-lora",
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
)
results.append(batch)
except Exception as e:
results.append(e)
return results
# Create task to run good requests
good_task = asyncio.create_task(run_good_requests(client))
# Run a bunch of bad adapter loads
for _ in range(25):
with suppress(openai.NotFoundError):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "notfound",
"lora_path": "/not/an/adapter"
})
for _ in range(25):
with suppress(openai.BadRequestError):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid",
"lora_path": str(invalid_files)
})
# Ensure all the running requests with lora adapters succeeded
stop_good_requests_event.set()
results = await good_task
for r in results:
assert not isinstance(r, Exception), f"Got exception {r}"
# Ensure we can load another adapter and run it
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "valid",
"lora_path": zephyr_lora_files
})
await client.completions.create(
model="valid",
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
)

View File

@ -1,109 +0,0 @@
import json
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def server_with_lora_modules_json(zephyr_lora_files):
# Define the json format LoRA module configurations
lora_module_1 = {
"name": "zephyr-lora",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}
lora_module_2 = {
"name": "zephyr-lora2",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
json.dumps(lora_module_1),
json.dumps(lora_module_2),
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"64",
]
# Enable the /v1/load_lora_adapter endpoint
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client_for_lora_lineage(server_with_lora_modules_json):
async with server_with_lora_modules_json.get_async_client(
) as async_client:
yield async_client
@pytest.mark.asyncio
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
zephyr_lora_files):
models = await client_for_lora_lineage.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert served_model.root == MODEL_NAME
assert served_model.parent is None
assert all(lora_model.root == zephyr_lora_files
for lora_model in lora_models)
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.asyncio
async def test_dynamic_lora_lineage(
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):
response = await client_for_lora_lineage.post("load_lora_adapter",
cast_to=str,
body={
"lora_name":
"zephyr-lora-3",
"lora_path":
zephyr_lora_files
})
# Ensure adapter loads before querying /models
assert "success" in response
models = await client_for_lora_lineage.models.list()
models = models.data
dynamic_lora_model = models[-1]
assert dynamic_lora_model.root == zephyr_lora_files
assert dynamic_lora_model.parent == MODEL_NAME
assert dynamic_lora_model.id == "zephyr-lora-3"

View File

@ -52,7 +52,7 @@ async def _async_serving_chat_init():
engine = MockEngine()
model_config = await engine.get_model_config()
models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS)
serving_completion = OpenAIServingChat(engine,
model_config,
models,
@ -73,7 +73,8 @@ def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
@ -116,7 +117,8 @@ def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,

View File

@ -4,6 +4,7 @@ from unittest.mock import MagicMock
import pytest
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
@ -21,13 +22,16 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config = MagicMock(spec=ModelConfig)
mock_engine_client = MagicMock(spec=EngineClient)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048
serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
lora_modules=None,
prompt_adapters=None)
await serving_models.init_static_loras()
return serving_models
@ -113,5 +117,5 @@ async def test_unload_lora_adapter_not_found():
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
assert response.type == "NotFoundError"
assert response.code == HTTPStatus.NOT_FOUND

View File

@ -1,6 +1,3 @@
import json
import os
import openai
import pytest
@ -10,16 +7,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B"
@pytest.mark.asyncio
async def test_shutdown_on_engine_failure(tmp_path):
# Use a bad adapter to crash the engine
# (This test will fail when that bug is fixed)
adapter_path = tmp_path / "bad_adapter"
os.mkdir(adapter_path)
with open(adapter_path / "adapter_model_config.json", "w") as f:
json.dump({"not": "real"}, f)
with open(adapter_path / "adapter_model.safetensors", "wb") as f:
f.write(b"this is fake")
async def test_shutdown_on_engine_failure():
# dtype, max-len etc set so that this can run in CI
args = [
"--dtype",
@ -29,9 +17,6 @@ async def test_shutdown_on_engine_failure(tmp_path):
"--enforce-eager",
"--max-num-seqs",
"128",
"--enable-lora",
"--lora-modules",
f"bad-adapter={tmp_path / 'bad_adapter'}",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -39,9 +24,13 @@ async def test_shutdown_on_engine_failure(tmp_path):
with pytest.raises(
(openai.APIConnectionError, openai.InternalServerError)):
# This crashes the engine
await client.completions.create(model="bad-adapter",
prompt="Hello, my name is")
# Asking for lots of prompt logprobs will currently crash the
# engine. This may change in the future when that bug is fixed
prompt = "Hello " * 4000
await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
extra_body={"prompt_logprobs": 10})
# Now the server should shut down
return_code = remote_server.proc.wait(timeout=8)

View File

@ -1257,6 +1257,10 @@ class AsyncLLMEngine(EngineClient):
else:
self.engine.model_executor._run_workers("stop_profile")
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
self.engine.add_lora(lora_request)
# TODO(v1): Remove this class proxy when V1 goes default.
if envs.VLLM_USE_V1:

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Mapping, Optional, Union, overload
@ -120,10 +121,23 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest]
@dataclass
class RPCLoadAdapterRequest:
lora_request: LoRARequest
# Set the default value of request_id to a new UUID
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
@dataclass
class RPCAdapterLoadedResponse:
request_id: str
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError]
def ENGINE_DEAD_ERROR(

View File

@ -25,8 +25,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
@ -240,17 +242,22 @@ class MQLLMEngineClient(EngineClient):
queue = self.output_queues.get(request_id)
if queue is not None:
queue.put_nowait(exception)
# Put each output into the appropriate queue.
elif isinstance(request_outputs, RPCAdapterLoadedResponse):
self._add_output(request_outputs)
else:
# Put each output into the appropriate steam.
for request_output in request_outputs:
queue = self.output_queues.get(
request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
self._add_output(request_output)
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient output handler.")
def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse]):
queue = self.output_queues.get(request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
async def setup(self):
"""Setup the client before it starts sending server requests."""
@ -659,3 +666,24 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request)
# Create output queue for this requests.
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
self.output_queues[request.request_id] = queue
# Send the request
request_bytes = pickle.dumps(request)
await self.input_socket.send_multipart((request_bytes, ), copy=False)
# Wait for the response
request_output = await queue.get()
self.output_queues.pop(request.request_id)
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output

View File

@ -14,8 +14,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.executor.gpu_executor import GPUExecutor
@ -234,6 +236,8 @@ class MQLLMEngine:
self.start_profile()
else:
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
@ -284,6 +288,19 @@ class MQLLMEngine:
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
try:
self.engine.add_lora(request.lora_request)
except BaseException as e:
# Send back an error if the adater fails to load
rpc_err = RPCError(request_id=request.request_id,
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))
def _health_check(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
@ -296,7 +313,11 @@ class MQLLMEngine:
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if outputs:
try:
from ray.exceptions import RayTaskError

View File

@ -270,3 +270,8 @@ class EngineClient(ABC):
async def stop_profile(self) -> None:
"""Start profiling the engine"""
...
@abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
...

View File

@ -662,7 +662,7 @@ def build_app(args: Namespace) -> FastAPI:
return app
def init_app_state(
async def init_app_state(
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
@ -690,12 +690,13 @@ def init_app_state(
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
)
# TODO: The chat template is now broken for lora adapters :(
await state.openai_serving_models.init_static_loras()
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
@ -794,7 +795,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
await init_app_state(engine_client, model_config, app.state, args)
shutdown_task = await serve_http(
app,

View File

@ -215,6 +215,7 @@ async def main(args):
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,

View File

@ -5,15 +5,19 @@ from http import HTTPStatus
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
UnloadLoraAdapterRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
@ -45,6 +49,7 @@ class OpenAIServingModels:
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
@ -55,20 +60,11 @@ class OpenAIServingModels:
self.base_model_paths = base_model_paths
self.max_model_len = model_config.max_model_len
self.engine_client = engine_client
self.static_lora_modules = lora_modules
self.lora_requests: List[LoRARequest] = []
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self.is_base_model(lora.base_model_name) else
self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
@ -84,6 +80,19 @@ class OpenAIServingModels:
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
async def init_static_loras(self):
"""Loads all static LoRA modules.
Raises if any fail to load"""
if self.static_lora_modules is None:
return
for lora in self.static_lora_modules:
load_request = LoadLoraAdapterRequest(lora_path=lora.path,
lora_name=lora.name)
load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name)
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.message)
def is_base_model(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
@ -129,17 +138,47 @@ class OpenAIServingModels:
async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
request: LoadLoraAdapterRequest,
base_model_name: Optional[str] = None
) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
lora_request = LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path)
if base_model_name is not None and self.is_base_model(base_model_name):
lora_request.base_model_name = base_model_name
# Validate that the adapter can be loaded into the engine
# This will also pre-load it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except ValueError as e:
# Adapter not found or lora configuration errors
if "No adapter found" in str(e):
return create_error_response(message=str(e),
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
else:
return create_error_response(
message=str(e),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
except BaseException as e:
# Some other unexpected problem loading the adapter, e.g. malformed
# input files.
# More detailed error messages for the user would be nicer here
return create_error_response(message=str(e),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
self.lora_requests.append(lora_request)
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
lora_path)
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
@ -155,6 +194,7 @@ class OpenAIServingModels:
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
logger.info("Removed LoRA adapter: name '%s'", lora_name)
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
@ -195,8 +235,8 @@ class OpenAIServingModels:
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
return None

View File

@ -115,6 +115,14 @@ class WorkerLoRAManager(AbstractWorkerManager):
embedding_padding_modules=self.embedding_padding_modules,
weights_mapper=hf_to_vllm_mapper)
except FileNotFoundError as e:
# FileNotFoundError should be raised if both
# - No adapter found to download from huggingface (or in
# offline mode)
# - No local adapter files found at `lora_request.lora_path`
raise ValueError(
f"Loading lora {lora_request.lora_name} failed: No adapter "
f"found for {lora_path}") from e
except Exception as e:
raise RuntimeError(f"Loading lora {lora_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
@ -209,12 +217,19 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
def add_adapter(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_adapters():
# Remove before we load the new lora to save memory
# Load the new adapter first to ensure it is actually valid, before
# evicting any existing adapters.
# This may cause the # of loaded lora adapters to very temporarily
# exceed `--max-cpu-loras`.
lora = self._load_adapter(lora_request)
# Loading succeeded, now check if we will exceed cache capacity and
# evict if the oldest adapter if so
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
assert isinstance(self._adapter_manager,
LRUCacheLoRAModelManager)
self._adapter_manager.remove_oldest_adapter()
lora = self._load_adapter(lora_request)
# Then add the new adapter to the cache
loaded = self._adapter_manager.add_adapter(lora)
else:
# If the lora is already loaded, just touch it to

View File

@ -339,3 +339,7 @@ class AsyncLLM(EngineClient):
@property
def dead_error(self) -> BaseException:
return Exception() # TODO: implement
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
raise NotImplementedError("LoRA not yet supported in V1")