This commit is contained in:
Arthur
2025-10-17 10:20:46 +02:00
parent a08b927826
commit bfb804756d
12 changed files with 176 additions and 194 deletions

View File

@ -5,7 +5,7 @@
#
# Either we keep it here, or we move it to the config, but for newcomers, seeing this is kinda weird no?
from .core_model_loading import Concatenate, MergeModuleList, WeightConversion, Fp8Quantize, Shard
from .core_model_loading import Concatenate, Fp8Quantize, MergeModuleList, Shard, WeightConversion
_checkpoint_conversion_mapping = {

View File

@ -14,14 +14,13 @@
# limitations under the License.
"""Core helpers for loading model checkpoints."""
from collections import defaultdict
from __future__ import annotations
import re
import math
import re
import time
from abc import abstractmethod
from collections import OrderedDict
from collections import defaultdict
from collections.abc import Sequence
from contextlib import nullcontext
from dataclasses import dataclass
@ -51,14 +50,13 @@ except AttributeError:
)
try:
from torch.profiler import ProfilerActivity, profile as torch_profile
from torch.profiler import ProfilerActivity
from torch.profiler import profile as torch_profile
except (ImportError, AttributeError):
ProfilerActivity = None
torch_profile = None
class ConversionOps:
"""Base class for weight conversion operations.
@ -308,6 +306,7 @@ class SplitModuleList(ConversionOps):
result.append(list(splits))
return result
class Cast(ConversionOps):
"""
Casts the tensor to a given dtype
@ -316,6 +315,7 @@ class Cast(ConversionOps):
def __init__(self, dtype):
self.dtype = dtype
class To(ConversionOps):
"""
Transfers the tensor to the provided device potentially using a stream?
@ -327,9 +327,11 @@ class To(ConversionOps):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
"""
def __init__(self, device):
self.device = device
class Shard(ConversionOps):
"""Shard tensors along a specific dimension.
@ -555,16 +557,18 @@ def convert_state_dict(model, state_dict, weight_mapping, tp_plan, quantization_
# 1. we need to find which key we have (so we keep track of which pattern was matched)
converted_state_dict: dict[str, torch.Tensor] = {}
used_operations: list[ConversionOps] = []
keys_to_convert = [ rf"{ '|'.join(k.source_keys) if isinstance(k.source_keys, list) else k.source_keys}" for k in weight_mapping ]
keys_to_convert = [
rf"{'|'.join(k.source_keys) if isinstance(k.source_keys, list) else k.source_keys}" for k in weight_mapping
]
# tensor parallel is also a conversion scheme! So add it to the keys to convert!
# quantization as well! But for quantization we would need to get the module, check if its a linear?
for k,v in state_dict.items():
if re.sub(rf"^({ '|'.join(keys_to_convert) })$", "", k) == k:
for k, v in state_dict.items():
if re.sub(rf"^({'|'.join(keys_to_convert)})$", "", k) == k:
converted_state_dict[k] = v
else:
# we replace the whole key by the matched pattern so that we can find it later
pattern = re.sub(rf"^({ '|'.join(keys_to_convert) })$", r"\1", k)
pattern = re.sub(rf"^({'|'.join(keys_to_convert)})$", r"\1", k)
collected_keys[pattern][k] += [v] # we collect all tensors that match the pattern
converter = weight_mapping[pattern]
if pattern in tp_plan: # If we want this to work conversion needs to be explicit no?

View File

@ -512,10 +512,8 @@ def accelerate_disk_offload(
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
):
disk_only_shard_files = []
if disk_offload_folder is not None:
@ -534,19 +532,13 @@ def accelerate_disk_offload(
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
# Fix the weight map keys according to the key mapping
weight_map = {
key_renaming_mapping[k]: v
for k, v in sharded_metadata["weight_map"].items()
if k in key_renaming_mapping
}
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
disk_offload_index = {
name: {
"safetensors_file": file,
"weight_name": reverse_key_renaming_mapping[name],
"weight_name": name,
"dtype": str_dtype,
}
for name, file in weight_map.items()

View File

@ -26,13 +26,13 @@ import sys
import warnings
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from enum import Enum
from functools import partial, wraps
from threading import Thread
from typing import Any, Optional, Sequence, TypeVar, Union, get_type_hints
from typing import Any, Optional, TypeVar, Union, get_type_hints
from zipfile import is_zipfile
import torch
@ -45,6 +45,8 @@ from torch.distributions import constraints
from torch.utils.checkpoint import checkpoint
from .configuration_utils import PreTrainedConfig
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
from .core_model_loading import WeightConversion, convert_state_dict
from .distributed import DistributedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig
@ -59,7 +61,6 @@ from .integrations.accelerate import (
init_empty_weights,
)
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .core_model_loading import QuantizationOp, Shard, WeightConversion, convert_state_dict
from .integrations.eager_paged import eager_paged_attention_forward
from .integrations.flash_attention import flash_attention_forward
from .integrations.flash_paged import paged_attention_forward
@ -125,7 +126,6 @@ from .utils.import_utils import (
is_torchdynamo_compiling,
)
from .utils.quantization_config import QuantizationMethod
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
if is_accelerate_available():
@ -734,8 +734,6 @@ def load_shard_file(args):
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
def load_shard_files_with_threadpool(args_list):
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
@ -4407,7 +4405,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if model_type is not None:
weight_conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type)
if gguf_file:
if hf_quantizer is not None:
raise ValueError(
@ -4749,7 +4746,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# correctly initialize the missing (and potentially mismatched) keys
model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized)
is_offloaded_safetensors = False
# This offload index if for params explicitly on the "disk" in the device_map
disk_offload_index = None
@ -4761,10 +4757,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
new_state_dict.keys(),
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
)
# To be able to iterate, even if we don't use it if the state_dict is already provided
elif state_dict is not None:
@ -4799,8 +4794,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
device_mesh=device_mesh,
)
# Save offloaded index if needed
if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors:
save_offload_index(disk_offload_index, disk_offload_folder)

View File

@ -243,38 +243,45 @@ class HunYuanMoEV1Gate(nn.Module):
return logits
class HunYuanMoEV1Experts(nn.ModuleList):
"""
ModuleList of experts.
"""
class HunYuanMoEV1Experts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
def __init__(self, config: HunYuanMoEV1Config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
for _ in range(self.num_experts):
self.append(HunYuanMoEV1MLP(config))
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
for expert_idx in expert_hit.tolist():
expert_selection = expert_mask[expert_idx].squeeze(0)
top_indices, token_positions = torch.where(expert_selection)
if token_positions.numel() == 0:
continue
current_state = hidden_states.index_select(0, token_positions)
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
current_hidden_states = self.act_fn(up)
current_hidden_states = current_hidden_states * gate
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states

View File

@ -557,38 +557,45 @@ class JambaMLP(nn.Module):
return down_proj
class JambaExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
class JambaExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
def __init__(self, config: JambaConfig):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
for _ in range(self.num_experts):
self.append(JambaMLP(config))
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
for expert_idx in expert_hit.tolist():
expert_selection = expert_mask[expert_idx].squeeze(0)
top_indices, token_positions = torch.where(expert_selection)
if token_positions.numel() == 0:
continue
current_state = hidden_states.index_select(0, token_positions)
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
current_hidden_states = self.act_fn(up)
current_hidden_states = current_hidden_states * gate
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states

View File

@ -132,9 +132,9 @@ class MiniMaxConfig(PreTrainedConfig):
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
"layers.*.block_sparse_moe.experts.w1": "colwise",
"layers.*.block_sparse_moe.experts.w2": "rowwise",
"layers.*.block_sparse_moe.experts.w3": "colwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -387,56 +387,45 @@ class MiniMaxAttention(nn.Module):
return attn_output, attn_weights
class MiniMaxMLP(nn.Module):
class MiniMaxExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
def __init__(self, config: MiniMaxConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MiniMaxExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: MiniMaxConfig):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
for _ in range(self.num_experts):
self.append(MiniMaxMLP(config))
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
for expert_idx in expert_hit.tolist():
expert_selection = expert_mask[expert_idx].squeeze(0)
top_indices, token_positions = torch.where(expert_selection)
if token_positions.numel() == 0:
continue
current_state = hidden_states.index_select(0, token_positions)
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
current_hidden_states = self.act_fn(up)
current_hidden_states = current_hidden_states * gate
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states

View File

@ -61,18 +61,10 @@ class MixtralExperts(nn.Module):
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.w1 = nn.Parameter(torch.empty(self.num_experts, self.intermediate_dim, self.hidden_dim))
self.w2 = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.w3 = nn.Parameter(torch.empty(self.num_experts, self.intermediate_dim, self.hidden_dim))
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def reset_parameters(self, initializer_range: float):
nn.init.normal_(self.w1, mean=0.0, std=initializer_range)
nn.init.normal_(self.w2, mean=0.0, std=initializer_range)
nn.init.normal_(self.w3, mean=0.0, std=initializer_range)
def forward(
self,
hidden_states: torch.Tensor,
@ -91,11 +83,10 @@ class MixtralExperts(nn.Module):
continue
current_state = hidden_states.index_select(0, token_positions)
current_hidden_states = nn.functional.linear(current_state, self.w1[expert_idx])
current_hidden_states = self.act_fn(current_hidden_states)
gate_hidden_states = nn.functional.linear(current_state, self.w3[expert_idx])
current_hidden_states = current_hidden_states * gate_hidden_states
current_hidden_states = nn.functional.linear(current_hidden_states, self.w2[expert_idx])
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
current_hidden_states = self.act_fn(up)
current_hidden_states = current_hidden_states * gate
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
@ -378,11 +369,6 @@ class MixtralPreTrainedModel(PreTrainedModel):
"hidden_states": MixtralDecoderLayer,
"attentions": MixtralAttention,
}
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, MixtralExperts):
initializer_range = getattr(self.config, "initializer_range", 0.02)
module.reset_parameters(initializer_range)
@auto_docstring

View File

@ -259,12 +259,6 @@ class MixtralPreTrainedModel(MistralPreTrainedModel):
"attentions": MixtralAttention,
}
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, MixtralExperts):
initializer_range = getattr(self.config, "initializer_range", 0.02)
module.reset_parameters(initializer_range)
class MixtralModel(MistralModel):
def forward(

View File

@ -265,13 +265,11 @@ class OlmoeAttention(nn.Module):
return attn_output, attn_weights
class OlmoeExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
class OlmoeExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
def __init__(self, config):
super().__init__()
nn.ModuleList.__init__(self)
for _ in range(config.num_experts):
self.append(OlmoeMLP(config))
self.num_experts = config.num_experts
@ -279,25 +277,32 @@ class OlmoeExperts(nn.ModuleList):
self.norm_topk_prob = config.norm_topk_prob
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
for expert_idx in expert_hit.tolist():
expert_selection = expert_mask[expert_idx].squeeze(0)
top_indices, token_positions = torch.where(expert_selection)
if token_positions.numel() == 0:
continue
current_state = hidden_states.index_select(0, token_positions)
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
current_hidden_states = self.act_fn(up)
current_hidden_states = current_hidden_states * gate
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states

View File

@ -260,37 +260,42 @@ class Qwen2MoeAttention(nn.Module):
return attn_output, attn_weights
class Qwen2MoeExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
class Qwen2MoeExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
def __init__(self, config):
super().__init__()
nn.ModuleList.__init__(self)
self.num_experts = config.num_experts
for _ in range(config.num_experts):
self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
for expert_idx in expert_hit.tolist():
expert_selection = expert_mask[expert_idx].squeeze(0)
top_indices, token_positions = torch.where(expert_selection)
if token_positions.numel() == 0:
continue
current_state = hidden_states.index_select(0, token_positions)
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
current_hidden_states = self.act_fn(up)
current_hidden_states = current_hidden_states * gate
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states