mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
update
This commit is contained in:
@ -5,7 +5,7 @@
|
|||||||
#
|
#
|
||||||
# Either we keep it here, or we move it to the config, but for newcomers, seeing this is kinda weird no?
|
# Either we keep it here, or we move it to the config, but for newcomers, seeing this is kinda weird no?
|
||||||
|
|
||||||
from .core_model_loading import Concatenate, MergeModuleList, WeightConversion, Fp8Quantize, Shard
|
from .core_model_loading import Concatenate, Fp8Quantize, MergeModuleList, Shard, WeightConversion
|
||||||
|
|
||||||
|
|
||||||
_checkpoint_conversion_mapping = {
|
_checkpoint_conversion_mapping = {
|
||||||
|
@ -14,14 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Core helpers for loading model checkpoints."""
|
"""Core helpers for loading model checkpoints."""
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import re
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import OrderedDict
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -51,14 +50,13 @@ except AttributeError:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.profiler import ProfilerActivity, profile as torch_profile
|
from torch.profiler import ProfilerActivity
|
||||||
|
from torch.profiler import profile as torch_profile
|
||||||
except (ImportError, AttributeError):
|
except (ImportError, AttributeError):
|
||||||
ProfilerActivity = None
|
ProfilerActivity = None
|
||||||
torch_profile = None
|
torch_profile = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConversionOps:
|
class ConversionOps:
|
||||||
"""Base class for weight conversion operations.
|
"""Base class for weight conversion operations.
|
||||||
|
|
||||||
@ -308,6 +306,7 @@ class SplitModuleList(ConversionOps):
|
|||||||
result.append(list(splits))
|
result.append(list(splits))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class Cast(ConversionOps):
|
class Cast(ConversionOps):
|
||||||
"""
|
"""
|
||||||
Casts the tensor to a given dtype
|
Casts the tensor to a given dtype
|
||||||
@ -316,6 +315,7 @@ class Cast(ConversionOps):
|
|||||||
def __init__(self, dtype):
|
def __init__(self, dtype):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
|
||||||
class To(ConversionOps):
|
class To(ConversionOps):
|
||||||
"""
|
"""
|
||||||
Transfers the tensor to the provided device potentially using a stream?
|
Transfers the tensor to the provided device potentially using a stream?
|
||||||
@ -327,9 +327,11 @@ class To(ConversionOps):
|
|||||||
if is_fsdp_enabled():
|
if is_fsdp_enabled():
|
||||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device):
|
def __init__(self, device):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
|
||||||
class Shard(ConversionOps):
|
class Shard(ConversionOps):
|
||||||
"""Shard tensors along a specific dimension.
|
"""Shard tensors along a specific dimension.
|
||||||
|
|
||||||
@ -555,7 +557,9 @@ def convert_state_dict(model, state_dict, weight_mapping, tp_plan, quantization_
|
|||||||
# 1. we need to find which key we have (so we keep track of which pattern was matched)
|
# 1. we need to find which key we have (so we keep track of which pattern was matched)
|
||||||
converted_state_dict: dict[str, torch.Tensor] = {}
|
converted_state_dict: dict[str, torch.Tensor] = {}
|
||||||
used_operations: list[ConversionOps] = []
|
used_operations: list[ConversionOps] = []
|
||||||
keys_to_convert = [ rf"{ '|'.join(k.source_keys) if isinstance(k.source_keys, list) else k.source_keys}" for k in weight_mapping ]
|
keys_to_convert = [
|
||||||
|
rf"{'|'.join(k.source_keys) if isinstance(k.source_keys, list) else k.source_keys}" for k in weight_mapping
|
||||||
|
]
|
||||||
# tensor parallel is also a conversion scheme! So add it to the keys to convert!
|
# tensor parallel is also a conversion scheme! So add it to the keys to convert!
|
||||||
# quantization as well! But for quantization we would need to get the module, check if its a linear?
|
# quantization as well! But for quantization we would need to get the module, check if its a linear?
|
||||||
|
|
||||||
|
@ -512,10 +512,8 @@ def accelerate_disk_offload(
|
|||||||
checkpoint_files,
|
checkpoint_files,
|
||||||
device_map,
|
device_map,
|
||||||
checkpoint_keys,
|
checkpoint_keys,
|
||||||
key_renaming_mapping,
|
|
||||||
sharded_metadata,
|
sharded_metadata,
|
||||||
dtype,
|
dtype,
|
||||||
reverse_key_renaming_mapping,
|
|
||||||
):
|
):
|
||||||
disk_only_shard_files = []
|
disk_only_shard_files = []
|
||||||
if disk_offload_folder is not None:
|
if disk_offload_folder is not None:
|
||||||
@ -534,19 +532,13 @@ def accelerate_disk_offload(
|
|||||||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
||||||
else:
|
else:
|
||||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||||
# Fix the weight map keys according to the key mapping
|
|
||||||
weight_map = {
|
|
||||||
key_renaming_mapping[k]: v
|
|
||||||
for k, v in sharded_metadata["weight_map"].items()
|
|
||||||
if k in key_renaming_mapping
|
|
||||||
}
|
|
||||||
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
|
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
|
||||||
# Find potential checkpoints containing only offloaded weights
|
# Find potential checkpoints containing only offloaded weights
|
||||||
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
|
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
|
||||||
disk_offload_index = {
|
disk_offload_index = {
|
||||||
name: {
|
name: {
|
||||||
"safetensors_file": file,
|
"safetensors_file": file,
|
||||||
"weight_name": reverse_key_renaming_mapping[name],
|
"weight_name": name,
|
||||||
"dtype": str_dtype,
|
"dtype": str_dtype,
|
||||||
}
|
}
|
||||||
for name, file in weight_map.items()
|
for name, file in weight_map.items()
|
||||||
|
@ -26,13 +26,13 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Optional, Sequence, TypeVar, Union, get_type_hints
|
from typing import Any, Optional, TypeVar, Union, get_type_hints
|
||||||
from zipfile import is_zipfile
|
from zipfile import is_zipfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -45,6 +45,8 @@ from torch.distributions import constraints
|
|||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from .configuration_utils import PreTrainedConfig
|
from .configuration_utils import PreTrainedConfig
|
||||||
|
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
|
||||||
|
from .core_model_loading import WeightConversion, convert_state_dict
|
||||||
from .distributed import DistributedConfig
|
from .distributed import DistributedConfig
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation import CompileConfig, GenerationConfig
|
from .generation import CompileConfig, GenerationConfig
|
||||||
@ -59,7 +61,6 @@ from .integrations.accelerate import (
|
|||||||
init_empty_weights,
|
init_empty_weights,
|
||||||
)
|
)
|
||||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||||
from .core_model_loading import QuantizationOp, Shard, WeightConversion, convert_state_dict
|
|
||||||
from .integrations.eager_paged import eager_paged_attention_forward
|
from .integrations.eager_paged import eager_paged_attention_forward
|
||||||
from .integrations.flash_attention import flash_attention_forward
|
from .integrations.flash_attention import flash_attention_forward
|
||||||
from .integrations.flash_paged import paged_attention_forward
|
from .integrations.flash_paged import paged_attention_forward
|
||||||
@ -125,7 +126,6 @@ from .utils.import_utils import (
|
|||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
)
|
)
|
||||||
from .utils.quantization_config import QuantizationMethod
|
from .utils.quantization_config import QuantizationMethod
|
||||||
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
|
|
||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
@ -734,8 +734,6 @@ def load_shard_file(args):
|
|||||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_shard_files_with_threadpool(args_list):
|
def load_shard_files_with_threadpool(args_list):
|
||||||
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
|
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
|
||||||
|
|
||||||
@ -4407,7 +4405,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
if model_type is not None:
|
if model_type is not None:
|
||||||
weight_conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type)
|
weight_conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type)
|
||||||
|
|
||||||
|
|
||||||
if gguf_file:
|
if gguf_file:
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -4749,7 +4746,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
# correctly initialize the missing (and potentially mismatched) keys
|
# correctly initialize the missing (and potentially mismatched) keys
|
||||||
model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized)
|
model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized)
|
||||||
|
|
||||||
|
|
||||||
is_offloaded_safetensors = False
|
is_offloaded_safetensors = False
|
||||||
# This offload index if for params explicitly on the "disk" in the device_map
|
# This offload index if for params explicitly on the "disk" in the device_map
|
||||||
disk_offload_index = None
|
disk_offload_index = None
|
||||||
@ -4761,10 +4757,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
checkpoint_files,
|
checkpoint_files,
|
||||||
device_map,
|
device_map,
|
||||||
checkpoint_keys,
|
checkpoint_keys,
|
||||||
key_renaming_mapping,
|
new_state_dict.keys(),
|
||||||
sharded_metadata,
|
sharded_metadata,
|
||||||
dtype,
|
dtype,
|
||||||
reverse_key_renaming_mapping,
|
|
||||||
)
|
)
|
||||||
# To be able to iterate, even if we don't use it if the state_dict is already provided
|
# To be able to iterate, even if we don't use it if the state_dict is already provided
|
||||||
elif state_dict is not None:
|
elif state_dict is not None:
|
||||||
@ -4799,8 +4794,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Save offloaded index if needed
|
# Save offloaded index if needed
|
||||||
if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors:
|
if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors:
|
||||||
save_offload_index(disk_offload_index, disk_offload_folder)
|
save_offload_index(disk_offload_index, disk_offload_folder)
|
||||||
|
@ -243,38 +243,45 @@ class HunYuanMoEV1Gate(nn.Module):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
class HunYuanMoEV1Experts(nn.ModuleList):
|
class HunYuanMoEV1Experts(nn.Module):
|
||||||
"""
|
"""Collection of expert weights stored as 3D tensors."""
|
||||||
ModuleList of experts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: HunYuanMoEV1Config):
|
def __init__(self, config: HunYuanMoEV1Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
for _ in range(self.num_experts):
|
self.hidden_dim = config.hidden_size
|
||||||
self.append(HunYuanMoEV1MLP(config))
|
self.intermediate_dim = config.intermediate_size
|
||||||
|
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||||
|
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
top_k_index: torch.Tensor,
|
||||||
|
top_k_weights: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
|
||||||
selected_experts: (batch_size * sequence_length, top_k)
|
|
||||||
routing_weights: (batch_size * sequence_length, top_k)
|
|
||||||
Returns:
|
|
||||||
(batch_size * sequence_length, hidden_dim)
|
|
||||||
"""
|
|
||||||
final_hidden_states = torch.zeros_like(hidden_states)
|
final_hidden_states = torch.zeros_like(hidden_states)
|
||||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
|
||||||
|
|
||||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||||
for expert_idx in expert_hit:
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
|
||||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
for expert_idx in expert_hit.tolist():
|
||||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
top_indices, token_positions = torch.where(expert_selection)
|
||||||
|
if token_positions.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_state = hidden_states.index_select(0, token_positions)
|
||||||
|
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||||
|
current_hidden_states = self.act_fn(up)
|
||||||
|
current_hidden_states = current_hidden_states * gate
|
||||||
|
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||||
|
|
||||||
|
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||||
|
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||||
|
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -557,38 +557,45 @@ class JambaMLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
class JambaExperts(nn.ModuleList):
|
class JambaExperts(nn.Module):
|
||||||
"""
|
"""Collection of expert weights stored as 3D tensors."""
|
||||||
ModuleList of experts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: JambaConfig):
|
def __init__(self, config: JambaConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
for _ in range(self.num_experts):
|
self.hidden_dim = config.hidden_size
|
||||||
self.append(JambaMLP(config))
|
self.intermediate_dim = config.intermediate_size
|
||||||
|
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||||
|
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
top_k_index: torch.Tensor,
|
||||||
|
top_k_weights: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
|
||||||
selected_experts: (batch_size * sequence_length, top_k)
|
|
||||||
routing_weights: (batch_size * sequence_length, top_k)
|
|
||||||
Returns:
|
|
||||||
(batch_size * sequence_length, hidden_dim)
|
|
||||||
"""
|
|
||||||
final_hidden_states = torch.zeros_like(hidden_states)
|
final_hidden_states = torch.zeros_like(hidden_states)
|
||||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
|
||||||
|
|
||||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||||
for expert_idx in expert_hit:
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
|
||||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
for expert_idx in expert_hit.tolist():
|
||||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
top_indices, token_positions = torch.where(expert_selection)
|
||||||
|
if token_positions.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_state = hidden_states.index_select(0, token_positions)
|
||||||
|
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||||
|
current_hidden_states = self.act_fn(up)
|
||||||
|
current_hidden_states = current_hidden_states * gate
|
||||||
|
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||||
|
|
||||||
|
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||||
|
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||||
|
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,9 +132,9 @@ class MiniMaxConfig(PreTrainedConfig):
|
|||||||
"layers.*.self_attn.v_proj": "colwise",
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
"layers.*.self_attn.o_proj": "rowwise",
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
|
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
|
||||||
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
|
"layers.*.block_sparse_moe.experts.w1": "colwise",
|
||||||
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
|
"layers.*.block_sparse_moe.experts.w2": "rowwise",
|
||||||
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
|
"layers.*.block_sparse_moe.experts.w3": "colwise",
|
||||||
}
|
}
|
||||||
base_model_pp_plan = {
|
base_model_pp_plan = {
|
||||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
@ -387,56 +387,45 @@ class MiniMaxAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxMLP(nn.Module):
|
class MiniMaxExperts(nn.Module):
|
||||||
|
"""Collection of expert weights stored as 3D tensors."""
|
||||||
|
|
||||||
def __init__(self, config: MiniMaxConfig):
|
def __init__(self, config: MiniMaxConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ffn_dim = config.intermediate_size
|
self.num_experts = config.num_local_experts
|
||||||
self.hidden_dim = config.hidden_size
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.intermediate_dim = config.intermediate_size
|
||||||
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||||
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||||
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
|
||||||
|
|
||||||
self.act_fn = ACT2FN[config.hidden_act]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
|
||||||
current_hidden_states = self.w2(current_hidden_states)
|
|
||||||
return current_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxExperts(nn.ModuleList):
|
|
||||||
"""
|
|
||||||
ModuleList of experts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: MiniMaxConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
self.num_experts = config.num_local_experts
|
|
||||||
for _ in range(self.num_experts):
|
|
||||||
self.append(MiniMaxMLP(config))
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
top_k_index: torch.Tensor,
|
||||||
|
top_k_weights: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
|
||||||
selected_experts: (batch_size * sequence_length, top_k)
|
|
||||||
routing_weights: (batch_size * sequence_length, top_k)
|
|
||||||
Returns:
|
|
||||||
(batch_size * sequence_length, hidden_dim)
|
|
||||||
"""
|
|
||||||
final_hidden_states = torch.zeros_like(hidden_states)
|
final_hidden_states = torch.zeros_like(hidden_states)
|
||||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
|
||||||
|
|
||||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||||
for expert_idx in expert_hit:
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
|
||||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
for expert_idx in expert_hit.tolist():
|
||||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
top_indices, token_positions = torch.where(expert_selection)
|
||||||
|
if token_positions.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_state = hidden_states.index_select(0, token_positions)
|
||||||
|
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||||
|
current_hidden_states = self.act_fn(up)
|
||||||
|
current_hidden_states = current_hidden_states * gate
|
||||||
|
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||||
|
|
||||||
|
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||||
|
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||||
|
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,18 +61,10 @@ class MixtralExperts(nn.Module):
|
|||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
self.hidden_dim = config.hidden_size
|
self.hidden_dim = config.hidden_size
|
||||||
self.intermediate_dim = config.intermediate_size
|
self.intermediate_dim = config.intermediate_size
|
||||||
|
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||||
self.w1 = nn.Parameter(torch.empty(self.num_experts, self.intermediate_dim, self.hidden_dim))
|
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||||
self.w2 = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
|
||||||
self.w3 = nn.Parameter(torch.empty(self.num_experts, self.intermediate_dim, self.hidden_dim))
|
|
||||||
|
|
||||||
self.act_fn = ACT2FN[config.hidden_act]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def reset_parameters(self, initializer_range: float):
|
|
||||||
nn.init.normal_(self.w1, mean=0.0, std=initializer_range)
|
|
||||||
nn.init.normal_(self.w2, mean=0.0, std=initializer_range)
|
|
||||||
nn.init.normal_(self.w3, mean=0.0, std=initializer_range)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -91,11 +83,10 @@ class MixtralExperts(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
current_state = hidden_states.index_select(0, token_positions)
|
current_state = hidden_states.index_select(0, token_positions)
|
||||||
current_hidden_states = nn.functional.linear(current_state, self.w1[expert_idx])
|
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||||
current_hidden_states = self.act_fn(current_hidden_states)
|
current_hidden_states = self.act_fn(up)
|
||||||
gate_hidden_states = nn.functional.linear(current_state, self.w3[expert_idx])
|
current_hidden_states = current_hidden_states * gate
|
||||||
current_hidden_states = current_hidden_states * gate_hidden_states
|
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.w2[expert_idx])
|
|
||||||
|
|
||||||
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||||
@ -378,11 +369,6 @@ class MixtralPreTrainedModel(PreTrainedModel):
|
|||||||
"hidden_states": MixtralDecoderLayer,
|
"hidden_states": MixtralDecoderLayer,
|
||||||
"attentions": MixtralAttention,
|
"attentions": MixtralAttention,
|
||||||
}
|
}
|
||||||
def _init_weights(self, module):
|
|
||||||
super()._init_weights(module)
|
|
||||||
if isinstance(module, MixtralExperts):
|
|
||||||
initializer_range = getattr(self.config, "initializer_range", 0.02)
|
|
||||||
module.reset_parameters(initializer_range)
|
|
||||||
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
|
@ -259,12 +259,6 @@ class MixtralPreTrainedModel(MistralPreTrainedModel):
|
|||||||
"attentions": MixtralAttention,
|
"attentions": MixtralAttention,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
super()._init_weights(module)
|
|
||||||
if isinstance(module, MixtralExperts):
|
|
||||||
initializer_range = getattr(self.config, "initializer_range", 0.02)
|
|
||||||
module.reset_parameters(initializer_range)
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralModel(MistralModel):
|
class MixtralModel(MistralModel):
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -265,13 +265,11 @@ class OlmoeAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class OlmoeExperts(nn.ModuleList):
|
class OlmoeExperts(nn.Module):
|
||||||
"""
|
"""Collection of expert weights stored as 3D tensors."""
|
||||||
ModuleList of experts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
nn.ModuleList.__init__(self)
|
||||||
for _ in range(config.num_experts):
|
for _ in range(config.num_experts):
|
||||||
self.append(OlmoeMLP(config))
|
self.append(OlmoeMLP(config))
|
||||||
self.num_experts = config.num_experts
|
self.num_experts = config.num_experts
|
||||||
@ -279,25 +277,32 @@ class OlmoeExperts(nn.ModuleList):
|
|||||||
self.norm_topk_prob = config.norm_topk_prob
|
self.norm_topk_prob = config.norm_topk_prob
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
top_k_index: torch.Tensor,
|
||||||
|
top_k_weights: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
|
||||||
selected_experts: (batch_size * sequence_length, top_k)
|
|
||||||
routing_weights: (batch_size * sequence_length, top_k)
|
|
||||||
Returns:
|
|
||||||
(batch_size * sequence_length, hidden_dim)
|
|
||||||
"""
|
|
||||||
final_hidden_states = torch.zeros_like(hidden_states)
|
final_hidden_states = torch.zeros_like(hidden_states)
|
||||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
|
||||||
|
|
||||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||||
for expert_idx in expert_hit:
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
|
||||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
for expert_idx in expert_hit.tolist():
|
||||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
top_indices, token_positions = torch.where(expert_selection)
|
||||||
|
if token_positions.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_state = hidden_states.index_select(0, token_positions)
|
||||||
|
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||||
|
current_hidden_states = self.act_fn(up)
|
||||||
|
current_hidden_states = current_hidden_states * gate
|
||||||
|
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||||
|
|
||||||
|
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||||
|
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||||
|
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -260,37 +260,42 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeExperts(nn.ModuleList):
|
class Qwen2MoeExperts(nn.Module):
|
||||||
"""
|
"""Collection of expert weights stored as 3D tensors."""
|
||||||
ModuleList of experts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
nn.ModuleList.__init__(self)
|
||||||
self.num_experts = config.num_experts
|
self.num_experts = config.num_experts
|
||||||
for _ in range(config.num_experts):
|
for _ in range(config.num_experts):
|
||||||
self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size))
|
self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
top_k_index: torch.Tensor,
|
||||||
|
top_k_weights: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
|
||||||
selected_experts: (batch_size * sequence_length, top_k)
|
|
||||||
routing_weights: (batch_size * sequence_length, top_k)
|
|
||||||
Returns:
|
|
||||||
(batch_size * sequence_length, hidden_dim)
|
|
||||||
"""
|
|
||||||
final_hidden_states = torch.zeros_like(hidden_states)
|
final_hidden_states = torch.zeros_like(hidden_states)
|
||||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
|
||||||
|
|
||||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||||
for expert_idx in expert_hit:
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).flatten()
|
||||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
|
||||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
for expert_idx in expert_hit.tolist():
|
||||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
expert_selection = expert_mask[expert_idx].squeeze(0)
|
||||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
top_indices, token_positions = torch.where(expert_selection)
|
||||||
|
if token_positions.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_state = hidden_states.index_select(0, token_positions)
|
||||||
|
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2)
|
||||||
|
current_hidden_states = self.act_fn(up)
|
||||||
|
current_hidden_states = current_hidden_states * gate
|
||||||
|
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||||
|
|
||||||
|
routing_weights = top_k_weights[token_positions, top_indices].unsqueeze(-1)
|
||||||
|
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||||
|
final_hidden_states.index_add_(0, token_positions, current_hidden_states.to(final_hidden_states.dtype))
|
||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user