[Lora]Load tuned multi-lora kernel configs from json files (#26319)

Signed-off-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com>
Signed-off-by: Haipeng Li <li2haipeng@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
li2haipeng
2025-10-15 02:45:14 -07:00
committed by GitHub
parent db1764e4e0
commit d4d1a6024f
4 changed files with 198 additions and 16 deletions

View File

@ -0,0 +1,51 @@
# Multi-LoRA Tuning
**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`. Without this, the shrink/expand kernels will use default configurations.
## Tuning Process
Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from [Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py).
**Step 1**
Define the searching space. An example searching space:
```python
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [32, 64, 128, 256]
num_warps_range = [4, 8]
num_stage_range = [2, 3, 4, 5]
num_ctas_range = [1]
split_k_range = [4, 8, 16, 32, 64]
```
**Step 2**
Get all hidden_state sizes and num_slices that the target model uses for a specific TP size.
For example, we can aquire those info by simply checking [add_lora_linear](https://github.com/li2haipeng/vllm/blob/multi_lora_v01011/vllm/lora/punica_wrapper/punica_gpu.py#L192):
```python
print(f"x_shape: {x.view(-1, x.shape[-1]).shape}")
print(f"num_sclises: {len(output_slices)}")
for i in range(len(output_slices)):
print(f"a{i} shape: {lora_a_stacked[i].shape}")
print(f"b{i} shape: {lora_b_stacked[i].shape}")
print("y_shape", y.shape)
```
**Step 3**
Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space by performing a grid search to find the optimal kernel configuration. vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py) can be used to search for configurations for different shapes.
## Config Files
### File Name
For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA_H200_SHRINK.json`.
For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`.
The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`
### Json Structure
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]`

View File

@ -10,7 +10,7 @@ https://arxiv.org/abs/2310.18547
import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@ -201,12 +201,21 @@ def _lora_expand(
NUM_SLICES = len(lora_b_weights)
# Triton kernel configs.
BLOCK_M = 64
BLOCK_N = 128
BLOCK_K = 16
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
kernel_config = get_lora_op_configs(
op_type="expand",
max_loras=MAX_LORAS,
batch=M,
hidden_size=MAX_N,
rank=K,
num_slices=NUM_SLICES,
add_inputs=add_inputs,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_CTAS = kernel_config["num_ctas"]
NUM_STAGES = kernel_config["num_stages"]
EVEN_K = K % BLOCK_K == 0 # type: ignore

View File

@ -10,7 +10,7 @@ https://arxiv.org/abs/2310.18547
import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@ -177,14 +177,21 @@ def _lora_shrink(
MAX_LORAS = lora_ids.size(0)
# Triton kernel configs
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 256 if M < 128 else 32
SPLIT_K = 64 if M < 128 else 8
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
kernel_config = get_lora_op_configs(
"shrink",
max_loras=MAX_LORAS,
batch=M,
hidden_size=K,
rank=N,
num_slices=NUM_SLICES,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
SPLIT_K = kernel_config["split_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_STAGES = kernel_config["num_stages"]
NUM_CTAS = kernel_config["num_ctas"]
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
# TODO (varun): This grid formulation maximizes parallelization at the

View File

@ -1,8 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import json
from pathlib import Path
from typing import Any
import torch
from vllm import envs
from vllm.logger import init_logger
logger = init_logger(__name__)
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
@ -133,3 +143,108 @@ def _get_lora_b_ptr(
MAX_N,
)
return _LORA_B_PTR_DICT.get(key)
@functools.lru_cache
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
if user_defined_config_folder is not None:
gpu_name = torch.cuda.get_device_name()
gpu_name = gpu_name.replace(" ", "_")
gpu_name = gpu_name.replace("-", "_")
config_fname = None
if op_type == "shrink":
config_fname = f"{gpu_name}_{op_type.upper()}.json"
else:
assert op_type == "expand"
config_fname = (
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
)
config_path = Path(f"{user_defined_config_folder}/{config_fname}")
if not config_path.exists():
logger.warning_once(f"No LoRA kernel configs founded in {config_path}")
return None
# Load json
logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.")
with open(str(config_path)) as f:
config_data = json.load(f)
else:
config_data = None
return config_data
@functools.lru_cache
def get_lora_op_configs(
op_type: str,
max_loras: int,
batch: int,
hidden_size: int,
rank: int,
num_slices: int,
add_inputs: bool | None = None,
) -> dict[str, int | None]:
assert op_type in ["shrink", "expand"]
# default config
default = {}
if op_type == "shrink":
default = {
"block_m": 32,
"block_n": 16,
"block_k": 256 if batch < 128 else 32,
"split_k": 64 if batch < 128 else 8,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2,
"max_nreg": None,
}
else:
default = {
"block_m": 64,
"block_n": 128,
"block_k": 16,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2,
"max_nreg": None,
}
m = batch
k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size)
config_data: Any
config_data = load_lora_op_config(op_type, add_inputs)
if not config_data:
logger.warning_once("Using default LoRA kernel configs")
return default
# config is structured as config_data[max_loras][num_slices][m][k][n] = {}
# slice by max_loras
config_data = (
config_data.get(str(max_loras))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))]
)
# slice by num_slices
config_data = config_data[str(num_slices)]
# slice by m
config_data = (
config_data.get(str(m))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))]
)
# slice by k
config_data = (
config_data.get(str(k))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))]
)
# slice by n
config_data = (
config_data.get(str(n))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
)
assert config_data is not None
return config_data