mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Lora]Load tuned multi-lora kernel configs from json files (#26319)
Signed-off-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com> Signed-off-by: Haipeng Li <li2haipeng@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
51
vllm/lora/ops/triton_ops/README_TUNING.md
Normal file
51
vllm/lora/ops/triton_ops/README_TUNING.md
Normal file
@ -0,0 +1,51 @@
|
||||
# Multi-LoRA Tuning
|
||||
|
||||
**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`. Without this, the shrink/expand kernels will use default configurations.
|
||||
|
||||
## Tuning Process
|
||||
|
||||
Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from [Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py).
|
||||
|
||||
**Step 1**
|
||||
Define the searching space. An example searching space:
|
||||
|
||||
```python
|
||||
block_m_range = [16, 32, 64, 128, 256]
|
||||
block_n_range = [32, 64, 128, 256]
|
||||
block_k_range = [32, 64, 128, 256]
|
||||
num_warps_range = [4, 8]
|
||||
num_stage_range = [2, 3, 4, 5]
|
||||
num_ctas_range = [1]
|
||||
split_k_range = [4, 8, 16, 32, 64]
|
||||
```
|
||||
|
||||
**Step 2**
|
||||
Get all hidden_state sizes and num_slices that the target model uses for a specific TP size.
|
||||
|
||||
For example, we can aquire those info by simply checking [add_lora_linear](https://github.com/li2haipeng/vllm/blob/multi_lora_v01011/vllm/lora/punica_wrapper/punica_gpu.py#L192):
|
||||
|
||||
```python
|
||||
print(f"x_shape: {x.view(-1, x.shape[-1]).shape}")
|
||||
print(f"num_sclises: {len(output_slices)}")
|
||||
for i in range(len(output_slices)):
|
||||
print(f"a{i} shape: {lora_a_stacked[i].shape}")
|
||||
print(f"b{i} shape: {lora_b_stacked[i].shape}")
|
||||
print("y_shape", y.shape)
|
||||
```
|
||||
|
||||
**Step 3**
|
||||
Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space by performing a grid search to find the optimal kernel configuration. vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py) can be used to search for configurations for different shapes.
|
||||
|
||||
## Config Files
|
||||
|
||||
### File Name
|
||||
|
||||
For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA_H200_SHRINK.json`.
|
||||
|
||||
For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`.
|
||||
|
||||
The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`
|
||||
|
||||
### Json Structure
|
||||
|
||||
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]`
|
@ -10,7 +10,7 @@ https://arxiv.org/abs/2310.18547
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -201,12 +201,21 @@ def _lora_expand(
|
||||
NUM_SLICES = len(lora_b_weights)
|
||||
|
||||
# Triton kernel configs.
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 16
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
kernel_config = get_lora_op_configs(
|
||||
op_type="expand",
|
||||
max_loras=MAX_LORAS,
|
||||
batch=M,
|
||||
hidden_size=MAX_N,
|
||||
rank=K,
|
||||
num_slices=NUM_SLICES,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
BLOCK_M = kernel_config["block_m"]
|
||||
BLOCK_N = kernel_config["block_n"]
|
||||
BLOCK_K = kernel_config["block_k"]
|
||||
NUM_WARPS = kernel_config["num_warps"]
|
||||
NUM_CTAS = kernel_config["num_ctas"]
|
||||
NUM_STAGES = kernel_config["num_stages"]
|
||||
|
||||
EVEN_K = K % BLOCK_K == 0 # type: ignore
|
||||
|
||||
|
@ -10,7 +10,7 @@ https://arxiv.org/abs/2310.18547
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -177,14 +177,21 @@ def _lora_shrink(
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
|
||||
# Triton kernel configs
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 16
|
||||
BLOCK_K = 256 if M < 128 else 32
|
||||
SPLIT_K = 64 if M < 128 else 8
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
|
||||
kernel_config = get_lora_op_configs(
|
||||
"shrink",
|
||||
max_loras=MAX_LORAS,
|
||||
batch=M,
|
||||
hidden_size=K,
|
||||
rank=N,
|
||||
num_slices=NUM_SLICES,
|
||||
)
|
||||
BLOCK_M = kernel_config["block_m"]
|
||||
BLOCK_N = kernel_config["block_n"]
|
||||
BLOCK_K = kernel_config["block_k"]
|
||||
SPLIT_K = kernel_config["split_k"]
|
||||
NUM_WARPS = kernel_config["num_warps"]
|
||||
NUM_STAGES = kernel_config["num_stages"]
|
||||
NUM_CTAS = kernel_config["num_ctas"]
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
|
@ -1,8 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
|
||||
@ -133,3 +143,108 @@ def _get_lora_b_ptr(
|
||||
MAX_N,
|
||||
)
|
||||
return _LORA_B_PTR_DICT.get(key)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
||||
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||
if user_defined_config_folder is not None:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
gpu_name = gpu_name.replace(" ", "_")
|
||||
gpu_name = gpu_name.replace("-", "_")
|
||||
|
||||
config_fname = None
|
||||
if op_type == "shrink":
|
||||
config_fname = f"{gpu_name}_{op_type.upper()}.json"
|
||||
else:
|
||||
assert op_type == "expand"
|
||||
config_fname = (
|
||||
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
|
||||
)
|
||||
|
||||
config_path = Path(f"{user_defined_config_folder}/{config_fname}")
|
||||
if not config_path.exists():
|
||||
logger.warning_once(f"No LoRA kernel configs founded in {config_path}")
|
||||
return None
|
||||
|
||||
# Load json
|
||||
logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.")
|
||||
with open(str(config_path)) as f:
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
config_data = None
|
||||
|
||||
return config_data
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_lora_op_configs(
|
||||
op_type: str,
|
||||
max_loras: int,
|
||||
batch: int,
|
||||
hidden_size: int,
|
||||
rank: int,
|
||||
num_slices: int,
|
||||
add_inputs: bool | None = None,
|
||||
) -> dict[str, int | None]:
|
||||
assert op_type in ["shrink", "expand"]
|
||||
|
||||
# default config
|
||||
default = {}
|
||||
if op_type == "shrink":
|
||||
default = {
|
||||
"block_m": 32,
|
||||
"block_n": 16,
|
||||
"block_k": 256 if batch < 128 else 32,
|
||||
"split_k": 64 if batch < 128 else 8,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"num_stages": 2,
|
||||
"max_nreg": None,
|
||||
}
|
||||
else:
|
||||
default = {
|
||||
"block_m": 64,
|
||||
"block_n": 128,
|
||||
"block_k": 16,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"num_stages": 2,
|
||||
"max_nreg": None,
|
||||
}
|
||||
m = batch
|
||||
|
||||
k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size)
|
||||
|
||||
config_data: Any
|
||||
config_data = load_lora_op_config(op_type, add_inputs)
|
||||
if not config_data:
|
||||
logger.warning_once("Using default LoRA kernel configs")
|
||||
return default
|
||||
|
||||
# config is structured as config_data[max_loras][num_slices][m][k][n] = {}
|
||||
# slice by max_loras
|
||||
config_data = (
|
||||
config_data.get(str(max_loras))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))]
|
||||
)
|
||||
# slice by num_slices
|
||||
config_data = config_data[str(num_slices)]
|
||||
# slice by m
|
||||
config_data = (
|
||||
config_data.get(str(m))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))]
|
||||
)
|
||||
# slice by k
|
||||
config_data = (
|
||||
config_data.get(str(k))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))]
|
||||
)
|
||||
# slice by n
|
||||
config_data = (
|
||||
config_data.get(str(n))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
|
||||
)
|
||||
|
||||
assert config_data is not None
|
||||
return config_data
|
||||
|
Reference in New Issue
Block a user