mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
This commit is contained in:
@ -250,7 +250,7 @@ class GPTJForCausalLM(nn.Module):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[1]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
|
Reference in New Issue
Block a user