fix weigit loading for GQA with TP (#2379)

This commit is contained in:
Chenhui Zhang
2024-01-16 07:43:59 +08:00
committed by GitHub
parent bfc072addf
commit f780504d12

View File

@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = shard_offset // param.pack_factor
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
shard_id = tp_rank // self.num_kv_head_replicas
if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)