[Model] Support HF format of minimax (#20211)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-07-11 11:55:21 +09:00
committed by GitHub
parent 5923ab9524
commit 922f316441
3 changed files with 36 additions and 11 deletions

View File

@ -218,6 +218,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True),
"MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf",
min_transformers_version="4.53"),
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True,
revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501

View File

@ -667,16 +667,24 @@ class MiniMaxText01DecoderLayer(nn.Module):
eps=config.rms_norm_eps)
if config.attention_type == 0:
self.layernorm_attention_alpha = getattr(
config, 'layernorm_linear_attention_alpha', 1)
config, 'layernorm_linear_attention_alpha',
getattr(config, 'linear_attn_alpha_factor', 1))
self.layernorm_attention_beta = getattr(
config, 'layernorm_linear_attention_beta', 1)
config, 'layernorm_linear_attention_beta',
getattr(config, 'linear_attn_beta_factor', 1))
else:
self.layernorm_attention_alpha = getattr(
config, 'layernorm_full_attention_alpha', 1)
config, 'layernorm_full_attention_alpha',
getattr(config, 'full_attn_alpha_factor', 1))
self.layernorm_attention_beta = getattr(
config, 'layernorm_full_attention_beta', 1)
self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
config, 'layernorm_full_attention_beta',
getattr(config, 'full_attn_beta_factor', 1))
self.layernorm_mlp_alpha = getattr(
config, 'layernorm_mlp_alpha',
getattr(config, 'mlp_alpha_factor', 1))
self.layernorm_mlp_beta = getattr(
config, 'layernorm_mlp_beta', getattr(config, 'mlp_beta_factor',
1))
self.postnorm = getattr(config, 'postnorm', False)
self.shared_moe = False
@ -794,6 +802,18 @@ class MiniMaxText01Model(nn.Module):
self.decoder_attention_types = getattr(
config, "attn_type_list", False) or getattr(
config, "decoder_attention_types", False)
# The HF format uses "layer_types" instead of "attn_type_list"
# where "linear_attention" is 0 and "full_attention" is 1
if not self.decoder_attention_types and hasattr(config, "layer_types"):
self.decoder_attention_types = []
for layer_type in config.layer_types:
if layer_type == "linear_attention":
self.decoder_attention_types.append(0)
elif layer_type == "full_attention":
self.decoder_attention_types.append(1)
else:
raise ValueError(f"Unsupported layer type: {layer_type}")
# Default to full attention
if not self.decoder_attention_types:
self.decoder_attention_types = [1] * config.num_hidden_layers
self.num_layers = config.num_hidden_layers
@ -1022,8 +1042,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
else:
self.lm_head = PPMissingLayer()
self.lm_head.float()
flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
if attn_type == 1)
flash_layer_count = sum(
1 for attn_type in self.model.decoder_attention_types
if attn_type == 1)
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
return
@ -1085,9 +1106,10 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
return None
def is_linear_attn_layer(layer_idx: int) -> bool:
if layer_idx is None or not hasattr(self.config, "attn_type_list"):
if layer_idx is None or layer_idx >= len(
self.model.decoder_attention_types):
return False
return self.config.attn_type_list[layer_idx] == 0
return self.model.decoder_attention_types[layer_idx] == 0
def is_moe_weight(name: str) -> bool:
return "block_sparse_moe" in name and not name.endswith(".bias")
@ -1275,7 +1297,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
for name, loaded_weight in weights:
weight_at_layer = which_layer(name)
if weight_at_layer and weight_at_layer >= len(
self.config.attn_type_list):
self.model.decoder_attention_types):
continue
if is_layer_norm_weight(name):

View File

@ -34,6 +34,7 @@ _TEXT_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
# baichuan-7b, upper case 'C' in the class name