[Lora][Frontend]Add default local directory LoRA resolver plugin. (#16855)

Signed-off-by: jberkhahn <jaberkha@us.ibm.com>
This commit is contained in:
Jonathan Berkhahn
2025-05-12 10:39:10 -07:00
committed by GitHub
parent d19110204c
commit 98ea35601c
9 changed files with 146 additions and 3 deletions

View File

@ -628,7 +628,7 @@ steps:
- vllm/plugins/
- tests/plugins/
commands:
# begin platform plugin tests, all the code in-between runs on dummy platform
# begin platform plugin and general plugin tests, all the code in-between runs on dummy platform
- pip install -e ./plugins/vllm_add_dummy_platform
- pytest -v -s plugins_tests/test_platform_plugins.py
- pip uninstall vllm_add_dummy_platform -y
@ -639,6 +639,7 @@ steps:
- pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
- label: Multi-step Tests (4 GPUs) # 36min
mirror_hardwares: [amdexperimental]

View File

@ -159,9 +159,12 @@ Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adap
You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds.
You can either install existing plugins or implement your own.
You can either install existing plugins or implement your own. By default, vLLM comes with a [resolver plugin to load LoRA adapters from a local directory.](https://github.com/vllm-project/vllm/tree/main/vllm/plugins/lora_resolvers)
To enable this resolver, set `VLLM_ALLOW_RUNTIME_LORA_UPDATING` to True, set `VLLM_PLUGINS` to include `lora_filesystem_resolver`, and then set `VLLM_LORA_RESOLVER_CACHE_DIR` to a local directory. When vLLM receives a request using a LoRA adapter `foobar`,
it will first look in the local directory for a directory `foobar`, and attempt to load the contents of that directory as a LoRA adapter. If successful, the request will complete as normal and
that adapter will then be available for normal use on the server.
Steps to implement your own LoRAResolver plugin:
Alternatively, follow these example steps to implement your own plugin:
1. Implement the LoRAResolver interface.
Example of a simple S3 LoRAResolver implementation:

View File

@ -41,6 +41,9 @@ Slack="http://slack.vllm.ai/"
[project.scripts]
vllm = "vllm.entrypoints.cli.main:main"
[project.entry-points."vllm.general_plugins"]
lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver"
[tool.setuptools_scm]
# no extra settings needed, presence enables setuptools-scm

View File

View File

@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
import os
import shutil
import pytest
from huggingface_hub import snapshot_download
from vllm.plugins.lora_resolvers.filesystem_resolver import FilesystemResolver
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
@pytest.fixture(scope='module')
def adapter_cache(request, tmpdir_factory):
# Create dir that mimics the structure of the adapter cache
adapter_cache = tmpdir_factory.mktemp(
request.module.__name__) / "adapter_cache"
return adapter_cache
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def pa_files():
return snapshot_download(repo_id=PA_NAME)
@pytest.mark.asyncio
async def test_filesystem_resolver(adapter_cache, zephyr_lora_files):
model_files = adapter_cache / LORA_NAME
shutil.copytree(zephyr_lora_files, model_files)
fs_resolver = FilesystemResolver(adapter_cache)
assert fs_resolver is not None
lora_request = await fs_resolver.resolve_lora(MODEL_NAME, LORA_NAME)
assert lora_request is not None
assert lora_request.lora_name == LORA_NAME
assert lora_request.lora_path == os.path.join(adapter_cache, LORA_NAME)
@pytest.mark.asyncio
async def test_missing_adapter(adapter_cache):
fs_resolver = FilesystemResolver(adapter_cache)
assert fs_resolver is not None
missing_lora_request = await fs_resolver.resolve_lora(MODEL_NAME, "foobar")
assert missing_lora_request is None
@pytest.mark.asyncio
async def test_nonlora_adapter(adapter_cache, pa_files):
model_files = adapter_cache / PA_NAME
shutil.copytree(pa_files, model_files)
fs_resolver = FilesystemResolver(adapter_cache)
assert fs_resolver is not None
pa_request = await fs_resolver.resolve_lora(MODEL_NAME, PA_NAME)
assert pa_request is None

View File

@ -68,6 +68,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_RPC_TIMEOUT: int = 10000 # ms
VLLM_PLUGINS: Optional[list[str]] = None
VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
@ -503,6 +504,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
"VLLM_PLUGINS"].split(","),
# a local directory to look in for unrecognized LoRA adapters.
# only works if plugins are enabled and
# VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
"VLLM_LORA_RESOLVER_CACHE_DIR":
lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR":

View File

@ -0,0 +1,15 @@
# LoRA Resolver Plugins
This directory contains vLLM general plugins for dynamically discovering and loading LoRA adapters
via the LoRAResolver plugin framework.
Note that `VLLM_ALLOW_RUNTIME_LORA_UPDATING` must be set to true to allow LoRA resolver plugins
to work, and `VLLM_PLUGINS` must be set to include the desired resolver plugins.
# lora_filesystem_resolver
This LoRA Resolver is installed with vLLM by default.
To use, set `VLLM_PLUGIN_LORA_CACHE_DIR` to a local directory. When vLLM receives a request
for a LoRA adapter `foobar` it doesn't currently recognize, it will look in that local directory
for a subdirectory `foobar` containing a LoRA adapter. If such an adapter exists, it will
load that adapter, and then service the request as normal. That adapter will then be available
for future requests as normal.

View File

View File

@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
import json
import os
from typing import Optional
import vllm.envs as envs
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
class FilesystemResolver(LoRAResolver):
def __init__(self, lora_cache_dir: str):
self.lora_cache_dir = lora_cache_dir
async def resolve_lora(self, base_model_name: str,
lora_name: str) -> Optional[LoRARequest]:
lora_path = os.path.join(self.lora_cache_dir, lora_name)
if os.path.exists(lora_path):
adapter_config_path = os.path.join(self.lora_cache_dir, lora_name,
"adapter_config.json")
if os.path.exists(adapter_config_path):
with open(adapter_config_path) as file:
adapter_config = json.load(file)
if adapter_config["peft_type"] == "LORA" and adapter_config[
"base_model_name_or_path"] == base_model_name:
lora_request = LoRARequest(lora_name=lora_name,
lora_int_id=abs(
hash(lora_name)),
lora_path=lora_path)
return lora_request
return None
def register_filesystem_resolver():
"""Register the filesystem LoRA Resolver with vLLM"""
lora_cache_dir = envs.VLLM_LORA_RESOLVER_CACHE_DIR
if lora_cache_dir:
if not os.path.exists(lora_cache_dir) or not os.path.isdir(
lora_cache_dir):
raise ValueError(
"VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory \
for Filesystem Resolver plugin to function")
fs_resolver = FilesystemResolver(lora_cache_dir)
LoRAResolverRegistry.register_resolver("Filesystem Resolver",
fs_resolver)
return