[Misc][LoRA] Refactor and clean MergedQKVParallelLinearWithLora implementation (#10958)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-12-07 18:33:49 +08:00
committed by GitHub
parent f13cf9ad50
commit b26b4cd03c

View File

@ -542,10 +542,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
Both slices must have the same size.
"""
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
def __init__(
self, base_layer: Union[MergedColumnParallelLinear,
QKVParallelLinear]) -> None:
super().__init__(base_layer)
# There are two LoRA layers
self.n_slices = len(self.base_layer.output_sizes)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size
output_sizes = self.base_layer.output_sizes
self.output_slices = tuple(
divide(output_size, self.tp_size) for output_size in output_sizes)
self.n_slices = len(self.output_slices)
self.output_ids = (self.tp_rank, ) * self.n_slices
def create_lora_weights(
self,
@ -559,15 +569,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""
self.lora_config = lora_config
if not (len(self.base_layer.output_sizes) == self.n_slices == 2
and self.base_layer.output_sizes[0]
== self.base_layer.output_sizes[1]):
raise ValueError(
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size.")
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
@ -585,22 +586,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
torch.zeros(
max_loras,
1,
self.output_size // 2,
output_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(self.n_slices))
) for output_size in self.output_slices)
if lora_config.bias_enabled:
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
self.output_size // 2,
output_size,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(self.n_slices))
self.output_dim = self.lora_b_stacked[0].shape[2]
self.output_slices = (self.output_dim, self.output_dim)
) for output_size in self.output_slices)
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
@ -610,27 +609,21 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
#NOTE: lora_b contains 2 subloras, and each sublora could be None.
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = [
lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
]
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)):
if (lora_b_i := lora_b[i]) is not None:
lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
(shard_id + 1)]
return lora_b
def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
# NOTE : each bias could be None.
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = [
bias[0][start_idx:end_idx] if bias[0] is not None else None,
bias[1][start_idx:end_idx] if bias[1] is not None else None
]
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)):
if (bias_i := bias[i]) is not None:
bias[i] = bias_i[shard_size * shard_id:shard_size *
(shard_id + 1)]
return bias
def set_lora(
@ -649,30 +642,25 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)
if lora_a[0] is not None:
self.lora_a_stacked[0][
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
lora_a[0].T, non_blocking=True)
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_bias is not None and lora_bias[0] is not None:
for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None:
self.lora_a_stacked[i][
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
lora_a_i.T, non_blocking=True)
if (lora_b_i := lora_b[i]) is not None:
self.lora_b_stacked[i][
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
lora_b_i.T, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias[0].shape[0]].copy_(
lora_bias[0].T, non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
lora_a[1].T, non_blocking=True)
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
if lora_bias is not None and lora_bias[1] is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked)
self.lora_bias_stacked[1][index, 0, :lora_bias[1].shape[0]].copy_(
lora_bias[1].T, non_blocking=True)
for i in range(self.n_slices):
if (lora_bias_i := lora_bias[i]) is not None:
self.lora_bias_stacked[i][index,
0, :lora_bias_i.shape[0]].copy_(
lora_bias_i.T,
non_blocking=True)
@classmethod
@_not_fully_sharded_can_replace
@ -755,8 +743,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
packed_modules_list) == 1
class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA):
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
@ -773,6 +761,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
self.output_slices = (
self.q_proj_shard_size,
self.kv_proj_shard_size,
self.kv_proj_shard_size,
)
self.output_ids = (
self.q_shard_id,
self.kv_shard_id,
self.kv_shard_id,
)
def create_lora_weights(
self,
max_loras: int,
@ -783,216 +789,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
The main reason for overloading this function is to handle inconsistent
weight dimensions in qkv lora.
"""
self.lora_config = lora_config
if not (len(self.base_layer.output_sizes) == self.n_slices == 3):
raise ValueError(
"LoRAColumnParallelLinear3Slice requires 3 slices.")
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
# q, k, v
self.lora_a_stacked = (
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
)
self.lora_b_stacked = (
torch.zeros(
max_loras,
1,
self.q_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
),
)
if lora_config.bias_enabled:
self.lora_bias_stacked = (
torch.zeros(
max_loras,
1,
self.q_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
)
self.output_slices = (
self.q_proj_shard_size,
self.kv_proj_shard_size,
self.kv_proj_shard_size,
)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
lora_b_q, lora_b_k, lora_b_v = None, None, None
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1), ]
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1), ]
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1), ]
lora_b = [lora_b_q, lora_b_k, lora_b_v]
return lora_b
def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
bias_q, bias_k, bias_v = bias
if bias_q is not None:
bias_q = bias_q[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
if bias_k is not None:
bias_k = bias_k[self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
if bias_v is not None:
bias_v = bias_v[self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
bias = [bias_q, bias_k, bias_v]
return bias
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)
if lora_b[0] is not None:
lora_b_q = lora_b[0]
self.lora_b_stacked[0][
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
lora_b_q.T, non_blocking=True)
if lora_b[1] is not None:
lora_b_k = lora_b[1]
self.lora_b_stacked[1][
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
lora_b_k.T, non_blocking=True)
if lora_b[2] is not None:
lora_b_v = lora_b[2]
self.lora_b_stacked[2][
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
lora_b_v.T, non_blocking=True)
if lora_a[0] is not None:
self.lora_a_stacked[0][
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
lora_a[0].T, non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
lora_a[1].T, non_blocking=True)
if lora_a[2] is not None:
self.lora_a_stacked[2][
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked)
if lora_bias[0] is not None:
self.lora_bias_stacked[0][index,
0, :lora_bias[0].shape[0]].copy_(
lora_bias[0].T,
non_blocking=True)
if lora_bias[1] is not None:
self.lora_bias_stacked[1][index,
0, :lora_bias[1].shape[0]].copy_(
lora_bias[1].T,
non_blocking=True)
if lora_bias[2] is not None:
self.lora_bias_stacked[2][index,
0, :lora_bias[2].shape[0]].copy_(
lora_bias[2].T,
non_blocking=True)
super().create_lora_weights(max_loras, lora_config, model_config)
@classmethod
@_not_fully_sharded_can_replace