[Model] use autoWeightsLoader for gptoss (#22446)

Signed-off-by: calvin chen <wen.chen@dynamia.ai>
This commit is contained in:
Calvin Chen
2025-08-20 18:16:27 +08:00
committed by GitHub
parent d983769c41
commit 103f1ec8d3

View File

@ -27,7 +27,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from .utils import extract_layer_index, maybe_prefix
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
maybe_prefix)
class OAIAttention(nn.Module):
@ -203,6 +204,7 @@ class GptOssModel(nn.Module):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding(
self.config.vocab_size,
@ -225,8 +227,364 @@ class GptOssModel(nn.Module):
x = self.norm(x)
return x
def _load_weights_mxfp4(
self,
ep_rank_end: int,
ep_rank_start: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
mxfp4_block = 32
use_ep = self.parallel_config.enable_expert_parallel
num_experts = self.config.num_local_experts
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
tp_size)
per_rank_intermediate_size = (per_rank_intermediate_size_block *
mxfp4_block)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
for name, weight in weights:
# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()
if ".w13_weight_scale" in name:
# Handle MLP gate and up projection weights scale
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w2_weight_scale" in name:
# Handle MLP down projection weights
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[..., tp_rank_start //
mxfp4_block:tp_rank_end //
mxfp4_block]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w13_weight" in name:
# Handle MLP gate and up projection weights
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(num_experts, 2 * intermediate_size,
-1).contiguous()
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w2_weight" in name:
# Handle MLP down projection weights
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(num_experts, -1,
intermediate_size // 2).contiguous()
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[...,
tp_rank_start // 2:tp_rank_end // 2]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w2_bias" in name:
# Handle MLP down projection bias
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
weight_loader(param,
weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
else:
# Handle all other weights with potential renaming
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(name)
return loaded_params
def _load_weights_other(
self,
ep_rank_start: int,
ep_rank_end: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
use_ep = self.parallel_config.enable_expert_parallel
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
for name, weight in weights:
if ".w13_weight" in name:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :,
2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w2_weight" in name:
# Handle MLP down projection weights
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w2_bias" in name:
# Handle MLP down projection bias
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[name]
param.copy_(weight)
loaded_params.add(name)
continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
else:
# Handle all other weights with potential renaming
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(name)
return loaded_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv", ".q_proj", "q"),
(".qkv", ".k_proj", "k"),
(".qkv", ".v_proj", "v"),
]
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
# Attention heads per rank
heads_per_rank = self.config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
quant_method = (self.config.quantization_config['quant_method'] if
hasattr(self.config, "quantization_config") else None)
if quant_method == "mxfp4":
return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
heads_per_rank, head_start,
weights, stacked_params_mapping)
else:
return self._load_weights_other(ep_rank_end, ep_rank_start,
heads_per_rank, head_start,
weights, stacked_params_mapping)
class GptOssForCausalLM(nn.Module):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
".self_attn.": ".attn.",
".post_attention_layernorm.": ".mlp.norm.",
},
orig_to_new_suffix={
".embed_tokens.weight": ".embedding.weight",
".input_layernorm.weight": ".attn.norm.weight",
".post_attention_layernorm.weight": ".mlp.norm.weight",
# MoE MXFP4 weights
".gate_up_proj_blocks": ".w13_weight",
".down_proj_blocks": ".w2_weight",
".gate_up_proj_scales": ".w13_weight_scale",
".down_proj_scales": ".w2_weight_scale",
# MoE other weights
".gate_up_proj": ".w13_weight",
".down_proj": ".w2_weight",
# MoE Bias
".gate_up_proj_bias": ".w13_bias",
".down_proj_bias": ".w2_bias",
},
)
def __init__(
self,
@ -235,16 +593,17 @@ class GptOssForCausalLM(nn.Module):
):
super().__init__()
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config.hf_config
self.config = vllm_config.model_config.hf_config
self.model = GptOssModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead(
self.model_config.vocab_size,
self.model_config.hidden_size,
self.config.vocab_size,
self.config.hidden_size,
)
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
self.logits_processor = LogitsProcessor(self.config.vocab_size)
def forward(self,
input_ids: torch.Tensor,
@ -261,354 +620,11 @@ class GptOssForCausalLM(nn.Module):
sampling_metadata)
return logits
def _load_weights_mxfp4(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
"post_attention_layernorm.weight": "mlp.norm.weight",
"embed_tokens": "embedding",
}
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
mxfp4_block = 32
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
tp_size)
per_rank_intermediate_size = (per_rank_intermediate_size_block *
mxfp4_block)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights:
# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()
if "gate_up_proj_blocks" in name:
# Handle MLP gate and up projection weights
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(num_experts, 2 * intermediate_size,
-1).contiguous()
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_blocks" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_blocks", "w2_weight")
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(num_experts, -1,
intermediate_size // 2).contiguous()
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[...,
tp_rank_start // 2:tp_rank_end // 2]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "gate_up_proj_scales" in name:
# Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_scales",
"w13_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_scales" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_scales", "w2_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[..., tp_rank_start //
mxfp4_block:tp_rank_end //
mxfp4_block]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias")
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
weight_loader(param,
weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
shard_id = ("q" if "q_proj" in name else
"k" if "k_proj" in name else "v")
name = name.replace("self_attn", "attn")
param_name = name.replace(f"{shard_id}_proj", "qkv")
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, weight, loaded_shard_id=shard_id)
loaded_params.add(param_name)
else:
# Handle all other weights with potential renaming
renamed_name = maybe_rename(name)
if renamed_name not in params_dict:
continue
param = params_dict[renamed_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(renamed_name)
return loaded_params
def _load_weights_other(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
"post_attention_layernorm.weight": "mlp.norm.weight",
"embed_tokens": "embedding",
}
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights:
if ".experts.gate_up_proj" in name and "bias" not in name:
# Handle MLP gate and up projection weights
new_name = name.replace(".experts.gate_up_proj",
".experts.w13_weight")
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :,
2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif ".experts.down_proj" in name and "bias" not in name:
# Handle MLP down projection weights
new_name = name.replace(".experts.down_proj",
".experts.w2_weight")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias")
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[new_name]
param.copy_(weight)
loaded_params.add(new_name)
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
shard_id = ("q" if "q_proj" in name else
"k" if "k_proj" in name else "v")
name = name.replace("self_attn", "attn")
param_name = name.replace(f"{shard_id}_proj", "qkv")
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, weight, loaded_shard_id=shard_id)
loaded_params.add(param_name)
else:
# Handle all other weights with potential renaming
renamed_name = maybe_rename(name)
if renamed_name not in params_dict:
continue
param = params_dict[renamed_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(renamed_name)
return loaded_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
quant_method = (self.model_config.quantization_config['quant_method']
if hasattr(self.model_config, "quantization_config")
else None)
if quant_method == "mxfp4":
return self._load_weights_mxfp4(weights)
else:
return self._load_weights_other(weights)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)