mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Reduce LoRA-related static variable (#13166)
This commit is contained in:
@ -23,6 +23,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -98,9 +99,13 @@ def dist_init_torch_only():
|
||||
backend=backend)
|
||||
|
||||
|
||||
class DummyLoRAModel(nn.Sequential, SupportsLoRA):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_model() -> nn.Module:
|
||||
model = nn.Sequential(
|
||||
model = DummyLoRAModel(
|
||||
OrderedDict([
|
||||
("dense1", ColumnParallelLinear(764, 100)),
|
||||
("dense2", RowParallelLinear(100, 50)),
|
||||
@ -121,12 +126,13 @@ def dummy_model() -> nn.Module:
|
||||
("sampler", Sampler())
|
||||
]))
|
||||
model.config = MagicMock()
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_model_gate_up() -> nn.Module:
|
||||
model = nn.Sequential(
|
||||
model = DummyLoRAModel(
|
||||
OrderedDict([
|
||||
("dense1", ColumnParallelLinear(764, 100)),
|
||||
("dense2", RowParallelLinear(100, 50)),
|
||||
@ -147,6 +153,13 @@ def dummy_model_gate_up() -> nn.Module:
|
||||
("sampler", Sampler())
|
||||
]))
|
||||
model.config = MagicMock()
|
||||
model.packed_modules_mapping = {
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
return model
|
||||
|
||||
|
||||
|
@ -12,6 +12,12 @@ from vllm.model_executor.models.utils import WeightsMapper
|
||||
lora_lst = [
|
||||
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
|
||||
]
|
||||
BAICHUAN_LORA_MODULES = [
|
||||
"W_pack",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lora_name", lora_lst)
|
||||
@ -22,12 +28,11 @@ def test_load_checkpoints(
|
||||
baichuan_regex_lora_files,
|
||||
chatglm3_lora_files,
|
||||
):
|
||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules: List[str] = []
|
||||
for module in supported_lora_modules:
|
||||
for module in BAICHUAN_LORA_MODULES:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
@ -90,12 +95,12 @@ def test_load_checkpoints(
|
||||
|
||||
|
||||
def test_lora_weights_mapping(baichuan_lora_files):
|
||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||
|
||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules: List[str] = []
|
||||
for module in supported_lora_modules:
|
||||
for module in BAICHUAN_LORA_MODULES:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
|
@ -11,17 +11,20 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
# Provide absolute path and huggingface lora ids
|
||||
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
|
||||
LLAMA_LORA_MODULES = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
||||
"lm_head"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
|
||||
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
|
||||
lora_name = request.getfixturevalue(lora_fixture_name)
|
||||
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
|
||||
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
|
||||
embedding_modules = LlamaForCausalLM.embedding_modules
|
||||
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules: List[str] = []
|
||||
for module in supported_lora_modules:
|
||||
for module in LLAMA_LORA_MODULES:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
|
@ -19,7 +19,6 @@ from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
EMBEDDING_MODULES = {
|
||||
@ -114,19 +113,16 @@ def create_packed_lora(
|
||||
|
||||
def test_replace_submodules(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "layer1.dense2"]
|
||||
model.packed_modules_mapping = {}
|
||||
manager = LoRAModelManager(
|
||||
model, 1, 1, 1,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
|
||||
torch.device(DEVICES[0]))
|
||||
model = manager.model
|
||||
|
||||
assert isinstance(model.get_submodule("dense1"),
|
||||
ColumnParallelLinearWithLoRA)
|
||||
assert isinstance(model.get_submodule("layer1.dense1"),
|
||||
ColumnParallelLinearWithLoRA)
|
||||
assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
|
||||
assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
|
||||
assert isinstance(model.get_submodule("layer1.dense2"),
|
||||
RowParallelLinearWithLoRA)
|
||||
|
||||
@ -134,8 +130,6 @@ def test_replace_submodules(dist_init, dummy_model):
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1,
|
||||
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
@ -190,13 +184,18 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
|
||||
assert manager.device == device
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert hasattr(manager, "supported_lora_modules")
|
||||
assert sorted(manager.supported_lora_modules) == [
|
||||
"dense1",
|
||||
"dense2",
|
||||
"lm_head",
|
||||
"output",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1,
|
||||
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
@ -289,8 +288,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
# This tests just the LRU cache functionality, everything else is
|
||||
# tested in test_lora_model_manager
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1,
|
||||
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
@ -572,13 +569,6 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
model = dummy_model_gate_up
|
||||
model.supported_lora_modules = ["gate_up_proj"]
|
||||
model.packed_modules_mapping = {
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
model_lora = create_packed_lora(
|
||||
1,
|
||||
model,
|
||||
|
@ -26,6 +26,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
get_supported_lora_modules,
|
||||
is_regex_target_modules,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
@ -332,15 +333,15 @@ class LoRAModelManager(AdapterModelManager):
|
||||
# Used for long context lora.
|
||||
self.scaling_factor_to_offset: Dict[float, int] = {}
|
||||
super().__init__(model)
|
||||
if hasattr(self.model, "supported_lora_modules"):
|
||||
self.supported_lora_modules = copy.deepcopy(
|
||||
self.model.supported_lora_modules)
|
||||
if lora_config.long_lora_scaling_factors:
|
||||
# We need to replace rotary emb layer to do batch computation
|
||||
# for long lora.
|
||||
self.supported_lora_modules.append("rotary_emb")
|
||||
self.packed_modules_mapping = copy.deepcopy(
|
||||
self.model.packed_modules_mapping)
|
||||
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
||||
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
||||
f"{self.model.__class__.__name__}."
|
||||
if lora_config.long_lora_scaling_factors:
|
||||
# We need to replace rotary emb layer to do batch computation
|
||||
# for long lora.
|
||||
self.supported_lora_modules.append("rotary_emb")
|
||||
self.packed_modules_mapping = copy.deepcopy(
|
||||
self.model.packed_modules_mapping)
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = (
|
||||
supports_multimodal(self.model)
|
||||
@ -756,7 +757,7 @@ def create_lora_manager(
|
||||
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs) -> LoRAModelManager:
|
||||
"""Create a LoRA adapter for a given model."""
|
||||
if not hasattr(model, "supported_lora_modules"):
|
||||
if not hasattr(model, "packed_modules_mapping"):
|
||||
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
||||
lora_manager = lora_manager_cls(
|
||||
model=model,
|
||||
|
@ -29,6 +29,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
ReplicatedLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
@ -68,6 +69,14 @@ def from_layer(layer: nn.Module,
|
||||
ret = lora_cls(layer)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
|
||||
# The Case for HFCompatibleLinear
|
||||
if (hasattr(layer, "get_lora_class")
|
||||
and layer.__class__.__name__ == "HFCompatibleLinear"):
|
||||
lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras)
|
||||
ret = lora_cls(layer)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
return layer
|
||||
|
||||
|
||||
@ -170,6 +179,23 @@ def is_regex_target_modules(load_modules: Union[str, List[str]],
|
||||
return False
|
||||
|
||||
|
||||
def get_supported_lora_modules(model: nn.Module) -> List[str]:
|
||||
"""
|
||||
In vLLM, all linear layers support LoRA.
|
||||
"""
|
||||
supported_lora_modules: Set[str] = set()
|
||||
# step1: traverse the model to get all the linear subfixes.
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (LinearBase, )):
|
||||
supported_lora_modules.add(name.split(".")[-1])
|
||||
# step 2: get the embedding modules if the model's mbedding_modules
|
||||
# is not empty.
|
||||
if model.embedding_modules:
|
||||
for name in model.embedding_modules:
|
||||
supported_lora_modules.add(name)
|
||||
return list(supported_lora_modules)
|
||||
|
||||
|
||||
def get_adapter_absolute_path(lora_path: str) -> str:
|
||||
"""
|
||||
Resolves the given lora_path to an absolute local path.
|
||||
|
@ -84,9 +84,10 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
||||
|
||||
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
|
||||
try:
|
||||
model = self._adapter_manager.model
|
||||
supported_lora_modules = model.supported_lora_modules
|
||||
packed_modules_mapping = model.packed_modules_mapping
|
||||
supported_lora_modules = (
|
||||
self._adapter_manager.supported_lora_modules)
|
||||
packed_modules_mapping = (
|
||||
self._adapter_manager.packed_modules_mapping)
|
||||
expected_lora_modules: List[str] = []
|
||||
for module in supported_lora_modules:
|
||||
if module in packed_modules_mapping:
|
||||
@ -107,6 +108,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
||||
|
||||
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of lora weights.
|
||||
model = self._adapter_manager.model
|
||||
hf_to_vllm_mapper = None
|
||||
if (hasattr(model, "hf_to_vllm_mapper")
|
||||
and model.hf_to_vllm_mapper is not None):
|
||||
|
@ -342,15 +342,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"W_pack",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -389,12 +389,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -477,16 +477,6 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -357,11 +357,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
|
||||
]
|
||||
embedding_modules = {"embed_tokens": "input_embeddings"}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
@ -415,14 +415,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"out_proj",
|
||||
"gate_up_proj",
|
||||
"c_proj",
|
||||
"wte",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"wte": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -344,18 +344,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
|
||||
# Gemma does not apply LoRA to the embedding layer.
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -390,17 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
# Gemma does not apply LoRA to the embedding layer.
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
@ -534,21 +534,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||
"merged_proj": ["gate_proj", "dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
# vision
|
||||
"fc1",
|
||||
"fc2",
|
||||
"merged_proj",
|
||||
"linear_proj"
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
|
@ -261,15 +261,12 @@ class GPTBigCodeModel(nn.Module):
|
||||
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||
|
||||
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
|
||||
|
||||
# LoRA specific attributes
|
||||
embedding_modules = {
|
||||
"wte": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -351,10 +351,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
||||
"lm_head"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -329,13 +329,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
"layer",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -597,21 +597,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
# vision_model
|
||||
"fc1",
|
||||
"fc2",
|
||||
"out_proj",
|
||||
# text_model
|
||||
"qkv_proj", # same name with vision encoder
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
@ -118,11 +118,11 @@ class SupportsLoRA(Protocol):
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
|
||||
supported_lora_modules: ClassVar[List[str]]
|
||||
embedding_modules: ClassVar[Dict[str, str]]
|
||||
embedding_padding_modules: ClassVar[List[str]]
|
||||
# The `embedding_module` and `embedding_padding_modules`
|
||||
# are empty by default.
|
||||
embedding_modules: ClassVar[Dict[str, str]] = {}
|
||||
embedding_padding_modules: ClassVar[List[str]] = []
|
||||
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
|
||||
|
||||
|
||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||
@ -132,7 +132,6 @@ class _SupportsLoRAType(Protocol):
|
||||
supports_lora: Literal[True]
|
||||
|
||||
packed_modules_mapping: Dict[str, List[str]]
|
||||
supported_lora_modules: List[str]
|
||||
embedding_modules: Dict[str, str]
|
||||
embedding_padding_modules: List[str]
|
||||
|
||||
@ -155,7 +154,6 @@ def supports_lora(
|
||||
if not result:
|
||||
lora_attrs = (
|
||||
"packed_modules_mapping",
|
||||
"supported_lora_modules",
|
||||
"embedding_modules",
|
||||
"embedding_padding_modules",
|
||||
)
|
||||
|
@ -329,16 +329,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
"gate_up_proj": ["w1", "w3"],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"wqkv",
|
||||
"wo",
|
||||
"gate_up_proj",
|
||||
"w2",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
|
@ -380,10 +380,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
|
||||
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -452,10 +452,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
||||
"lm_head"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings"
|
||||
|
@ -522,14 +522,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -227,21 +227,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"kv_a_proj_with_mqa",
|
||||
"q_a_proj",
|
||||
"q_b_proj",
|
||||
"kv_b_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
|
||||
# `embedding_modules` and `embedding_padding_modules`
|
||||
# are inherited from MiniCPMForCausalLM
|
||||
|
||||
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
|
||||
|
@ -1228,23 +1228,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
# vision encoder
|
||||
"fc1",
|
||||
"fc2",
|
||||
"out_proj",
|
||||
# language model
|
||||
"qkv_proj", # same name with vision encoder
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
# resampler
|
||||
"kv_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
@ -1338,23 +1321,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
# vision encoder
|
||||
"fc1",
|
||||
"fc2",
|
||||
"out_proj",
|
||||
# language model
|
||||
"qkv_proj", # same name with vision encoder
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
# resampler
|
||||
"kv_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
@ -1460,13 +1426,6 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
|
||||
which is not conducive to the current integration logic of LoRA and
|
||||
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
|
||||
"""
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
# These will be updated when an instance class is selected
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -1487,7 +1446,6 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
|
||||
# quant_config references base class members,
|
||||
# so update values before init is called
|
||||
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
|
||||
cls.supported_lora_modules += instance_cls.supported_lora_modules
|
||||
cls.embedding_modules.update(instance_cls.embedding_modules)
|
||||
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
|
||||
return instance_cls(vllm_config=vllm_config, prefix=prefix)
|
||||
|
@ -332,10 +332,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3",
|
||||
"gate"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -1440,26 +1440,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
"merged_linear": ["gate_proj", "up_proj"] # image_projector
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
# language model
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj", # same name with image_projector
|
||||
# vision tower
|
||||
"wq",
|
||||
"wk",
|
||||
"wv",
|
||||
"wo",
|
||||
"w1",
|
||||
"w2",
|
||||
# image_projector
|
||||
"merged_linear",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -389,9 +389,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "up_proj", "down_proj", "embed_tokens", "lm_head"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -273,17 +273,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
]
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"dense",
|
||||
"fc1",
|
||||
"fc2",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -526,16 +526,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
"w1",
|
||||
"w2",
|
||||
"w3",
|
||||
"gate",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -354,15 +354,6 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
|
||||
"w1",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"c_attn",
|
||||
"gate_up_proj",
|
||||
"c_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -430,16 +430,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -528,16 +518,6 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -734,27 +734,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
# language model
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj", # Same name with vision encoder
|
||||
# vision tower
|
||||
"qkv",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"attn.proj", # Distinguish patch_embed.proj
|
||||
"fc1",
|
||||
"fc2",
|
||||
# projector
|
||||
"mlp.0",
|
||||
"mlp.2"
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||
|
@ -47,16 +47,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -1048,24 +1048,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
# vision tower
|
||||
"qkv",
|
||||
"attn.proj", # Distinguish patch_embed.proj
|
||||
"fc1",
|
||||
"fc2",
|
||||
# projector
|
||||
"mlp.0",
|
||||
"mlp.2"
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
|
@ -667,21 +667,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
"w1",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"c_attn",
|
||||
"gate_up_proj",
|
||||
"c_proj",
|
||||
# visual module
|
||||
"out_proj",
|
||||
"in_proj",
|
||||
"c_fc",
|
||||
# resampler
|
||||
"kv_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
|
@ -386,14 +386,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
|
@ -27,6 +27,11 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.fully_sharded_layers import (
|
||||
ColumnParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA)
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
ReplicatedLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
@ -103,6 +108,23 @@ def replace_linear_class(
|
||||
"rowwise": RowParallelLinear,
|
||||
}.get(style, ReplicatedLinear)
|
||||
|
||||
lora_linear_cls = {
|
||||
ColumnParallelLinear: {
|
||||
True: ColumnParallelLinearWithShardedLoRA, # fully sharded
|
||||
False: ColumnParallelLinearWithLoRA # not fully sharded
|
||||
},
|
||||
RowParallelLinear: {
|
||||
True: RowParallelLinearWithShardedLoRA,
|
||||
False: RowParallelLinearWithLoRA
|
||||
},
|
||||
# ReplicatedLinear doesn't support fully sharded LoRA yet,
|
||||
# so we use the same class for both cases.
|
||||
ReplicatedLinear: {
|
||||
True: ReplicatedLinearWithLoRA,
|
||||
False: ReplicatedLinearWithLoRA
|
||||
}
|
||||
}
|
||||
|
||||
class HFCompatibleLinear(vllm_linear_cls):
|
||||
"""
|
||||
Wrapper class that removes `output_bias` from returned output.
|
||||
@ -111,6 +133,19 @@ def replace_linear_class(
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(input)[0]
|
||||
|
||||
@classmethod
|
||||
def get_lora_class(cls, fully_sharded: bool = False):
|
||||
"""
|
||||
Get the LoRA class corresponding to the current transformer
|
||||
linear class.
|
||||
|
||||
Args:
|
||||
fully_sharded (bool): If True, select the LoRA class variant
|
||||
that supports fully sharded LoRA. Defaults to False.
|
||||
|
||||
"""
|
||||
return lora_linear_cls[vllm_linear_cls][fully_sharded]
|
||||
|
||||
return HFCompatibleLinear(
|
||||
input_size=linear.in_features,
|
||||
output_size=linear.out_features,
|
||||
|
@ -360,14 +360,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
# TODO : Add LoRA to the audio tower and projector.
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
|
||||
|
||||
|
@ -650,9 +650,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
logger.info(msg)
|
||||
|
||||
if self.lora_config:
|
||||
assert hasattr(self.model, "supported_lora_modules"
|
||||
) and self.model.supported_lora_modules, (
|
||||
"Model does not support LoRA")
|
||||
assert hasattr(self.model, "embedding_modules"
|
||||
), "Model does not have embedding_modules"
|
||||
assert hasattr(
|
||||
|
Reference in New Issue
Block a user