mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
122 lines
4.8 KiB
Python
122 lines
4.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
|
|
_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
|
|
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
|
|
|
|
|
|
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device):
|
|
"""
|
|
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
|
|
After this, it remains constant and subsequent usage is through LUT.
|
|
Refer to:
|
|
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
|
"""
|
|
key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights)
|
|
|
|
if values := _LORA_A_PTR_DICT.get(key):
|
|
return values
|
|
|
|
lora_strides_d0 = []
|
|
lora_strides_d1 = []
|
|
lora_strides_d2 = []
|
|
tensor_ptrs = []
|
|
for lora_a_weight in lora_a_weights:
|
|
if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
|
assert lora_a_weight.size(1) == 1
|
|
lora_a_weight = lora_a_weight.squeeze(dim=1)
|
|
else:
|
|
assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank)
|
|
assert lora_a_weight.is_contiguous()
|
|
tensor_ptrs.append(lora_a_weight.data_ptr())
|
|
lora_strides_d0.append(lora_a_weight.stride(0))
|
|
lora_strides_d1.append(lora_a_weight.stride(1))
|
|
lora_strides_d2.append(lora_a_weight.stride(2))
|
|
if len(lora_a_weights) > 1:
|
|
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
|
else:
|
|
lora_ptr_tensor = lora_a_weights[0]
|
|
|
|
if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1
|
|
or len(set(lora_strides_d2)) > 1):
|
|
raise ValueError("All LoRA weights must have the same stride.")
|
|
|
|
_LORA_A_PTR_DICT[key] = (
|
|
lora_ptr_tensor,
|
|
lora_strides_d0[0],
|
|
lora_strides_d1[0],
|
|
lora_strides_d2[0],
|
|
)
|
|
return _LORA_A_PTR_DICT.get(key)
|
|
|
|
|
|
def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int,
|
|
device: torch.device):
|
|
"""
|
|
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
|
|
After this, it remains constant and subsequent usage is through LUT.
|
|
Refer to:
|
|
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
|
|
|
"""
|
|
|
|
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
|
|
if values := _LORA_B_PTR_DICT.get(key):
|
|
return values
|
|
slice_offset_lst = []
|
|
tensor_ptrs = []
|
|
lora_strides_d0 = []
|
|
lora_strides_d1 = []
|
|
lora_strides_d2 = []
|
|
hidden_sizes = []
|
|
slice_offset = offset_start
|
|
for lora_b_weight in lora_weights:
|
|
if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
|
assert lora_b_weight.size(1) == 1
|
|
lora_b_weight = lora_b_weight.squeeze(dim=1)
|
|
else:
|
|
assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank)
|
|
assert lora_b_weight.is_contiguous()
|
|
tensor_ptrs.append(lora_b_weight.data_ptr())
|
|
lora_strides_d0.append(lora_b_weight.stride(0))
|
|
lora_strides_d1.append(lora_b_weight.stride(1))
|
|
lora_strides_d2.append(lora_b_weight.stride(2))
|
|
slice_offset_lst.append(slice_offset)
|
|
slice_offset += lora_b_weight.size(1)
|
|
hidden_sizes.append(lora_b_weight.size(1))
|
|
|
|
if len(lora_weights) > 1:
|
|
# note these are device tensors
|
|
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
|
slice_start_tensor = torch.tensor(slice_offset_lst, device=device)
|
|
else:
|
|
slice_start_tensor = slice_offset_lst[0]
|
|
lora_ptr_tensor = lora_b_weight[0]
|
|
|
|
# If each lora has the same stride, there's no need to use a
|
|
# tensor for storage.
|
|
if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and
|
|
len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1:
|
|
lora_strides_d0_tensor = lora_strides_d0[0]
|
|
lora_strides_d1_tensor = lora_strides_d1[0]
|
|
lora_strides_d2_tensor = lora_strides_d2[0]
|
|
hidden_sizes_tensor = hidden_sizes[0]
|
|
same_stride = True
|
|
|
|
else:
|
|
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
|
|
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
|
|
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
|
|
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
|
|
same_stride = False
|
|
# MAX_N is the maximum hidden size among all the lora_b weights
|
|
MAX_N = max(hidden_sizes)
|
|
_LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor,
|
|
lora_strides_d0_tensor, lora_strides_d1_tensor,
|
|
lora_strides_d2_tensor, hidden_sizes_tensor,
|
|
same_stride, MAX_N)
|
|
return _LORA_B_PTR_DICT.get(key)
|