mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] use autoWeightsLoader for gptoss (#22446)
Signed-off-by: calvin chen <wen.chen@dynamia.ai>
This commit is contained in:
@ -27,7 +27,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from .utils import extract_layer_index, maybe_prefix
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class OAIAttention(nn.Module):
|
||||
@ -203,6 +204,7 @@ class GptOssModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.config.hidden_size = self.config.hidden_size
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
@ -225,8 +227,364 @@ class GptOssModel(nn.Module):
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def _load_weights_mxfp4(
|
||||
self,
|
||||
ep_rank_end: int,
|
||||
ep_rank_start: int,
|
||||
heads_per_rank: int,
|
||||
head_start: int,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
stacked_params_mapping: list[tuple[str, ...]],
|
||||
) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
mxfp4_block = 32
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
num_experts = self.config.num_local_experts
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
|
||||
tp_size)
|
||||
per_rank_intermediate_size = (per_rank_intermediate_size_block *
|
||||
mxfp4_block)
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
for name, weight in weights:
|
||||
# FIXME(woosuk): Remove this after testing.
|
||||
weight = weight.cuda()
|
||||
|
||||
if ".w13_weight_scale" in name:
|
||||
# Handle MLP gate and up projection weights scale
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_weight_scale" in name:
|
||||
# Handle MLP down projection weights
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[..., tp_rank_start //
|
||||
mxfp4_block:tp_rank_end //
|
||||
mxfp4_block]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w13_weight" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||
weight = weight.view(num_experts, 2 * intermediate_size,
|
||||
-1).contiguous()
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_weight" in name:
|
||||
# Handle MLP down projection weights
|
||||
# same flatten here, but since 2 mx4 value are packed in 1
|
||||
# uint8, divide by 2
|
||||
weight = weight.view(num_experts, -1,
|
||||
intermediate_size // 2).contiguous()
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[...,
|
||||
tp_rank_start // 2:tp_rank_end // 2]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w13_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
weight_loader(param,
|
||||
weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, weight)
|
||||
else:
|
||||
weight_loader(param, weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _load_weights_other(
|
||||
self,
|
||||
ep_rank_start: int,
|
||||
ep_rank_end: int,
|
||||
heads_per_rank: int,
|
||||
head_start: int,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
stacked_params_mapping: list[tuple[str, ...]],
|
||||
) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
for name, weight in weights:
|
||||
if ".w13_weight" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
# Extract gate and up projection parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, :,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_weight" in name:
|
||||
# Handle MLP down projection weights
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w13_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[name]
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
param = params_dict[name]
|
||||
param.copy_(weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, weight)
|
||||
else:
|
||||
weight_loader(param, weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv", ".q_proj", "q"),
|
||||
(".qkv", ".k_proj", "k"),
|
||||
(".qkv", ".v_proj", "v"),
|
||||
]
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
ep_size = get_ep_group().world_size
|
||||
ep_rank = get_ep_group().rank
|
||||
num_experts = self.config.num_local_experts
|
||||
experts_per_rank = num_experts // ep_size
|
||||
ep_rank_start = ep_rank * experts_per_rank
|
||||
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||
|
||||
quant_method = (self.config.quantization_config['quant_method'] if
|
||||
hasattr(self.config, "quantization_config") else None)
|
||||
if quant_method == "mxfp4":
|
||||
return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
|
||||
heads_per_rank, head_start,
|
||||
weights, stacked_params_mapping)
|
||||
else:
|
||||
return self._load_weights_other(ep_rank_end, ep_rank_start,
|
||||
heads_per_rank, head_start,
|
||||
weights, stacked_params_mapping)
|
||||
|
||||
|
||||
class GptOssForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
".self_attn.": ".attn.",
|
||||
".post_attention_layernorm.": ".mlp.norm.",
|
||||
},
|
||||
orig_to_new_suffix={
|
||||
".embed_tokens.weight": ".embedding.weight",
|
||||
".input_layernorm.weight": ".attn.norm.weight",
|
||||
".post_attention_layernorm.weight": ".mlp.norm.weight",
|
||||
|
||||
# MoE MXFP4 weights
|
||||
".gate_up_proj_blocks": ".w13_weight",
|
||||
".down_proj_blocks": ".w2_weight",
|
||||
".gate_up_proj_scales": ".w13_weight_scale",
|
||||
".down_proj_scales": ".w2_weight_scale",
|
||||
|
||||
# MoE other weights
|
||||
".gate_up_proj": ".w13_weight",
|
||||
".down_proj": ".w2_weight",
|
||||
|
||||
# MoE Bias
|
||||
".gate_up_proj_bias": ".w13_bias",
|
||||
".down_proj_bias": ".w2_bias",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -235,16 +593,17 @@ class GptOssForCausalLM(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config.hf_config
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
|
||||
self.model = GptOssModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.model_config.vocab_size,
|
||||
self.model_config.hidden_size,
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -261,354 +620,11 @@ class GptOssForCausalLM(nn.Module):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def _load_weights_mxfp4(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rename_mapping = {
|
||||
"self_attn": "attn",
|
||||
"input_layernorm.weight": "attn.norm.weight",
|
||||
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||
"embed_tokens": "embedding",
|
||||
}
|
||||
|
||||
def maybe_rename(name: str) -> str:
|
||||
for remap_name, new_name in rename_mapping.items():
|
||||
if remap_name in name:
|
||||
return name.replace(remap_name, new_name)
|
||||
return name
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
mxfp4_block = 32
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = self.model_config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
|
||||
tp_size)
|
||||
per_rank_intermediate_size = (per_rank_intermediate_size_block *
|
||||
mxfp4_block)
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||
ep_size = get_ep_group().world_size
|
||||
ep_rank = get_ep_group().rank
|
||||
num_experts = self.model_config.num_local_experts
|
||||
experts_per_rank = num_experts // ep_size
|
||||
ep_rank_start = ep_rank * experts_per_rank
|
||||
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||
|
||||
for name, weight in weights:
|
||||
# FIXME(woosuk): Remove this after testing.
|
||||
weight = weight.cuda()
|
||||
|
||||
if "gate_up_proj_blocks" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
|
||||
|
||||
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||
weight = weight.view(num_experts, 2 * intermediate_size,
|
||||
-1).contiguous()
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_blocks" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_blocks", "w2_weight")
|
||||
# same flatten here, but since 2 mx4 value are packed in 1
|
||||
# uint8, divide by 2
|
||||
weight = weight.view(num_experts, -1,
|
||||
intermediate_size // 2).contiguous()
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[...,
|
||||
tp_rank_start // 2:tp_rank_end // 2]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "gate_up_proj_scales" in name:
|
||||
# Handle MLP gate and up projection weights scale
|
||||
new_name = name.replace("gate_up_proj_scales",
|
||||
"w13_weight_scale")
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_scales" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[..., tp_rank_start //
|
||||
mxfp4_block:tp_rank_end //
|
||||
mxfp4_block]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
elif "gate_up_proj_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
weight_loader(param,
|
||||
weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
name = name.replace("self_attn", "attn")
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||
shard_id = ("q" if "q_proj" in name else
|
||||
"k" if "k_proj" in name else "v")
|
||||
name = name.replace("self_attn", "attn")
|
||||
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||
param = params_dict[param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||
loaded_params.add(param_name)
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
renamed_name = maybe_rename(name)
|
||||
if renamed_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[renamed_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(renamed_name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def _load_weights_other(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rename_mapping = {
|
||||
"self_attn": "attn",
|
||||
"input_layernorm.weight": "attn.norm.weight",
|
||||
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||
"embed_tokens": "embedding",
|
||||
}
|
||||
|
||||
def maybe_rename(name: str) -> str:
|
||||
for remap_name, new_name in rename_mapping.items():
|
||||
if remap_name in name:
|
||||
return name.replace(remap_name, new_name)
|
||||
return name
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = self.model_config.intermediate_size
|
||||
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||
ep_size = get_ep_group().world_size
|
||||
ep_rank = get_ep_group().rank
|
||||
num_experts = self.model_config.num_local_experts
|
||||
experts_per_rank = num_experts // ep_size
|
||||
ep_rank_start = ep_rank * experts_per_rank
|
||||
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||
|
||||
for name, weight in weights:
|
||||
if ".experts.gate_up_proj" in name and "bias" not in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
new_name = name.replace(".experts.gate_up_proj",
|
||||
".experts.w13_weight")
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, :,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[new_name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif ".experts.down_proj" in name and "bias" not in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace(".experts.down_proj",
|
||||
".experts.w2_weight")
|
||||
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[new_name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "gate_up_proj_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[new_name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
param = params_dict[new_name]
|
||||
param.copy_(weight)
|
||||
loaded_params.add(new_name)
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
name = name.replace("self_attn", "attn")
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||
shard_id = ("q" if "q_proj" in name else
|
||||
"k" if "k_proj" in name else "v")
|
||||
name = name.replace("self_attn", "attn")
|
||||
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||
param = params_dict[param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||
loaded_params.add(param_name)
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
|
||||
renamed_name = maybe_rename(name)
|
||||
if renamed_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[renamed_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(renamed_name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
quant_method = (self.model_config.quantization_config['quant_method']
|
||||
if hasattr(self.model_config, "quantization_config")
|
||||
else None)
|
||||
if quant_method == "mxfp4":
|
||||
return self._load_weights_mxfp4(weights)
|
||||
else:
|
||||
return self._load_weights_other(weights)
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
Reference in New Issue
Block a user