mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
225 lines
6.1 KiB
Python
225 lines
6.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""
|
|
Based on:
|
|
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
|
Punica: Multi-Tenant LoRA Serving.
|
|
https://arxiv.org/abs/2310.18547
|
|
"""
|
|
|
|
from typing import List
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from vllm.utils import direct_register_custom_op
|
|
|
|
from .kernel_utils import do_shrink_kernel
|
|
from .utils import _get_lora_a_ptr
|
|
|
|
|
|
@triton.jit
|
|
def _sgmv_shrink_kernel(
|
|
input_ptr,
|
|
lora_ptr, #1-3
|
|
out_ptr,
|
|
N,
|
|
K,
|
|
b_seq_start_loc,
|
|
seq_lens,
|
|
lora_indices,
|
|
scaling,
|
|
input_d0_stride,
|
|
input_d1_stride, # 1
|
|
lora_d0_stride,
|
|
lora_d1_stride,
|
|
lora_d2_stride, # 1
|
|
output_d0_stride,
|
|
output_d1_stride,
|
|
output_d2_stride, # 1
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
BLOCK_K: tl.constexpr,
|
|
EVEN_K: tl.constexpr,
|
|
SPLIT_K: tl.constexpr,
|
|
SLICE_NUM: tl.constexpr):
|
|
"""
|
|
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
|
|
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
|
|
introducing SPLIT-K can improve performance
|
|
"""
|
|
pid = tl.program_id(axis=0)
|
|
pid_mix = tl.program_id(axis=1)
|
|
cur_batch = tl.program_id(axis=2)
|
|
cta_n_num = tl.cdiv(N, BLOCK_N)
|
|
pid_m = pid // cta_n_num
|
|
pid_n = pid % cta_n_num
|
|
if SLICE_NUM == 1:
|
|
slice_id: tl.constexpr = 0
|
|
pid_sk = tl.program_id(axis=1)
|
|
else:
|
|
pid_mix = tl.program_id(axis=1)
|
|
slice_id = pid_mix // SPLIT_K
|
|
pid_sk = pid_mix % SPLIT_K
|
|
|
|
M = tl.load(seq_lens + cur_batch)
|
|
if pid_m * BLOCK_M >= M:
|
|
return
|
|
lora_index = tl.load(lora_indices + cur_batch)
|
|
if lora_index == -1:
|
|
return
|
|
|
|
m_offset = tl.load(b_seq_start_loc + cur_batch)
|
|
|
|
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
|
|
cta_m_offset = m_offset + (pid_m * BLOCK_M)
|
|
offset_m = tl.arange(0, BLOCK_M)
|
|
ram = cta_m_offset + tl.max_contiguous(
|
|
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
|
|
|
|
do_shrink_kernel(
|
|
pid_n,
|
|
pid_sk,
|
|
slice_id,
|
|
lora_index,
|
|
input_ptr,
|
|
lora_ptr,
|
|
out_ptr,
|
|
N,
|
|
K,
|
|
cta_m_len,
|
|
ram,
|
|
# input strides
|
|
input_d0_stride,
|
|
input_d1_stride,
|
|
# lora strides
|
|
lora_d0_stride,
|
|
lora_d1_stride,
|
|
lora_d2_stride,
|
|
# output strides
|
|
output_d0_stride,
|
|
output_d1_stride,
|
|
output_d2_stride,
|
|
scaling,
|
|
BLOCK_M,
|
|
BLOCK_N,
|
|
BLOCK_K,
|
|
EVEN_K,
|
|
SPLIT_K,
|
|
SLICE_NUM)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def _sgmv_shrink(
|
|
inputs: torch.Tensor,
|
|
lora_a_weights: List[torch.Tensor],
|
|
output_tensor: torch.Tensor,
|
|
b_seq_start_loc: torch.Tensor,
|
|
seq_len_tensor: torch.Tensor,
|
|
lora_indices_tensor: torch.Tensor,
|
|
batches: int,
|
|
max_seq_length: int,
|
|
token_nums: int,
|
|
scaling: float,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
inputs (torch.Tensor): input tensor
|
|
lora_a_weights (List[torch.Tensor]): lora'a weight
|
|
output_tensor (torch.Tensor): output tensor
|
|
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
|
sequence lengths of the sequences in the batch, used to index
|
|
into sequence. E.g., if the sequence length is [4, 6], it is
|
|
[0, 4].
|
|
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
|
length of the sequences in the batch.
|
|
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
|
corresponding to each batch. An index of -1 means no lora should be
|
|
applied.
|
|
batches (int): batch size
|
|
max_seq_length (int): The max sequence lengths of the sequences in the
|
|
batch.
|
|
token_nums (int): The token numbers in the batch. Used to verify if the
|
|
token numbers in the inputs matches the one in the metadata.
|
|
scaling (float): Scaling factor.
|
|
"""
|
|
assert inputs.dtype == lora_a_weights[0].dtype
|
|
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
|
for weight in lora_a_weights:
|
|
assert weight.dtype in [torch.float16, torch.bfloat16]
|
|
|
|
assert inputs.size(0) == token_nums
|
|
assert inputs.size(1) == lora_a_weights[0].size(-1)
|
|
assert b_seq_start_loc.size(0) == batches
|
|
assert lora_indices_tensor.size(0) == batches
|
|
assert inputs.is_contiguous()
|
|
assert output_tensor.is_contiguous()
|
|
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
|
|
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device)
|
|
# TODO tuning this config
|
|
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
|
BLOCK_M = 32
|
|
BLOCK_N = 16
|
|
BLOCK_K = 32
|
|
SPLIT_K = 8
|
|
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
|
|
grid = (
|
|
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
|
SPLIT_K * len(lora_a_weights),
|
|
batches,
|
|
)
|
|
_sgmv_shrink_kernel[grid](
|
|
inputs,
|
|
lora_ptr_tensor,
|
|
output_tensor,
|
|
N,
|
|
K,
|
|
b_seq_start_loc,
|
|
seq_len_tensor,
|
|
lora_indices_tensor,
|
|
scaling,
|
|
inputs.stride(0),
|
|
inputs.stride(1),
|
|
lora_strides_d0,
|
|
lora_strides_d1,
|
|
lora_strides_d2,
|
|
output_tensor.stride(0),
|
|
output_tensor.stride(1),
|
|
output_tensor.stride(2),
|
|
BLOCK_M,
|
|
BLOCK_N,
|
|
BLOCK_K,
|
|
EVEN_K,
|
|
SPLIT_K,
|
|
len(lora_a_weights),
|
|
)
|
|
return
|
|
|
|
|
|
def sgmv_shrink_fake(
|
|
inputs: torch.Tensor,
|
|
lora_a_weights: List[torch.Tensor],
|
|
output_tensor: torch.Tensor,
|
|
b_seq_start_loc: torch.Tensor,
|
|
seq_len_tensor: torch.Tensor,
|
|
lora_indices_tensor: torch.Tensor,
|
|
batches: int,
|
|
max_seq_length: int,
|
|
token_nums: int,
|
|
scaling: float,
|
|
) -> None:
|
|
return
|
|
|
|
|
|
try:
|
|
direct_register_custom_op(
|
|
op_name="sgmv_shrink",
|
|
op_func=_sgmv_shrink,
|
|
mutates_args=["output_tensor"],
|
|
fake_impl=sgmv_shrink_fake,
|
|
)
|
|
sgmv_shrink = torch.ops.vllm.sgmv_shrink
|
|
|
|
except AttributeError:
|
|
sgmv_shrink = _sgmv_shrink
|