mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 15:43:52 +08:00
199 lines
6.4 KiB
Python
199 lines
6.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections import OrderedDict
|
|
from typing import NamedTuple
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from huggingface_hub.utils import HfHubHTTPError
|
|
from torch import nn
|
|
|
|
from vllm.lora.utils import (
|
|
get_adapter_absolute_path,
|
|
parse_fine_tuned_lora_name,
|
|
replace_submodule,
|
|
)
|
|
from vllm.model_executor.models.utils import WeightsMapper
|
|
|
|
|
|
class LoRANameParserTestConfig(NamedTuple):
|
|
name: str
|
|
module_name: str
|
|
is_lora_a: bool
|
|
weights_mapper: WeightsMapper | None = None
|
|
|
|
|
|
def test_parse_fine_tuned_lora_name_valid():
|
|
fixture = [
|
|
LoRANameParserTestConfig(
|
|
"base_model.model.lm_head.lora_A.weight", "lm_head", True, False
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"base_model.model.lm_head.lora_B.weight", "lm_head", False, False
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"base_model.model.model.embed_tokens.lora_embedding_A",
|
|
"model.embed_tokens",
|
|
True,
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"base_model.model.model.embed_tokens.lora_embedding_B",
|
|
"model.embed_tokens",
|
|
False,
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
|
"model.layers.9.mlp.down_proj",
|
|
True,
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
|
"model.layers.9.mlp.down_proj",
|
|
False,
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"language_model.layers.9.mlp.down_proj.lora_A.weight",
|
|
"language_model.layers.9.mlp.down_proj",
|
|
True,
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"language_model.layers.9.mlp.down_proj.lora_B.weight",
|
|
"language_model.layers.9.mlp.down_proj",
|
|
False,
|
|
),
|
|
# Test with WeightsMapper
|
|
LoRANameParserTestConfig(
|
|
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
|
"language_model.model.layers.9.mlp.down_proj",
|
|
True,
|
|
weights_mapper=WeightsMapper(
|
|
orig_to_new_prefix={"model.": "language_model.model."}
|
|
),
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
|
"language_model.model.layers.9.mlp.down_proj",
|
|
False,
|
|
weights_mapper=WeightsMapper(
|
|
orig_to_new_prefix={"model.": "language_model.model."}
|
|
),
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"model.layers.9.mlp.down_proj.lora_A.weight",
|
|
"language_model.model.layers.9.mlp.down_proj",
|
|
True,
|
|
weights_mapper=WeightsMapper(
|
|
orig_to_new_prefix={"model.": "language_model.model."}
|
|
),
|
|
),
|
|
LoRANameParserTestConfig(
|
|
"model.layers.9.mlp.down_proj.lora_B.weight",
|
|
"language_model.model.layers.9.mlp.down_proj",
|
|
False,
|
|
weights_mapper=WeightsMapper(
|
|
orig_to_new_prefix={"model.": "language_model.model."}
|
|
),
|
|
),
|
|
]
|
|
for name, module_name, is_lora_a, weights_mapper in fixture:
|
|
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(
|
|
name, weights_mapper
|
|
)
|
|
|
|
|
|
def test_parse_fine_tuned_lora_name_invalid():
|
|
fixture = {
|
|
"base_model.weight",
|
|
"base_model.model.weight",
|
|
}
|
|
for name in fixture:
|
|
with pytest.raises(ValueError, match="unsupported LoRA weight"):
|
|
parse_fine_tuned_lora_name(name)
|
|
|
|
|
|
def test_replace_submodule():
|
|
model = nn.Sequential(
|
|
OrderedDict(
|
|
[
|
|
("dense1", nn.Linear(764, 100)),
|
|
("act1", nn.ReLU()),
|
|
("dense2", nn.Linear(100, 50)),
|
|
(
|
|
"seq1",
|
|
nn.Sequential(
|
|
OrderedDict(
|
|
[
|
|
("dense1", nn.Linear(100, 10)),
|
|
("dense2", nn.Linear(10, 50)),
|
|
]
|
|
)
|
|
),
|
|
),
|
|
("act2", nn.ReLU()),
|
|
("output", nn.Linear(50, 10)),
|
|
("outact", nn.Sigmoid()),
|
|
]
|
|
)
|
|
)
|
|
|
|
sigmoid = nn.Sigmoid()
|
|
|
|
replace_submodule(model, "act1", sigmoid)
|
|
assert dict(model.named_modules())["act1"] == sigmoid
|
|
|
|
dense2 = nn.Linear(1, 5)
|
|
replace_submodule(model, "seq1.dense2", dense2)
|
|
assert dict(model.named_modules())["seq1.dense2"] == dense2
|
|
|
|
|
|
# Unit tests for get_adapter_absolute_path
|
|
@patch("os.path.isabs")
|
|
def test_get_adapter_absolute_path_absolute(mock_isabs):
|
|
path = "/absolute/path/to/lora"
|
|
mock_isabs.return_value = True
|
|
assert get_adapter_absolute_path(path) == path
|
|
|
|
|
|
@patch("os.path.expanduser")
|
|
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
|
|
# Path with ~ that needs to be expanded
|
|
path = "~/relative/path/to/lora"
|
|
absolute_path = "/home/user/relative/path/to/lora"
|
|
mock_expanduser.return_value = absolute_path
|
|
assert get_adapter_absolute_path(path) == absolute_path
|
|
|
|
|
|
@patch("os.path.exists")
|
|
@patch("os.path.abspath")
|
|
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
|
|
# Relative path that exists locally
|
|
path = "relative/path/to/lora"
|
|
absolute_path = "/absolute/path/to/lora"
|
|
mock_exist.return_value = True
|
|
mock_abspath.return_value = absolute_path
|
|
assert get_adapter_absolute_path(path) == absolute_path
|
|
|
|
|
|
@patch("huggingface_hub.snapshot_download")
|
|
@patch("os.path.exists")
|
|
def test_get_adapter_absolute_path_huggingface(mock_exist, mock_snapshot_download):
|
|
# Hugging Face model identifier
|
|
path = "org/repo"
|
|
absolute_path = "/mock/snapshot/path"
|
|
mock_exist.return_value = False
|
|
mock_snapshot_download.return_value = absolute_path
|
|
assert get_adapter_absolute_path(path) == absolute_path
|
|
|
|
|
|
@patch("huggingface_hub.snapshot_download")
|
|
@patch("os.path.exists")
|
|
def test_get_adapter_absolute_path_huggingface_error(
|
|
mock_exist, mock_snapshot_download
|
|
):
|
|
# Hugging Face model identifier with download error
|
|
path = "org/repo"
|
|
mock_exist.return_value = False
|
|
mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info")
|
|
assert get_adapter_absolute_path(path) == path
|