mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc][LoRA] Refactor and clean MergedQKVParallelLinearWithLora implementation (#10958)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@ -542,10 +542,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
Both slices must have the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
|
||||
def __init__(
|
||||
self, base_layer: Union[MergedColumnParallelLinear,
|
||||
QKVParallelLinear]) -> None:
|
||||
super().__init__(base_layer)
|
||||
# There are two LoRA layers
|
||||
self.n_slices = len(self.base_layer.output_sizes)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
|
||||
# we need to divide it by the tp_size to get correct slices size
|
||||
output_sizes = self.base_layer.output_sizes
|
||||
self.output_slices = tuple(
|
||||
divide(output_size, self.tp_size) for output_size in output_sizes)
|
||||
self.n_slices = len(self.output_slices)
|
||||
self.output_ids = (self.tp_rank, ) * self.n_slices
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
@ -559,15 +569,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
self.lora_config = lora_config
|
||||
|
||||
if not (len(self.base_layer.output_sizes) == self.n_slices == 2
|
||||
and self.base_layer.output_sizes[0]
|
||||
== self.base_layer.output_sizes[1]):
|
||||
raise ValueError(
|
||||
"LoRAColumnParallelLinear2Slice requires 2 slices with "
|
||||
"the same size.")
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
lora_a_output_size_per_partition = (
|
||||
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
||||
else divide(lora_config.max_lora_rank, self.tp_size))
|
||||
@ -585,22 +586,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.output_size // 2,
|
||||
output_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
) for _ in range(self.n_slices))
|
||||
) for output_size in self.output_slices)
|
||||
if lora_config.bias_enabled:
|
||||
self.lora_bias_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.output_size // 2,
|
||||
output_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
) for _ in range(self.n_slices))
|
||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||
self.output_slices = (self.output_dim, self.output_dim)
|
||||
) for output_size in self.output_slices)
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
@ -610,27 +609,21 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def slice_lora_b(
|
||||
self, lora_b: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
#NOTE: lora_b contains 2 subloras, and each sublora could be None.
|
||||
shard_size = self.output_dim
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = [
|
||||
lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
|
||||
lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
|
||||
]
|
||||
for i, (shard_id, shard_size) in enumerate(
|
||||
zip(self.output_ids, self.output_slices)):
|
||||
if (lora_b_i := lora_b[i]) is not None:
|
||||
lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
|
||||
(shard_id + 1)]
|
||||
return lora_b
|
||||
|
||||
def slice_bias(
|
||||
self, bias: List[Union[torch.Tensor,
|
||||
None]]) -> List[Union[torch.Tensor, None]]:
|
||||
# NOTE : each bias could be None.
|
||||
shard_size = self.output_dim
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
bias = [
|
||||
bias[0][start_idx:end_idx] if bias[0] is not None else None,
|
||||
bias[1][start_idx:end_idx] if bias[1] is not None else None
|
||||
]
|
||||
for i, (shard_id, shard_size) in enumerate(
|
||||
zip(self.output_ids, self.output_slices)):
|
||||
if (bias_i := bias[i]) is not None:
|
||||
bias[i] = bias_i[shard_size * shard_id:shard_size *
|
||||
(shard_id + 1)]
|
||||
return bias
|
||||
|
||||
def set_lora(
|
||||
@ -649,30 +642,25 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
if lora_bias is not None:
|
||||
lora_bias = self.slice_bias(lora_bias)
|
||||
|
||||
if lora_a[0] is not None:
|
||||
self.lora_a_stacked[0][
|
||||
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
||||
lora_a[0].T, non_blocking=True)
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||
lora_b[0].T, non_blocking=True)
|
||||
if lora_bias is not None and lora_bias[0] is not None:
|
||||
for i in range(self.n_slices):
|
||||
if (lora_a_i := lora_a[i]) is not None:
|
||||
self.lora_a_stacked[i][
|
||||
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
|
||||
lora_a_i.T, non_blocking=True)
|
||||
if (lora_b_i := lora_b[i]) is not None:
|
||||
self.lora_b_stacked[i][
|
||||
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
|
||||
lora_b_i.T, non_blocking=True)
|
||||
|
||||
if lora_bias is not None:
|
||||
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||
self.lora_bias_stacked)
|
||||
self.lora_bias_stacked[0][index, 0, :lora_bias[0].shape[0]].copy_(
|
||||
lora_bias[0].T, non_blocking=True)
|
||||
if lora_a[1] is not None:
|
||||
self.lora_a_stacked[1][
|
||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||
lora_a[1].T, non_blocking=True)
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||
lora_b[1].T, non_blocking=True)
|
||||
if lora_bias is not None and lora_bias[1] is not None:
|
||||
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||
self.lora_bias_stacked)
|
||||
self.lora_bias_stacked[1][index, 0, :lora_bias[1].shape[0]].copy_(
|
||||
lora_bias[1].T, non_blocking=True)
|
||||
for i in range(self.n_slices):
|
||||
if (lora_bias_i := lora_bias[i]) is not None:
|
||||
self.lora_bias_stacked[i][index,
|
||||
0, :lora_bias_i.shape[0]].copy_(
|
||||
lora_bias_i.T,
|
||||
non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
@ -755,8 +743,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
packed_modules_list) == 1
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
||||
class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA):
|
||||
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
||||
packed together in qkv proj fashion
|
||||
(q_proj + k_proj + v_proj -> qkv_proj).
|
||||
|
||||
@ -773,6 +761,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||
self.base_layer.head_size)
|
||||
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
||||
self.base_layer.head_size)
|
||||
self.q_shard_id = self.tp_rank
|
||||
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
||||
|
||||
self.output_slices = (
|
||||
self.q_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
)
|
||||
self.output_ids = (
|
||||
self.q_shard_id,
|
||||
self.kv_shard_id,
|
||||
self.kv_shard_id,
|
||||
)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
@ -783,216 +789,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
The main reason for overloading this function is to handle inconsistent
|
||||
weight dimensions in qkv lora.
|
||||
"""
|
||||
self.lora_config = lora_config
|
||||
|
||||
if not (len(self.base_layer.output_sizes) == self.n_slices == 3):
|
||||
raise ValueError(
|
||||
"LoRAColumnParallelLinear3Slice requires 3 slices.")
|
||||
|
||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||
self.base_layer.head_size)
|
||||
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
||||
self.base_layer.head_size)
|
||||
self.q_shard_id = self.tp_rank
|
||||
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
||||
|
||||
lora_a_output_size_per_partition = (
|
||||
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
||||
else divide(lora_config.max_lora_rank, self.tp_size))
|
||||
# q, k, v
|
||||
self.lora_a_stacked = (
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
self.lora_b_stacked = (
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.q_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
if lora_config.bias_enabled:
|
||||
self.lora_bias_stacked = (
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.q_proj_shard_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
self.output_slices = (
|
||||
self.q_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
)
|
||||
self.packed_indices: Optional[torch.Tensor] = None
|
||||
self.standard_indices: Optional[torch.Tensor] = None
|
||||
# lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(
|
||||
self, lora_b: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
lora_b_q, lora_b_k, lora_b_v = None, None, None
|
||||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1), ]
|
||||
if lora_b[1] is not None:
|
||||
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1), ]
|
||||
if lora_b[2] is not None:
|
||||
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1), ]
|
||||
lora_b = [lora_b_q, lora_b_k, lora_b_v]
|
||||
return lora_b
|
||||
|
||||
def slice_bias(
|
||||
self, bias: List[Union[torch.Tensor,
|
||||
None]]) -> List[Union[torch.Tensor, None]]:
|
||||
bias_q, bias_k, bias_v = bias
|
||||
if bias_q is not None:
|
||||
bias_q = bias_q[self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
if bias_k is not None:
|
||||
bias_k = bias_k[self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
if bias_v is not None:
|
||||
bias_v = bias_v[self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
bias = [bias_q, bias_k, bias_v]
|
||||
return bias
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
lora_bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
if lora_bias is not None:
|
||||
lora_bias = self.slice_bias(lora_bias)
|
||||
|
||||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0]
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
|
||||
lora_b_q.T, non_blocking=True)
|
||||
if lora_b[1] is not None:
|
||||
lora_b_k = lora_b[1]
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
|
||||
lora_b_k.T, non_blocking=True)
|
||||
if lora_b[2] is not None:
|
||||
lora_b_v = lora_b[2]
|
||||
self.lora_b_stacked[2][
|
||||
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
|
||||
lora_b_v.T, non_blocking=True)
|
||||
|
||||
if lora_a[0] is not None:
|
||||
self.lora_a_stacked[0][
|
||||
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
||||
lora_a[0].T, non_blocking=True)
|
||||
if lora_a[1] is not None:
|
||||
self.lora_a_stacked[1][
|
||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||
lora_a[1].T, non_blocking=True)
|
||||
if lora_a[2] is not None:
|
||||
self.lora_a_stacked[2][
|
||||
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
||||
lora_a[2].T, non_blocking=True)
|
||||
|
||||
if lora_bias is not None:
|
||||
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||
self.lora_bias_stacked)
|
||||
if lora_bias[0] is not None:
|
||||
self.lora_bias_stacked[0][index,
|
||||
0, :lora_bias[0].shape[0]].copy_(
|
||||
lora_bias[0].T,
|
||||
non_blocking=True)
|
||||
if lora_bias[1] is not None:
|
||||
self.lora_bias_stacked[1][index,
|
||||
0, :lora_bias[1].shape[0]].copy_(
|
||||
lora_bias[1].T,
|
||||
non_blocking=True)
|
||||
if lora_bias[2] is not None:
|
||||
self.lora_bias_stacked[2][index,
|
||||
0, :lora_bias[2].shape[0]].copy_(
|
||||
lora_bias[2].T,
|
||||
non_blocking=True)
|
||||
super().create_lora_weights(max_loras, lora_config, model_config)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
|
Reference in New Issue
Block a user