mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
fix weigit loading for GQA with TP (#2379)
This commit is contained in:
@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
if loaded_shard_id == "q":
|
||||
shard_id = tp_rank
|
||||
else:
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
Reference in New Issue
Block a user