mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc][LoRA] Add PEFTHelper for LoRA (#11003)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
@ -13,6 +14,7 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager)
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
@ -30,18 +32,68 @@ CUDA_DEVICES = [
|
||||
]
|
||||
|
||||
|
||||
def test_peft_helper(sql_lora_files):
|
||||
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
assert peft_helper.r == 8
|
||||
assert peft_helper.lora_alpha == 16
|
||||
assert peft_helper.target_modules == [
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
|
||||
expected_error = "vLLM only supports modules_to_save being None."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
config = dict(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
modules_to_save=["lm_head"],
|
||||
)
|
||||
PEFTHelper.from_dict(config)
|
||||
expected_error = "vLLM does not yet support RSLoRA."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
config = dict(r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
use_rslora=True)
|
||||
PEFTHelper.from_dict(config)
|
||||
|
||||
expected_error = "vLLM does not yet support DoRA."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
config = dict(r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
use_dora=True)
|
||||
PEFTHelper.from_dict(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(
|
||||
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||
new_embeddings = load_file(
|
||||
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
||||
|
||||
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
lora_model = LoRAModel.from_lora_tensors(
|
||||
1,
|
||||
8,
|
||||
16,
|
||||
tensors,
|
||||
device,
|
||||
peft_helper=peft_helper,
|
||||
device=device,
|
||||
embeddings=new_embeddings,
|
||||
embedding_modules=EMBEDDING_MODULES,
|
||||
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
|
||||
|
@ -4,6 +4,7 @@ from typing import Sequence as GenericSequence
|
||||
import torch
|
||||
import torch.types
|
||||
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -59,6 +60,23 @@ class LoRALayerWeights:
|
||||
return self.embeddings_tensor.shape[
|
||||
0] if self.embeddings_tensor is not None else 0
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
module_name: str,
|
||||
peft_helper: PEFTHelper,
|
||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||
) -> "LoRALayerWeights":
|
||||
return cls(
|
||||
module_name,
|
||||
peft_helper.r,
|
||||
peft_helper.lora_alpha,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
embeddings_tensor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_dummy_lora_weights(
|
||||
cls,
|
||||
|
@ -21,6 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLora,
|
||||
LoRAMapping)
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
is_regex_target_modules,
|
||||
@ -104,14 +105,12 @@ class LoRAModel(AdapterModel):
|
||||
def from_lora_tensors(
|
||||
cls,
|
||||
lora_model_id: int,
|
||||
rank: int,
|
||||
lora_alpha: int,
|
||||
tensors: Dict[str, torch.Tensor],
|
||||
peft_helper: PEFTHelper,
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
scaling_factor: Optional[float] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
) -> "LoRAModel":
|
||||
@ -135,10 +134,9 @@ class LoRAModel(AdapterModel):
|
||||
if pin_memory:
|
||||
lora_embeddings_tensor = (
|
||||
lora_embeddings_tensor.pin_memory())
|
||||
loras[module_name] = LoRALayerWeights(module_name, rank,
|
||||
lora_alpha, None, None,
|
||||
None,
|
||||
lora_embeddings_tensor)
|
||||
loras[module_name] = LoRALayerWeights.from_config(
|
||||
module_name, peft_helper, lora_embeddings_tensor)
|
||||
|
||||
if is_bias:
|
||||
loras[module_name].bias = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
@ -170,7 +168,11 @@ class LoRAModel(AdapterModel):
|
||||
|
||||
for lora in loras.values():
|
||||
lora.optimize()
|
||||
return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
|
||||
|
||||
return cls(lora_model_id,
|
||||
peft_helper.r,
|
||||
loras,
|
||||
scaling_factor=peft_helper.vllm_scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
@ -212,6 +214,9 @@ class LoRAModel(AdapterModel):
|
||||
"new_embeddings.bin")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config["vllm_max_position_embeddings"] = max_position_embeddings
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
if os.path.isfile(lora_tensor_path):
|
||||
tensors: Dict[str, torch.Tensor] = {}
|
||||
# Find unexpected modules.
|
||||
@ -242,7 +247,7 @@ class LoRAModel(AdapterModel):
|
||||
# When a bin file is provided, we rely on config to find unexpected
|
||||
# modules.
|
||||
unexpected_modules = []
|
||||
target_modules = config["target_modules"]
|
||||
target_modules = peft_helper.target_modules
|
||||
if not isinstance(target_modules, list):
|
||||
target_modules = [target_modules]
|
||||
for module in target_modules:
|
||||
@ -256,7 +261,7 @@ class LoRAModel(AdapterModel):
|
||||
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
||||
# other better mechanism.
|
||||
if unexpected_modules and not is_regex_target_modules(
|
||||
config["target_modules"], expected_lora_modules):
|
||||
peft_helper.target_modules, expected_lora_modules):
|
||||
raise ValueError(
|
||||
f"While loading {lora_dir}, expected"
|
||||
f" target modules in {expected_lora_modules}"
|
||||
@ -274,30 +279,17 @@ class LoRAModel(AdapterModel):
|
||||
embeddings = torch.load(new_embeddings_bin_file_path,
|
||||
map_location=device)
|
||||
|
||||
rank = config["r"]
|
||||
lora_alpha = config["lora_alpha"]
|
||||
context_length = config.get("context_length", None)
|
||||
scaling_factor = None
|
||||
if context_length:
|
||||
if max_position_embeddings is None:
|
||||
max_position_embeddings = context_length
|
||||
scaling_factor = float(
|
||||
math.ceil(context_length / max_position_embeddings))
|
||||
|
||||
return cls.from_lora_tensors(
|
||||
lora_model_id=get_lora_id()
|
||||
if lora_model_id is None else lora_model_id,
|
||||
rank=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
tensors=tensors,
|
||||
peft_helper=peft_helper,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
embeddings=embeddings,
|
||||
target_embedding_padding=target_embedding_padding,
|
||||
scaling_factor=scaling_factor,
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embedding_padding_modules,
|
||||
)
|
||||
embedding_padding_modules=embedding_padding_modules)
|
||||
|
||||
|
||||
class LoRAModelManager(AdapterModelManager):
|
||||
|
70
vllm/lora/peft_helper.py
Normal file
70
vllm/lora/peft_helper.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
|
||||
|
||||
import math
|
||||
from dataclasses import MISSING, dataclass, field, fields
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class PEFTHelper:
|
||||
# Required fields
|
||||
r: int
|
||||
lora_alpha: int
|
||||
target_modules: Union[list[str], str]
|
||||
|
||||
bias: Literal["none", "all", "lora_only"] = field(default="none")
|
||||
modules_to_save: Optional[list[str]] = field(default=None)
|
||||
use_rslora: bool = field(default=False)
|
||||
use_dora: bool = field(default=False)
|
||||
# long lora field
|
||||
context_length: int = field(default=0)
|
||||
# Extra vllm field, start with 'vllm_' to avoid conflict
|
||||
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
||||
vllm_scaling_factor: Optional[float] = field(default=None)
|
||||
|
||||
def _validate_features(self):
|
||||
error_msg = []
|
||||
|
||||
if self.modules_to_save:
|
||||
error_msg.append("vLLM only supports modules_to_save being None.")
|
||||
if self.use_rslora:
|
||||
error_msg.append("vLLM does not yet support RSLoRA.")
|
||||
|
||||
if self.use_dora:
|
||||
error_msg.append("vLLM does not yet support DoRA.")
|
||||
|
||||
if error_msg:
|
||||
raise ValueError(f"{', '.join(error_msg)}")
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate_features()
|
||||
if self.context_length:
|
||||
if self.vllm_max_position_embeddings is None:
|
||||
self.vllm_max_position_embeddings = self.context_length
|
||||
self.vllm_scaling_factor = float(
|
||||
math.ceil(self.context_length /
|
||||
self.vllm_max_position_embeddings))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
|
||||
# Get all field information from the class
|
||||
class_fields = {f.name: f for f in fields(cls)}
|
||||
# Check for required fields
|
||||
required_fields = {
|
||||
name
|
||||
for name, f in class_fields.items()
|
||||
if f.default is MISSING and f.default_factory is MISSING
|
||||
}
|
||||
|
||||
# Identify any missing required fields
|
||||
missing_fields = required_fields - set(config_dict.keys())
|
||||
if missing_fields:
|
||||
raise ValueError(
|
||||
f"Missing required configuration fields: {missing_fields}")
|
||||
|
||||
# Filter out fields that aren't defined in the class
|
||||
filtered_dict = {
|
||||
k: v
|
||||
for k, v in config_dict.items() if k in class_fields
|
||||
}
|
||||
return cls(**filtered_dict)
|
Reference in New Issue
Block a user