mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Relax the LoRA max rank (#26461)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -103,7 +103,7 @@ class LoRAConfig:
|
||||
|
||||
# Setting the maximum rank to 512 should be able to satisfy the vast
|
||||
# majority of applications.
|
||||
possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512)
|
||||
possible_max_ranks = (1, 8, 16, 32, 64, 128, 256, 320, 512)
|
||||
possible_lora_extra_vocab_size = (256, 512)
|
||||
if self.max_lora_rank not in possible_max_ranks:
|
||||
raise ValueError(
|
||||
|
@ -28,8 +28,6 @@ logger = init_logger(__name__)
|
||||
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class LoRAModelRunnerMixin:
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
def load_lora_model(
|
||||
self, model: nn.Module, vllm_config: VllmConfig, device: torch.device
|
||||
) -> nn.Module:
|
||||
@ -96,7 +94,9 @@ class LoRAModelRunnerMixin:
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
lora_warmup_rank = (
|
||||
lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8
|
||||
)
|
||||
# Make dummy lora requests
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(
|
||||
@ -111,7 +111,7 @@ class LoRAModelRunnerMixin:
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(lr, rank=self.LORA_WARMUP_RANK)
|
||||
self.lora_manager.add_dummy_lora(lr, rank=lora_warmup_rank)
|
||||
|
||||
yield
|
||||
|
||||
|
Reference in New Issue
Block a user