mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 09:03:53 +08:00
update
This commit is contained in:
@ -5,7 +5,7 @@
|
||||
#
|
||||
# Either we keep it here, or we move it to the config, but for newcomers, seeing this is kinda weird no?
|
||||
|
||||
from .core_model_loading import Concatenate, MergeModuleList, WeightConversion, Fp8Quantize, Shard
|
||||
from .core_model_loading import Concatenate, Fp8Quantize, MergeModuleList, Shard, WeightConversion
|
||||
|
||||
|
||||
_checkpoint_conversion_mapping = {
|
||||
|
@ -14,14 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""Core helpers for loading model checkpoints."""
|
||||
|
||||
from collections import defaultdict
|
||||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections import OrderedDict
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
@ -51,14 +50,13 @@ except AttributeError:
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.profiler import ProfilerActivity, profile as torch_profile
|
||||
from torch.profiler import ProfilerActivity
|
||||
from torch.profiler import profile as torch_profile
|
||||
except (ImportError, AttributeError):
|
||||
ProfilerActivity = None
|
||||
torch_profile = None
|
||||
|
||||
|
||||
|
||||
|
||||
class ConversionOps:
|
||||
"""Base class for weight conversion operations.
|
||||
|
||||
@ -308,6 +306,7 @@ class SplitModuleList(ConversionOps):
|
||||
result.append(list(splits))
|
||||
return result
|
||||
|
||||
|
||||
class Cast(ConversionOps):
|
||||
"""
|
||||
Casts the tensor to a given dtype
|
||||
@ -316,6 +315,7 @@ class Cast(ConversionOps):
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class To(ConversionOps):
|
||||
"""
|
||||
Transfers the tensor to the provided device potentially using a stream?
|
||||
@ -327,9 +327,11 @@ class To(ConversionOps):
|
||||
if is_fsdp_enabled():
|
||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||
"""
|
||||
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
|
||||
|
||||
class Shard(ConversionOps):
|
||||
"""Shard tensors along a specific dimension.
|
||||
|
||||
@ -555,21 +557,23 @@ def convert_state_dict(model, state_dict, weight_mapping, tp_plan, quantization_
|
||||
# 1. we need to find which key we have (so we keep track of which pattern was matched)
|
||||
converted_state_dict: dict[str, torch.Tensor] = {}
|
||||
used_operations: list[ConversionOps] = []
|
||||
keys_to_convert = [ rf"{ '|'.join(k.source_keys) if isinstance(k.source_keys, list) else k.source_keys}" for k in weight_mapping ]
|
||||
keys_to_convert = [
|
||||
rf"{'|'.join(k.source_keys) if isinstance(k.source_keys, list) else k.source_keys}" for k in weight_mapping
|
||||
]
|
||||
# tensor parallel is also a conversion scheme! So add it to the keys to convert!
|
||||
# quantization as well! But for quantization we would need to get the module, check if its a linear?
|
||||
|
||||
for k,v in state_dict.items():
|
||||
if re.sub(rf"^({ '|'.join(keys_to_convert) })$", "", k) == k:
|
||||
for k, v in state_dict.items():
|
||||
if re.sub(rf"^({'|'.join(keys_to_convert)})$", "", k) == k:
|
||||
converted_state_dict[k] = v
|
||||
else:
|
||||
# we replace the whole key by the matched pattern so that we can find it later
|
||||
pattern = re.sub(rf"^({ '|'.join(keys_to_convert) })$", r"\1", k)
|
||||
collected_keys[pattern][k] += [v] # we collect all tensors that match the pattern
|
||||
pattern = re.sub(rf"^({'|'.join(keys_to_convert)})$", r"\1", k)
|
||||
collected_keys[pattern][k] += [v] # we collect all tensors that match the pattern
|
||||
converter = weight_mapping[pattern]
|
||||
if pattern in tp_plan: # If we want this to work conversion needs to be explicit no?
|
||||
if pattern in tp_plan: # If we want this to work conversion needs to be explicit no?
|
||||
if converter.distributed_operation is None:
|
||||
converter.distributed_operation = Shard(0) # for now
|
||||
converter.distributed_operation = Shard(0) # for now
|
||||
# TODO: use `param_needs_quantization` !
|
||||
if pattern in quantization_config.conversion_mapping:
|
||||
if converter.quantize_operations is None:
|
||||
|
@ -512,10 +512,8 @@ def accelerate_disk_offload(
|
||||
checkpoint_files,
|
||||
device_map,
|
||||
checkpoint_keys,
|
||||
key_renaming_mapping,
|
||||
sharded_metadata,
|
||||
dtype,
|
||||
reverse_key_renaming_mapping,
|
||||
):
|
||||
disk_only_shard_files = []
|
||||
if disk_offload_folder is not None:
|
||||
@ -534,19 +532,13 @@ def accelerate_disk_offload(
|
||||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
# Fix the weight map keys according to the key mapping
|
||||
weight_map = {
|
||||
key_renaming_mapping[k]: v
|
||||
for k, v in sharded_metadata["weight_map"].items()
|
||||
if k in key_renaming_mapping
|
||||
}
|
||||
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
|
||||
# Find potential checkpoints containing only offloaded weights
|
||||
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
|
||||
disk_offload_index = {
|
||||
name: {
|
||||
"safetensors_file": file,
|
||||
"weight_name": reverse_key_renaming_mapping[name],
|
||||
"weight_name": name,
|
||||
"dtype": str_dtype,
|
||||
}
|
||||
for name, file in weight_map.items()
|
||||
|
@ -26,13 +26,13 @@ import sys
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from threading import Thread
|
||||
from typing import Any, Optional, Sequence, TypeVar, Union, get_type_hints
|
||||
from typing import Any, Optional, TypeVar, Union, get_type_hints
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import torch
|
||||
@ -45,6 +45,8 @@ from torch.distributions import constraints
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from .configuration_utils import PreTrainedConfig
|
||||
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
|
||||
from .core_model_loading import WeightConversion, convert_state_dict
|
||||
from .distributed import DistributedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig
|
||||
@ -59,7 +61,6 @@ from .integrations.accelerate import (
|
||||
init_empty_weights,
|
||||
)
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
from .core_model_loading import QuantizationOp, Shard, WeightConversion, convert_state_dict
|
||||
from .integrations.eager_paged import eager_paged_attention_forward
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flash_paged import paged_attention_forward
|
||||
@ -125,7 +126,6 @@ from .utils.import_utils import (
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from .utils.quantization_config import QuantizationMethod
|
||||
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@ -734,8 +734,6 @@ def load_shard_file(args):
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
|
||||
|
||||
|
||||
def load_shard_files_with_threadpool(args_list):
|
||||
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
|
||||
|
||||
@ -4407,7 +4405,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
if model_type is not None:
|
||||
weight_conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type)
|
||||
|
||||
|
||||
if gguf_file:
|
||||
if hf_quantizer is not None:
|
||||
raise ValueError(
|
||||
@ -4700,7 +4697,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
if weight_mapping:
|
||||
merged_state_dict = {}
|
||||
for file in checkpoint_files: # TODO this is sequential but supposed to be fast
|
||||
for file in checkpoint_files: # TODO this is sequential but supposed to be fast
|
||||
merged_state_dict.update(
|
||||
load_state_dict(file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only)
|
||||
)
|
||||
@ -4749,7 +4746,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# correctly initialize the missing (and potentially mismatched) keys
|
||||
model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized)
|
||||
|
||||
|
||||
is_offloaded_safetensors = False
|
||||
# This offload index if for params explicitly on the "disk" in the device_map
|
||||
disk_offload_index = None
|
||||
@ -4761,10 +4757,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
checkpoint_files,
|
||||
device_map,
|
||||
checkpoint_keys,
|
||||
key_renaming_mapping,
|
||||
new_state_dict.keys(),
|
||||
sharded_metadata,
|
||||
dtype,
|
||||
reverse_key_renaming_mapping,
|
||||
)
|
||||
# To be able to iterate, even if we don't use it if the state_dict is already provided
|
||||
elif state_dict is not None:
|
||||
@ -4799,8 +4794,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Save offloaded index if needed
|
||||
if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors:
|
||||
save_offload_index(disk_offload_index, disk_offload_folder)
|
||||
|
@ -243,38 +243,45 @@ class HunYuanMoEV1Gate(nn.Module):
|
||||
return logits
|
||||
|
||||
|
||||
class HunYuanMoEV1Experts(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
class HunYuanMoEV1Experts(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
|
||||
def __init__(self, config: HunYuanMoEV1Config):
|
||||
super().__init__()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.num_local_experts
|
||||
for _ in range(self.num_experts):
|
||||
self.append(HunYuanMoEV1MLP(config))
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||
|
||||
for expert_idx in expert_hit.tolist():
|
||||
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||
top_indices, token_positions = torch.where(expert_selection)
|
||||
if token_positions.numel() == 0:
|
||||
continue
|
||||
|
||||
current_state = hidden_states.index_select(0, token_positions)
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||
current_hidden_states = self.act_fn(up)
|
||||
current_hidden_states = current_hidden_states * gate
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
|
||||
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
@ -557,38 +557,45 @@ class JambaMLP(nn.Module):
|
||||
return down_proj
|
||||
|
||||
|
||||
class JambaExperts(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
class JambaExperts(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
|
||||
def __init__(self, config: JambaConfig):
|
||||
super().__init__()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.num_local_experts
|
||||
for _ in range(self.num_experts):
|
||||
self.append(JambaMLP(config))
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||
|
||||
for expert_idx in expert_hit.tolist():
|
||||
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||
top_indices, token_positions = torch.where(expert_selection)
|
||||
if token_positions.numel() == 0:
|
||||
continue
|
||||
|
||||
current_state = hidden_states.index_select(0, token_positions)
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||
current_hidden_states = self.act_fn(up)
|
||||
current_hidden_states = current_hidden_states * gate
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
|
||||
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
@ -132,9 +132,9 @@ class MiniMaxConfig(PreTrainedConfig):
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
|
||||
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
|
||||
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
|
||||
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
|
||||
"layers.*.block_sparse_moe.experts.w1": "colwise",
|
||||
"layers.*.block_sparse_moe.experts.w2": "rowwise",
|
||||
"layers.*.block_sparse_moe.experts.w3": "colwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
@ -387,56 +387,45 @@ class MiniMaxAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class MiniMaxMLP(nn.Module):
|
||||
class MiniMaxExperts(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
|
||||
def __init__(self, config: MiniMaxConfig):
|
||||
super().__init__()
|
||||
self.ffn_dim = config.intermediate_size
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
|
||||
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
||||
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
||||
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class MiniMaxExperts(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MiniMaxConfig):
|
||||
super().__init__()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.num_local_experts
|
||||
for _ in range(self.num_experts):
|
||||
self.append(MiniMaxMLP(config))
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||
|
||||
for expert_idx in expert_hit.tolist():
|
||||
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||
top_indices, token_positions = torch.where(expert_selection)
|
||||
if token_positions.numel() == 0:
|
||||
continue
|
||||
|
||||
current_state = hidden_states.index_select(0, token_positions)
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||
current_hidden_states = self.act_fn(up)
|
||||
current_hidden_states = current_hidden_states * gate
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
|
||||
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
@ -61,18 +61,10 @@ class MixtralExperts(nn.Module):
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
|
||||
self.w1 = nn.Parameter(torch.empty(self.num_experts, self.intermediate_dim, self.hidden_dim))
|
||||
self.w2 = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.w3 = nn.Parameter(torch.empty(self.num_experts, self.intermediate_dim, self.hidden_dim))
|
||||
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def reset_parameters(self, initializer_range: float):
|
||||
nn.init.normal_(self.w1, mean=0.0, std=initializer_range)
|
||||
nn.init.normal_(self.w2, mean=0.0, std=initializer_range)
|
||||
nn.init.normal_(self.w3, mean=0.0, std=initializer_range)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -91,11 +83,10 @@ class MixtralExperts(nn.Module):
|
||||
continue
|
||||
|
||||
current_state = hidden_states.index_select(0, token_positions)
|
||||
current_hidden_states = nn.functional.linear(current_state, self.w1[expert_idx])
|
||||
current_hidden_states = self.act_fn(current_hidden_states)
|
||||
gate_hidden_states = nn.functional.linear(current_state, self.w3[expert_idx])
|
||||
current_hidden_states = current_hidden_states * gate_hidden_states
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.w2[expert_idx])
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||
current_hidden_states = self.act_fn(up)
|
||||
current_hidden_states = current_hidden_states * gate
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
|
||||
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
@ -378,11 +369,6 @@ class MixtralPreTrainedModel(PreTrainedModel):
|
||||
"hidden_states": MixtralDecoderLayer,
|
||||
"attentions": MixtralAttention,
|
||||
}
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, MixtralExperts):
|
||||
initializer_range = getattr(self.config, "initializer_range", 0.02)
|
||||
module.reset_parameters(initializer_range)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
@ -259,12 +259,6 @@ class MixtralPreTrainedModel(MistralPreTrainedModel):
|
||||
"attentions": MixtralAttention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, MixtralExperts):
|
||||
initializer_range = getattr(self.config, "initializer_range", 0.02)
|
||||
module.reset_parameters(initializer_range)
|
||||
|
||||
|
||||
class MixtralModel(MistralModel):
|
||||
def forward(
|
||||
|
@ -265,13 +265,11 @@ class OlmoeAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class OlmoeExperts(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
class OlmoeExperts(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
nn.ModuleList.__init__(self)
|
||||
for _ in range(config.num_experts):
|
||||
self.append(OlmoeMLP(config))
|
||||
self.num_experts = config.num_experts
|
||||
@ -279,25 +277,32 @@ class OlmoeExperts(nn.ModuleList):
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||
|
||||
for expert_idx in expert_hit.tolist():
|
||||
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||
top_indices, token_positions = torch.where(expert_selection)
|
||||
if token_positions.numel() == 0:
|
||||
continue
|
||||
|
||||
current_state = hidden_states.index_select(0, token_positions)
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||
current_hidden_states = self.act_fn(up)
|
||||
current_hidden_states = current_hidden_states * gate
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
|
||||
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
@ -260,37 +260,42 @@ class Qwen2MoeAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2MoeExperts(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
class Qwen2MoeExperts(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
nn.ModuleList.__init__(self)
|
||||
self.num_experts = config.num_experts
|
||||
for _ in range(config.num_experts):
|
||||
self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||
|
||||
for expert_idx in expert_hit.tolist():
|
||||
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||
top_indices, token_positions = torch.where(expert_selection)
|
||||
if token_positions.numel() == 0:
|
||||
continue
|
||||
|
||||
current_state = hidden_states.index_select(0, token_positions)
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||
current_hidden_states = self.act_fn(up)
|
||||
current_hidden_states = current_hidden_states * gate
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
|
||||
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user