Compare commits

..

14 Commits

Author SHA1 Message Date
3ee3c563dd Merge branch 'main' into siglip_and_check_model_changes 2025-10-30 15:10:47 +01:00
02c324f43f Fix: Gemma3TextConfig rope scaling assignments (#41934)
* Fix: Gemma3TextConfig rope scaling assignments

* Fix: type annotation for rope_parameters
2025-10-30 12:23:54 +00:00
b47b35637f Fix rope_parameters for gemma3 weights conversion script (#41922)
Fix rope_parameters for gemma3 weights conversion script.

Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com>
2025-10-30 11:49:18 +00:00
f54d0db71d Merge branch 'main' into siglip_and_check_model_changes 2025-10-30 08:48:53 +01:00
40a9dc87d3 reorder/simplify 2025-10-30 08:47:35 +01:00
91d34b0a99 fix initialization 2025-10-29 18:30:22 +01:00
448dd635e3 attn support 2025-10-29 18:01:33 +01:00
807983c2a7 missing docstring 2025-10-29 17:40:21 +01:00
5aa7610d12 Merge branch 'siglip_and_check_model_changes' of github.com:huggingface/transformers into siglip_and_check_model_changes 2025-10-29 17:38:35 +01:00
fe7c9228a4 fix tests 2025-10-29 17:38:22 +01:00
082dcf21d1 Merge branch 'main' into siglip_and_check_model_changes 2025-10-29 17:14:58 +01:00
4f93734169 fixup 2025-10-29 17:14:07 +01:00
76a14c7008 correct inheritance + protect executorch 2025-10-29 17:13:43 +01:00
ca68be8560 handle inputs from non-automapped encoder layers 2025-10-29 10:47:30 +01:00
86 changed files with 1461 additions and 3164 deletions

View File

@ -1,83 +0,0 @@
# coding=utf-8
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
_checkpoint_conversion_mapping = {
"mixtral": [
WeightConverter(
source_keys=[
"block_sparse_moe.experts.*.w1.weight",
"block_sparse_moe.experts.*.w3.weight",
], # you give me a list of 2 keys, I collect a list of tensors
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
operations=[
MergeModulelist(
dim=0
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
),
WeightConverter(
source_keys=[
"block_sparse_moe.experts.*.w2.weight",
],
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
operations=[
MergeModulelist(
dim=0
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
),
# WeightConverter(
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
# "self_attn.qkv_proj",
# Concatenate(dim=0), # more like stack?
# ),
WeightConverter("*.block_sparse_moe.", "*.mlp."),
],
"qwen2_moe": [
WeightConverter(
source_keys=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_keys="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_keys=["mlp.experts.*.down_proj.weight"],
target_keys="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
}
_checkpoint_conversion_mapping["phimoe"] = _checkpoint_conversion_mapping["mixtral"].copy()
_checkpoint_conversion_mapping["deepseek_v2"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["deepseek_v3"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["dot1"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["ernie_4_5_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["glm4_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["glm4v_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["jamba"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["lfm2_moe"] = _checkpoint_conversion_mapping["mixtral"].copy()
_checkpoint_conversion_mapping["long_cat_flash"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["qwen3_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["qwen3_omni_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["qwen3_next"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["qwen3_vl_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["hunyuan_v1_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy()
_checkpoint_conversion_mapping["minimax"] = _checkpoint_conversion_mapping["mixtral"].copy()

View File

@ -1,576 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core helpers for loading model checkpoints."""
from __future__ import annotations
import itertools
import os
import re
import threading
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Optional, Union
import torch
from torch.distributed.tensor import DTensor
from .integrations.finegrained_fp8 import Fp8Quantize
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
from .utils import logging
logger = logging.get_logger(__name__)
def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
"""
Convert a glob with '*' into a regex *source* string.
'*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing.
"""
star = r"(\d+)" if digits_only else r"(.+)"
return re.escape(glob).replace(r"\*", star)
def build_glob_alt(
globs: list[str],
*,
digits_only: bool = True,
allow_prefix: bool = True,
) -> tuple[re.Pattern, dict[str, str]]:
"""
Build one compiled regex alternation with a named group per glob.
- digits_only: '*' => digits only (\\d+) if True, else any chars (.+)
- allow_prefix: if True, allow arbitrary prefix before the pattern
(keeps '$' so we still require a full suffix match)
Returns (compiled_regex, name->glob map).
"""
name_map: dict[str, str] = {}
parts: list[str] = []
# If we keep using .match(), we must handle prefix allowance in the pattern itself.
prefix_src = r".*" if allow_prefix else r"^"
for i, g in enumerate(globs):
name = f"g{i}"
name_map[name] = g
pat_src = _glob_to_regex_src(g, digits_only=digits_only)
# Each branch is fully wrapped and uniquely named.
parts.append(f"(?P<{name}>{prefix_src}{pat_src})")
alt_src = "|".join(parts)
return re.compile(alt_src), name_map
def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]:
"""
Match the key against the alternation; return the original glob string that matched.
"""
m = alt.match(key)
if not m:
return None
return name_map.get(m.lastgroup)
class ConversionOps:
"""Base class for weight conversion operations."""
# Reusable scratch buffer to avoid reallocations.
_buffer: Optional[torch.Tensor] = None
# The inverse operation class, will be used when saving the checkpoint
_inverse_op: type[ConversionOps]
def _ensure_buffer(
self,
required_shape: torch.Size,
*,
dtype: torch.dtype,
device: torch.device,
growth_factor: float = 1.5,
) -> torch.Tensor:
"""Ensure a pre-allocated buffer large enough for ``required_shape`` exists."""
required_elems = 1
for dim in required_shape:
required_elems *= int(dim)
need_new = (
self._buffer is None
or self._buffer.dtype != dtype
or self._buffer.device != device
or self._buffer.numel() < required_elems
)
if need_new:
capacity = max(required_elems, int(required_elems * growth_factor))
self._buffer = torch.empty(capacity, dtype=dtype, device=device)
return self._buffer[:required_elems].view(required_shape)
def clear_cache(self) -> None:
"""Free any cached buffers."""
self._buffer = None
@abstractmethod
def convert(self, value: Union[Sequence[torch.Tensor], torch.Tensor], *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
class Chunk(ConversionOps):
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""
_inverse_op: type[ConversionOps]
def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
if chunks is None and sizes is None:
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
if chunks is not None and chunks <= 0:
raise ValueError("`chunks` must be a strictly positive integer.")
self.dim = dim
self.chunks = chunks
self.sizes = list(sizes) if sizes is not None else None
self._inverse_op = Concatenate
def convert(self, value: torch.Tensor) -> list[torch.Tensor]:
if not isinstance(value, torch.Tensor):
raise TypeError("Chunk expects a torch.Tensor as input.")
if self.sizes is not None:
return list(torch.split(value, self.sizes, dim=self.dim))
return list(torch.chunk(value, self.chunks, dim=self.dim))
class Concatenate(ConversionOps):
"""Concatenate tensors along `dim` using a reusable buffer."""
_inverse_op: type[ConversionOps]
def __init__(self, dim: int = 0):
self.dim = dim
self._inverse_op = Chunk
@torch.no_grad
def convert(self, value: Sequence[torch.Tensor]) -> torch.Tensor:
if isinstance(value[0], list):
value = [v[0] for v in value]
tensors = value
if not tensors:
raise ValueError("Fuse requires at least one tensor to concatenate.")
out_shape = list(tensors[0].shape)
out_shape[self.dim] = sum([t.size(self.dim) for t in tensors])
with torch.no_grad(): # we use staging buffers
out = self._ensure_buffer(torch.Size(out_shape), dtype=tensors[0].dtype, device=tensors[0].device)
torch.cat(tuple(tensors), dim=self.dim, out=out)
# offset = 0
# for tensor in tensors:
# index = [slice(None)] * tensor.ndim
# index[self.dim] = slice(offset, offset + tensor.shape[self.dim])
# out[tuple(index)].copy_(tensor, non_blocking=tensor.is_cuda)
# offset += tensor.shape[self.dim]
return out.clone() # need to say I can overwrite this storage now
class MergeModulelist(Concatenate):
"""
Merge a list of tensors into a single tensor along the first dimension.
We explicitly define this because for EP or TP you want to make sure you know what you are doing!
"""
def __init__(self, dim: int = 0):
super().__init__(dim=dim)
self._inverse_op = SplitModulelist
def convert(self, value: Sequence[torch.Tensor]) -> list[torch.Tensor]:
merged = []
with torch.no_grad(): # we use staging buffers
for group in value:
if not isinstance(group, Sequence) or len(group) == 0:
raise ValueError("MergeModulelist requires non-empty sub-sequences.")
group = [k for k in group if k.ndim]
out_shape = list(group[0].shape)
out_shape.insert(self.dim, len(group))
out = self._ensure_buffer(torch.Size(out_shape), dtype=group[0].dtype, device=group[0].device)
torch.stack(tuple(group), dim=self.dim, out=out)
# for off, tensor in enumerate(group):
# out[off].copy_(tensor, non_blocking=tensor.is_cuda)
# torch.as_tensor(numpy.stack(batch))
merged.append(out.clone()) # TODO have a single staging tensor here as well!
return merged
class SplitModulelist(ConversionOps):
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""
def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
self.sizes = [list(sub) for sub in sizes]
self.dim = dim
self._inverse_op = MergeModulelist
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
if not isinstance(value, Sequence):
raise TypeError("SplitModulelist expects a sequence of tensors.")
if len(value) != len(self.sizes):
raise ValueError("Number of tensors does not match the provided split specifications.")
result: list[list[torch.Tensor]] = []
for tensor, split_sizes in zip(value, self.sizes):
if not isinstance(tensor, torch.Tensor):
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
splits = torch.split(tensor, split_sizes, dim=self.dim)
result.append(list(splits))
return result
class Cast(ConversionOps):
"""
Casts the tensor to a given dtype
"""
def __init__(self, dtype):
self.dtype = dtype
def convert(self, realized_value):
return realized_value.to(self.dtype)
class To(ConversionOps):
"""
Transfers the tensor to the provided device potentially using a stream?
if param_device == "disk":
if not is_safetensors:
disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
elif not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
"""
def __init__(self, device):
self.device = device
def convert(self, realized_value):
with torch.device(self.device):
out = [[x[...] for x in inner] if isinstance(inner, list) else inner[...] for inner in realized_value]
return out
@dataclass(slots=True)
class WeightConverter:
r"""
A weight convert that acts on a pattern of source keys.
The keys need to be collected based on the target keys.
With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match:
`model.layers.*.experts.*` -> it will act on all of them
{"model.layers.*.experts.*": []}
but
`experts.*.mlp` will be layer specific.
{"model.layers.1.experts.*": [], }
- source_keys: str | list[str] (wildcards '*' match digits)
- target_keys: str | list[str] | None
- distributed_operation / operations / quantization_operations are ALWAYS lists.
"""
source_keys: Union[str, list[str]]
target_keys: Optional[Union[str, list[str]]] = None
operations: list[ConversionOps] = field(default_factory=list, repr=False)
distributed_operation: dict[str, ConversionOps] = field(default_factory=dict, compare=False, repr=False)
quantization_operation: dict[str, ConversionOps] = field(default_factory=dict, compare=False, repr=False)
_compiled: tuple[tuple[str, re.Pattern], ...] = field(default_factory=tuple, compare=False, repr=False)
_regex_pat: tuple[re.Pattern, dict[str, str]] = field(default_factory=tuple, compare=False, repr=False)
def __post_init__(self):
if not isinstance(self.source_keys, list):
self.source_keys = [self.source_keys]
if not isinstance(self.target_keys, list):
if self.target_keys is None:
self.target_keys = self.source_keys
else:
self.target_keys = [self.target_keys]
self._regex_pat = build_glob_alt(self.source_keys)
def set_param_for_module(
model, k, v, meta_model_state_dict, empty_tensor, mismatch_keys, missing_keys, misc, distributed_operation
):
try:
module_path, _, param_name = k.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
param_value = v[0] if isinstance(v, list) else v[:]
ref = meta_model_state_dict.get(k, empty_tensor)
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
if not isinstance(param_value, torch.nn.Parameter):
if distributed_operation != {} and use_dtensor:
param_value = DTensor.from_local(
param_value,
distributed_operation.device_mesh,
distributed_operation.shard,
run_check=False,
shape=ref.size(),
stride=ref.stride(),
)
else:
pass # TODO for "local" stuff, it will trigger missmatched no?
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
if ref is not None and ref.shape != param_value.shape:
mismatch_keys.add((k, param_value.shape, ref.shape))
if k in missing_keys:
missing_keys.remove(k)
setattr(module_obj, param_name, param_value)
except Exception as e:
misc[k] = f"{e} for {k} on {list(module_obj.state_dict().keys())}"
@dataclass(slots=True)
class ConversionEntry:
weight_converter: WeightConverter
collected_tensors: dict = field(default_factory=lambda: defaultdict(dict))
# Tune these to your storage:
GLOBAL_WORKERS = min(32, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
PER_FILE_LIMIT = 4 # concurrent reads per file
def _materialize_copy(x):
# PyTorch: this runs in C and releases the GIL; good for threads.
return x[...]
def spawn_materialize(EXEC, _file_sems, file_id, t) -> Future:
sem = _file_sems[file_id]
def _job():
with sem:
return _materialize_copy(t)
return EXEC.submit(_job)
def spawn_tp_materialize(EXEC, _file_sems, file_id, t, sharding_method, empty_tensor, tensor_idx) -> Future:
sem = _file_sems[file_id]
def _job():
with sem:
return sharding_method.shard_tensor(t, empty_tensor, tensor_idx=tensor_idx)[0]
return EXEC.submit(_job)
def dot_natural_key(s: str):
parts = s.split(".")
for i, p in enumerate(parts):
# whole-segment digits -> int; otherwise leave as str
if p.isdigit():
parts[i] = int(p)
return parts
def convert_and_load_state_dict_in_model(
model,
state_dict,
weight_mapping,
tp_plan,
quantizer,
device_map=None,
keep_in_dtype=None,
device_mesh=None,
profile: bool = False,
):
"""
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
collecting tensors per *layer instance* (the concrete indices captured from '*').
"""
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
device_map = device_map or {} # {exact_target_key: device}
keep_in_dtype = keep_in_dtype or {} # {glob_pattern: dtype}
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
meta_model_state_dict = model.state_dict()
missing_keys = set(meta_model_state_dict.keys())
if model.config.tie_word_embeddings and "lm_head.weight" in missing_keys:
missing_keys.remove("lm_head.weight")
misc = {}
mismatch_keys = set()
unexpected_keys = set()
# Global executor + per-file semaphores
EXEC = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
_file_sems = defaultdict(lambda: threading.Semaphore(PER_FILE_LIMIT))
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))
dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(keep_in_dtype.keys()))
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
# 1. Create the conversion entries
by_conversion_pattern: dict[str, ConversionEntry] = {}
for original_key, (file_id, tensor) in state_dict:
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
if matched_pattern is not None:
converter = source_to_target[matched_pattern] # TODO make sure its the ref
sub_with_extractor = partial(re.sub, _glob_to_regex_src(matched_pattern), string=original_key)
entry_key = "|".join(converter.target_keys)
target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys]))
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
converter_key = sub_with_extractor(matched_pattern)
else:
converter = WeightConverter(original_key)
converter_key = entry_key = target_key = original_key
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))
prefix = model.base_model_prefix
new_target_key = []
for t in target_key.split("|"): # let's correct the keys
if t.startswith(prefix) and meta_model_state_dict.get(t.replace(f"{prefix}.", "")) is not None:
t = t.replace(f"{prefix}.", "")
elif meta_model_state_dict.get(f"{prefix}.{t}") is not None:
t = f"{prefix}.{t}"
new_target_key.append(t)
target_key = "|".join(new_target_key)
for t in target_key.split("|"):
empty_tensor = meta_model_state_dict.get(t)
if empty_tensor is None:
unexpected_keys.add(t)
continue
if (
quantizer is not None
and quantizer.param_needs_quantization(model, t)
and quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer"
):
converter.quantization_operation[t] = Fp8Quantize() # TODO support other methods
else:
raise ValueError("This quantization method is gonna be supported SOOOON")
first_target_key = target_key.split("|")[0]
future = None
if device_mesh:
if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
empty_tensor = meta_model_state_dict.get(first_target_key)
if getattr(converter, "distributed_operation", {}) == {}:
converter.distributed_operation = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]]
converter.distributed_operation.device_mesh = device_mesh
converter.distributed_operation.rank = device_map[""].index
converter.distributed_operation.empty_tensor = empty_tensor.clone()
shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
future = spawn_tp_materialize(
EXEC, _file_sems, file_id, tensor, converter.distributed_operation, empty_tensor, shard_index
)
if future is None: # If not TP, async move tensors
future = spawn_materialize(EXEC, _file_sems, file_id, tensor)
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
# 2. Actually convert the ckpt
inverse_converters = {}
keys = list(by_conversion_pattern.keys())
total_layers = sum(len(by_conversion_pattern[key].collected_tensors) for key in keys)
progress_bar = logging.tqdm(total=total_layers, desc="Converting weights", leave=False) if total_layers else None
try:
for key in keys[::-1]: # revert to process simple keys first
group = by_conversion_pattern.pop(key)
converter = group.weight_converter
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
concrete_target_keys = layer_name.split("|")
if bool(set(concrete_target_keys) - unexpected_keys):
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
for op in operations:
try:
values = op.convert(values)
except Exception as e:
misc[layer_name] = (
f"{e}\nError: {op.__class__.__name__} on tensors collected from {converter.source_keys}. Ckpt contains: {values}"
)
values = [values] if not isinstance(values, list) else values
realized_value = {k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys}
for k in list(realized_value.keys()).copy():
if op := converter.quantization_operation.get(k):
try:
realized_value.update(
op.convert({k: realized_value.pop(k)}, quant_config=quantizer.quantization_config)
)
except Exception as e:
misc[layer_name] = f"{op.__class__.__name__}: {e}"
if progress_bar is not None:
progress_bar.set_postfix_str(layer_name, refresh=False)
progress_bar.update()
for k, output_value in realized_value.items():
matched_dtype_pattern = match_glob(k, dtype_policy_alt, dtype_policy_by_group_name)
if matched_dtype_pattern is not None:
op = Cast(keep_in_dtype[matched_dtype_pattern])
output_value = op(output_value)
for src in converter.source_keys: # what should happen to k when we meet k at saving
inverse_converters[k] = {src: converter}
set_param_for_module(
model,
k,
output_value,
meta_model_state_dict,
empty_tensor,
mismatch_keys,
missing_keys,
misc,
converter.distributed_operation,
)
del group
for op in operations:
op.clear_cache()
finally:
if progress_bar is not None:
progress_bar.close()
model.inverse_converters = inverse_converters
EXEC.shutdown(wait=True)
return missing_keys, unexpected_keys, mismatch_keys, misc
# TODO this is not done yet!
def revert_weight_conversion(model, state_dict):
reverse_key_mapping = getattr(model, "inverse_converters", {})
original_state_dict = {}
for key, value in state_dict.items():
for pattern, inverse_converter in reverse_key_mapping.items():
# TODO FIXME you name it
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*\)", "", replacement)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
return state_dict

View File

@ -1635,12 +1635,7 @@ class GenerationMixin(ContinuousMixin):
# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
for key, value in model_kwargs.items():
if (
value is not None
and key not in model_args
and key not in TransformersKwargs.__optional_keys__
and key != "debug_io"
):
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
unused_model_args.append(key)
if unused_model_args:

View File

@ -512,8 +512,10 @@ def accelerate_disk_offload(
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
):
disk_only_shard_files = []
if disk_offload_folder is not None:
@ -532,13 +534,19 @@ def accelerate_disk_offload(
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
# Fix the weight map keys according to the key mapping
weight_map = {
key_renaming_mapping[k]: v
for k, v in sharded_metadata["weight_map"].items()
if k in key_renaming_mapping
}
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
disk_offload_index = {
name: {
"safetensors_file": file,
"weight_name": name,
"weight_name": reverse_key_renaming_mapping[name],
"dtype": str_dtype,
}
for name, file in weight_map.items()

View File

@ -13,11 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Optional
from ..core_model_loading import ConversionOps
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
@ -33,18 +30,6 @@ if is_accelerate_available():
logger = logging.get_logger(__name__)
try:
_FP8_DTYPE = torch.float8_e4m3fn
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
_FP8_IS_INT = False
except AttributeError:
_FP8_DTYPE = torch.int8
_FP8_MIN, _FP8_MAX = -127, 127
_FP8_IS_INT = True
logger.warning_once(
"torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
)
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
@ -347,12 +332,6 @@ class FP8Linear(nn.Linear):
if self.weight.element_size() > 1:
return F.linear(input, self.weight, self.bias)
else:
if isinstance(self.weight, torch.distributed.tensor.DTensor):
weight = self.weight._local_tensor.contiguous()
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
else:
weight = self.weight
scale_inv = self.weight_scale_inv
# Context manager used to switch among the available accelerators
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
@ -360,9 +339,9 @@ class FP8Linear(nn.Linear):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
self.weight,
scale,
scale_inv,
self.weight_scale_inv,
self.block_size,
output_dtype=input.dtype,
)
@ -374,120 +353,6 @@ class FP8Linear(nn.Linear):
return output.to(dtype=input.dtype)
def _ceil_div(a, b):
return (a + b - 1) // b
class FP8Expert(nn.Module):
dtype = torch.float8_e4m3fn
def __init__(self, config, block_size, device):
super().__init__()
from ..activations import ACT2FN
self.block_size = block_size
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
self.gate_up_proj = nn.Parameter(
torch.empty(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
)
self.down_proj = nn.Parameter(
torch.empty(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
)
# Create inverse scale tiles only when using 1-byte types (fp8)
if self.gate_up_proj.element_size() == 1:
bo, bi = self.block_size
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
gu_scale_o = _ceil_div(Wg_out, bo)
gu_scale_i = _ceil_div(Wg_in, bi)
self.gate_up_proj_scales_inv = nn.Parameter(
torch.empty(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
)
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
dp_scale_o = _ceil_div(Wd_out, bo)
dp_scale_i = _ceil_div(Wd_in, bi)
self.down_proj_scales_inv = nn.Parameter(
torch.empty(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
)
else:
# Match FP8Linear behavior when not using 1-byte weights
self.register_parameter("gate_up_proj_scale_inv", None)
self.register_parameter("down_proj_scale_inv", None)
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
self.register_parameter("gate_up_bias", None)
self.register_parameter("down_bias", None)
# Activation used in the MLP (same as your config / ACT2FN)
# Keep a handle here; actual usage happens in forward of your MoE block
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
# current_state = hidden_states[token_idx]
current_state = hidden_states.index_select(0, token_idx)
gate, up = self.linear(
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx]
).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = self.linear(
current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx]
)
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor:
if weight.element_size() > 1:
return F.linear(input, weight, self.bias)
else:
# Context manager used to switch among the available accelerators
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
with torch_accelerator_module.device(input.device):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
scale,
weight_scale_inv,
self.block_size,
output_dtype=input.dtype,
)
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch_accelerator_module.synchronize()
return output.to(dtype=input.dtype)
# TODO: we do need this.... but not recursive...
def _replace_with_fp8_linear(
model,
tp_plan=None,
@ -496,48 +361,40 @@ def _replace_with_fp8_linear(
quantization_config=None,
has_been_replaced=False,
):
iterator = list(model.named_parameters()).copy()
for name, empty_tensor in iterator:
current_key_name = name
name = name.rsplit(".", 1)[0] if "." in name else name
module = model.get_submodule(name)
"""Replace Linear layers with FP8Linear."""
if current_key_name is None:
current_key_name = []
current_key_name_str = re.sub(r"\d+", "*", current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
if (
"gate_up_proj" in current_key_name
or "down_proj" in current_key_name
and "experts" in current_key_name
): # Experts!
in_features = empty_tensor.size(-2)
out_features = empty_tensor.size(-1)
model.set_submodule(
name,
FP8Expert(
config=model.config,
block_size=quantization_config.weight_block_size,
device=empty_tensor.device,
),
)
for name, module in model.named_children():
current_key_name.append(name)
elif isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
model.set_submodule(
name,
FP8Linear(
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
),
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
current_key_name_str = ".".join(current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
model._modules[name] = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
)
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fp8_linear(
module,
tp_plan,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
current_key_name.pop(-1)
return model, has_been_replaced
@ -548,7 +405,7 @@ def replace_with_fp8_linear(
quantization_config=None,
):
"""Helper function to replace model layers with FP8 versions."""
modules_to_not_convert += ["lm_head"]
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
@ -567,133 +424,3 @@ def replace_with_fp8_linear(
)
return model
class QuantizationOp(ConversionOps):
"""Base class for quantization operations."""
pass
class Fp8Quantize(QuantizationOp):
"""
A quantization operation that creates two tensors, weight and scale out of a weight.
"""
_inverse_op: type[ConversionOps]
def __init__(self, block_size: Optional[tuple[int, int]] = None):
self.block_size = block_size
self._inverse_op = Fp8Dequantize
def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) -> dict[str, torch.Tensor]:
# Unpack single key/value (value may be wrapped in a list)
target_keys, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
# Resolve block size (support dict-like or attr-like quant_config)
block_size = None
if quant_config is not None:
if isinstance(quant_config, dict):
block_size = quant_config.get("weight_block_size")
else:
block_size = getattr(quant_config, "weight_block_size", None)
if block_size is None:
block_size = (value.shape[-2], value.shape[-1])
block_m, block_n = block_size
rows, cols = value.shape[-2], value.shape[-1]
# Enforce exact tiling like your original
if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
)
# Leading dims can be empty (2D) or include num_experts/... (3D+)
leading_shape = value.shape[:-2]
rows_tiles = rows // block_m
cols_tiles = cols // block_n
original_shape = value.shape
value_fp32 = value.to(torch.float32)
# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
# Per-tile max-abs over the block dims
# dims: block_m is at -3, block_n is at -1 after the reshape
max_abs = reshaped.abs().amax(dim=(-3, -1))
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
scales = _FP8_MAX / safe_max_abs
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
# Broadcast scales back over the block dims and quantize
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
scaled = reshaped * scales_broadcast
if _FP8_IS_INT:
quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
else:
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
quantized = quantized.reshape(original_shape)
inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
if target_keys.endswith("weight"):
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
else:
scale_key = target_keys + "_scales_inv"
# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
return {
target_keys: quantized,
scale_key: inv_scales,
}
class Fp8Dequantize(QuantizationOp):
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
def __init__(self, block_size: Optional[tuple[int, int]] = None):
self.block_size = block_size
self._inverse_op = Fp8Quantize
def convert(
self,
value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]],
*,
context: dict[str, Any],
) -> torch.Tensor:
if isinstance(value, dict):
tensors = list(value.values())
else:
tensors = list(value) if isinstance(value, Sequence) else [value]
if len(tensors) != 2:
raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.")
quantized, scales = tensors
if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
raise TypeError("Fp8Dequantize expects tensors as inputs.")
quantized_fp32 = quantized.to(torch.float32)
rows, cols = quantized_fp32.shape[-2:]
block_size = self.block_size
if block_size is None:
quant_config = context.get("quantization_config")
block_size = getattr(quant_config, "weight_block_size", None)
if block_size is None:
block_size = (rows, cols)
block_m, block_n = block_size
if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
)
reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
dequantized = reshaped * expanded_scales
return dequantized.reshape(quantized_fp32.shape)

View File

@ -236,7 +236,7 @@ class PeftAdapterMixin:
**adapter_kwargs,
)
peft_config.inference_mode = not is_trainable
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
# Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)

View File

@ -18,7 +18,6 @@ import operator
import os
import re
from functools import partial, reduce
from typing import Optional
import torch
import torch.distributed as dist
@ -307,7 +306,7 @@ def repack_weights(
return final_ordered_tensor
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None):
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
"""
Generalized tensor sharding across a multi-dimensional device mesh.
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
@ -359,57 +358,32 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Opt
rank (int): Global rank of the current process/device.
dim (int): Dimension along which to shard the tensor.
"""
param_dim = empty_param.ndim
param_dim = empty_param.dim()
if dim < 0:
dim = param_dim + dim
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
# Flatten the mesh to get the total number of devices
mesh_shape = device_mesh.shape
world_size = reduce(operator.mul, mesh_shape)
if dim < 0:
dim = param_dim + dim
if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2:
dim = 0
elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
dim = 0
shard_size = math.ceil(empty_param.size(dim) / world_size)
start = rank * shard_size
end = min(start + shard_size, empty_param.size(dim))
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
if rank >= world_size:
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
# we have the full tensor not 1 part of it.
# in that case, we just assume that the weight was properly saved
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
# here we take care of potential chunking / layer split / layer chunking.
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
# actually we still shard dim=0 does not change
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
# tensor on a certain device (with the input tensor_index)
dimensions = param.get_shape()
shard_size = math.ceil(empty_param.shape[dim] / world_size)
start = rank * shard_size
if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
# special case we don't "shard" just send this entire tensor to the correct rank.
if start <= tensor_idx < end:
# this tensor does need to be materialized on this device:
return param[:]
else:
return torch.empty([], dtype=torch.int64, device=rank)
slice_indices = [slice(None)] * len(param.get_shape())
if start < param.get_shape()[dim]:
# Construct slicing index dynamically
end = min(start + shard_size, empty_param.shape[dim])
slice_indices = [slice(None)] * param_dim
if start < empty_param.shape[dim]:
slice_indices[dim] = slice(start, end)
param = param[tuple(slice_indices)]
if isinstance(param, list): # TODO handle the modulelist case!
param = [p[:] for p in param]
return param
return param[tuple(slice_indices)]
dimensions = list(param.shape)
dimensions[dim] = 0
return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory....
return torch.empty(tuple(dimensions), dtype=torch.int64)
def distribute_module(
@ -436,8 +410,6 @@ class TensorParallelLayer:
"""
use_dtensor = True
device_mes = None
rank = None
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
@ -565,9 +537,6 @@ class ReplicateParallel(TensorParallelLayer):
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
def shard_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
return param[...].to(param_casting_dtype)
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
param = param[...].to(param_casting_dtype)
if to_contiguous:
@ -609,25 +578,17 @@ class ColwiseParallel(TensorParallelLayer):
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor
def shard_tensor(self, param, empty_param, param_type=None, tensor_idx=None):
device_mesh = self.device_mesh
rank = self.rank
if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
shard = [Shard(-1)]
else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter, shard = self.shard_tensor(
param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh
)
if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
shard = [Shard(-1)]
else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
@ -647,14 +608,6 @@ class ColwiseParallel(TensorParallelLayer):
class PackedColwiseParallel(ColwiseParallel):
def shard_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
return get_packed_weights(param, empty_param, device_mesh, rank, -2), [Shard(-2)]
def create_nn_parameter(
self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh
):
return nn.Parameter(param, requires_grad=param.is_floating_point())
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
@ -701,18 +654,6 @@ class RowwiseParallel(TensorParallelLayer):
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
def shard_tensor(self, param, empty_param, param_type=None, tensor_idx=None):
device_mesh = self.device_mesh
rank = self.rank
if param_type == "bias":
shard = [Replicate()]
parameter = param[:]
else:
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx)
shard = [Shard(-1)]
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
@ -784,9 +725,6 @@ class RowwiseParallel(TensorParallelLayer):
class PackedRowwiseParallel(RowwiseParallel):
def shard_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
@ -979,9 +917,6 @@ class RouterParallel(TensorParallelLayer):
) # masking class for one hot
return router_scores, router_indices
def shard_tensor(self, param, *args, **kwargs):
return param[:], None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# TODO: i'd like for this to be the default
param = param[...].to(param_casting_dtype)

View File

@ -26,7 +26,7 @@ import sys
import warnings
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Callable, Sequence
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from enum import Enum
@ -45,17 +45,17 @@ from torch.distributions import constraints
from torch.utils.checkpoint import checkpoint
from .configuration_utils import PreTrainedConfig
from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING
from .core_model_loading import WeightConverter, convert_and_load_state_dict_in_model, revert_weight_conversion
from .distributed import DistributedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
from .integrations.accelerate import (
_get_device_map,
accelerate_disk_offload,
accelerate_dispatch,
check_and_set_device_map,
expand_device_map,
find_tied_parameters,
init_empty_weights,
)
from .integrations.deepspeed import _load_state_dict_into_zero3_model
@ -122,7 +122,6 @@ from .utils.import_utils import (
is_sagemaker_mp_enabled,
is_tracing,
)
from .utils.loading_report import log_state_dict_report
from .utils.quantization_config import QuantizationMethod
@ -131,6 +130,7 @@ if is_accelerate_available():
from accelerate.utils import (
extract_model_from_parallel,
offload_weight,
save_offload_index,
)
from accelerate.utils.modeling import get_state_dict_from_offload
@ -730,6 +730,25 @@ def load_shard_file(args):
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
error_msgs = []
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index = _load_state_dict_into_meta_model(
model,
state_dict,
shard_file,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
hf_quantizer=hf_quantizer,
device_mesh=device_mesh,
)
return error_msgs, disk_offload_index
def load_shard_files_with_threadpool(args_list):
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
@ -1155,6 +1174,104 @@ def _get_dtype(
return config, dtype, dtype_orig
def _find_missing_and_unexpected_keys(
model: "PreTrainedModel",
original_checkpoint_keys: list[str],
checkpoint_keys: list[str],
loading_base_model_from_task_state_dict: bool,
hf_quantizer: Optional[HfQuantizer],
) -> tuple[list[str], list[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
"""
prefix = model.base_model_prefix
# Compute expected keys, i.e. keys that the full model expects
expected_keys = list(model.state_dict().keys())
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys)
# Adjust prefix of the keys to make them match loaded keys before removing them
missing_keys = sorted(set(expected_keys) - set(checkpoint_keys))
unexpected_keys = set(checkpoint_keys) - set(expected_keys)
# If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys
if loading_base_model_from_task_state_dict:
task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")]
unexpected_keys.update(task_specific_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but
# may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway
model_buffers = {n for n, _ in model.named_buffers()}
unexpected_keys = sorted(unexpected_keys - model_buffers)
tied_params = find_tied_parameters(model)
for group in tied_params:
missing_in_group = [k for k in missing_keys if k in group]
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
missing_keys = [k for k in missing_keys if k not in missing_in_group]
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys)
return missing_keys, unexpected_keys
def _find_mismatched_keys(
model: "PreTrainedModel",
state_dict: Optional[dict],
checkpoint_files: Optional[list[str]],
ignore_mismatched_sizes: bool,
keys_to_rename_mapping: dict[str, str],
is_quantized: bool,
weights_only: bool,
) -> tuple[list[str], list[tuple[int, int]]]:
"""
Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes`
is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking
every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do
need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize
correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the
case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform
this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the
mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be
initialized, not only the weights that are mismatched).
"""
# An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function
# if there are no mismatch (which is almost always the case)
if not ignore_mismatched_sizes:
return [], []
if state_dict is not None:
checkpoint_files = [""]
model_state_dict = model.state_dict()
mismatched_keys = []
mismatched_shapes = []
for shard_file in checkpoint_files:
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
)
# Fix the key names
new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping}
for key, tensor in new_state_dict.items():
if key in model_state_dict and tensor.shape != model_state_dict[key].shape:
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights.
if not (
is_quantized and tensor.shape[-1] == 1 and tensor.numel() * 2 == model_state_dict[key].numel()
):
mismatched_keys.append(key)
mismatched_shapes.append((tensor.shape, model_state_dict[key].shape))
return mismatched_keys, mismatched_shapes
class PipelineParallel(Enum):
inputs = 0
outputs = 1
@ -1560,8 +1677,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
_keep_in_fp32_modules_strict = None
_dtype_per_modules: Optional[dict[str, torch.dtype]] = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
_keys_to_ignore_on_load_missing = None
@ -1732,9 +1847,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
if isinstance(self._keep_in_fp32_modules, dict):
self._dtype_per_modules = dict.fromkeys(self._keep_in_fp32_modules.keys(), torch.float32)
self._no_split_modules = self._no_split_modules or []
_CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
@ -2525,34 +2637,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# 0.02 is the standard default value across the library
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
try:
if isinstance(
module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)
):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
# between modelings (because they are prefixed with the model name)
elif (
isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
or "LayerNorm" in module.__class__.__name__
or "RMSNorm" in module.__class__.__name__
):
# Norms can exist without weights (in which case they are None from torch primitives)
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
except Exception as e:
logger.warning_once(f"Failed to init: {str(e)}")
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
# between modelings (because they are prefixed with the model name)
elif (
isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
or "LayerNorm" in module.__class__.__name__
or "RMSNorm" in module.__class__.__name__
):
# Norms can exist without weights (in which case they are None from torch primitives)
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
def _initialize_weights(self, module):
"""
@ -3350,7 +3457,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
variant: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
save_peft_format: bool = True,
save_original_format: bool = False,
**kwargs,
):
"""
@ -3399,10 +3505,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can
disable this behaviours by setting `save_peft_format` to `False`.
save_original_format (`bool`, *optional*, defaults to `True`):
For backward compatibility with the previous versions of `transfomers` you can save the checkpoint with
its reverse mapping. The reverse mapping needs to exists even if the model was loaded from a None legacy
checkpoint.
kwargs (`dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@ -3542,18 +3644,24 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict()
if (
any(
allowed_name in class_name.__name__.lower()
for class_name in self.__class__.__mro__[:-1]
for allowed_name in VLMS
)
or save_original_format
if any(
allowed_name in class_name.__name__.lower()
for class_name in self.__class__.__mro__[:-1]
for allowed_name in VLMS
):
# MEGA BIG TODO HERE: self._conversion_ops needs to be used to save the final ckpt
# using what was loaded. Actually self._conversion_ops wont work because we need it
# even if the files are not legacy -> thus no conversion happened
state_dict = revert_weight_conversion(self, state_dict)
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
original_state_dict = {}
for key, value in state_dict.items():
for pattern, replacement in reverse_key_mapping.items():
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*\)", "", replacement)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
@ -3721,8 +3829,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
# too much before scheduling the next write when its on a different
# joyfulness), but for now this enough.
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
else:
save_function(shard, os.path.join(save_directory, shard_file))
@ -4180,7 +4287,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
generation_config = kwargs.pop("generation_config", None)
gguf_file = kwargs.pop("gguf_file", None)
@ -4279,7 +4385,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
commit_hash = getattr(config, "_commit_hash", commit_hash)
download_kwargs_with_commit["commit_hash"] = commit_hash
profile_weight_conversion = kwargs.pop("profile_weight_conversion", False)
# Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
# to correctly redispatch recursively if the kwarg is provided
@ -4290,11 +4395,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
config, quantization_config, dtype, device_map, weights_only, user_agent
)
weight_conversions: Optional[list[WeightConverter]] = None
model_type = getattr(config, "model_type", None)
if model_type is not None:
weight_conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type)
if gguf_file:
if hf_quantizer is not None:
raise ValueError(
@ -4354,6 +4454,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
model.upcast_modules_in_fp32(hf_quantizer, dtype)
# Make sure to tie the weights correctly
model.tie_weights()
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
@ -4393,8 +4494,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
device_mesh=device_mesh,
key_mapping=key_mapping,
weights_only=weights_only,
weight_mapping=weight_conversions,
profile_weight_conversion=profile_weight_conversion,
)
model.tie_weights() # make sure token embedding weights are still tied if needed
@ -4415,7 +4514,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
)
# for device_map="auto" : dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances). TODO: replace with native PP
# harm performances).
if device_map is not None and device_mesh is None:
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
@ -4574,16 +4673,97 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
key_mapping: Optional[dict[str, str]] = None,
weights_only: bool = True,
weight_mapping: Optional[Sequence[WeightConverter]] = None,
profile_weight_conversion: bool = False,
):
# TODO: we should only be calling hf_quantizer.skip_placement or something like that
is_quantized = hf_quantizer is not None
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}
# Model's definition arriving here is final (TP hooks added, quantized layers replaces)
# Get all the keys of the state dicts that we have to initialize the model with
if sharded_metadata is not None:
original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"]
elif state_dict is not None:
original_checkpoint_keys = list(state_dict.keys())
else:
original_checkpoint_keys = list(
load_state_dict(checkpoint_files[0], map_location="meta", weights_only=weights_only).keys()
)
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
prefix = model.base_model_prefix
has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False
expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False
loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module
# Find the key names that the model expects from the serialized keys
key_renaming_mapping = model._get_key_renaming_mapping(
original_checkpoint_keys,
key_mapping,
loading_base_model_from_task_state_dict,
loading_task_model_from_base_state_dict,
)
checkpoint_keys = list(key_renaming_mapping.values())
# Find missing and unexpected keys from the state dict
missing_keys, unexpected_keys = _find_missing_and_unexpected_keys(
model, original_checkpoint_keys, checkpoint_keys, loading_base_model_from_task_state_dict, hf_quantizer
)
# Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the
# same way as missing keys)
mismatched_keys, mismatched_shapes = _find_mismatched_keys(
model,
state_dict,
checkpoint_files,
ignore_mismatched_sizes,
key_renaming_mapping,
is_quantized,
weights_only,
)
# We need to update both the mapping and the list of checkpoint keys to remove the mismatched and unexpected ones
key_renaming_mapping = {
k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys and v not in unexpected_keys
}
checkpoint_keys = list(key_renaming_mapping.values())
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
# loading the weights as they are not in the loaded state dict)
model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, dtype, hf_quantizer)
# correctly initialize the missing (and potentially mismatched) keys
model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized)
# Get reverse key mapping
reverse_key_renaming_mapping = {v: k for k, v in key_renaming_mapping.items()}
is_offloaded_safetensors = False
# This offload index if for params explicitly on the "disk" in the device_map
disk_offload_index = None
disk_only_shard_files = []
# Prepare parameters offloading if needed
if device_map is not None and "disk" in device_map.values():
disk_offload_index, disk_only_shard_files, is_offloaded_safetensors = accelerate_disk_offload(
disk_offload_folder,
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
)
# To be able to iterate, even if we don't use it if the state_dict is already provided
elif state_dict is not None:
checkpoint_files = [""]
# Compute expected model keys
expected_keys = list(model.state_dict().keys())
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys)
if logger.level >= logging.WARNING:
verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
@ -4592,79 +4772,46 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
# Now we read all the files to get a pointer on each physical weights
merged_state_dict = {}
all_pointer = set()
if device_map is None:
device_map = {"": "cpu"}
keys = sorted(device_map.keys(), key=len, reverse=True)
tp_plan = getattr(model, "_tp_plan", None)
keep_in_dtype = None # TODO use keep_in
error_msgs = []
misc = {}
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
else:
if checkpoint_files is not None:
pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")")
if sharded_metadata is None:
k_v_iterator = dict.fromkeys(
safe_open(checkpoint_files[0], framework="pt").keys(), "model.safetensors"
).items()
else:
k_v_iterator = sharded_metadata["weight_map"].items()
for k, v in k_v_iterator:
key = pattern.match(k).group(1)
if key is not None and key != "":
device = device_map[key]
else:
device = device_map[""]
if isinstance(device, torch.device):
device = device.index # safetensors only
file_pointer = safe_open(
os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device
)
all_pointer.add(file_pointer)
merged_state_dict[k] = (v, file_pointer.get_slice(k)) # don't meterialize yet
elif state_dict is not None:
merged_state_dict = {k: ("", v) for k, v in state_dict.items()}
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")
missing_keys, unexpected_keys, mismatched_keys, misc = convert_and_load_state_dict_in_model(
model,
merged_state_dict,
weight_mapping,
tp_plan,
hf_quantizer,
# Prepare and compatabilize arguments for serial and parallel shard loading
args_list = [
(
shard_file,
state_dict,
disk_only_shard_files,
is_quantized,
device_map,
keep_in_dtype,
device_mesh=device_mesh,
profile=profile_weight_conversion,
hf_quantizer,
key_renaming_mapping,
weights_only,
model,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
device_mesh,
)
for shard_file in checkpoint_files
]
for k in all_pointer: # finally close all opened file pointeres
k.__exit__(None, None, None)
error_msgs = []
new_state_dict = model.state_dict()
if (
os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
and not is_deepspeed_zero3_enabled()
):
_error_msgs, disk_offload_index = load_shard_files_with_threadpool(args_list)
error_msgs += _error_msgs
else:
if len(args_list) > 1:
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
#!!!!!!!!!!!!!!!!!!!!!!! POST PROCESS!!!!!!!!!!!!!!!!!!
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
prefix = model.base_model_prefix
has_prefix_module = any(s.startswith(prefix) for s in new_state_dict.keys()) if len(prefix) > 0 else False
expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False
loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
for args in args_list:
_error_msgs, disk_offload_index = load_shard_file(args)
error_msgs += _error_msgs
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
# loading the weights as they are not in the loaded state dict)
miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer)
# correctly initialize the missing (and potentially mismatched) keys
model._initialize_missing_keys(miss_and_mismatched, is_quantized)
# Save offloaded index if needed
if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors:
save_offload_index(disk_offload_index, disk_offload_folder)
disk_offload_index = None
# Post-processing for tensor parallelism
if device_mesh is not None:
@ -4703,19 +4850,48 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(
missing_keys, unexpected_keys, loading_task_model_from_base_state_dict
)
log_state_dict_report(
model=model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
logger=logger,
error_msgs=error_msgs,
unexpected_keys=unexpected_keys,
missing_keys=missing_keys,
mismatched_keys=mismatched_keys,
mismatched_shapes=mismatched_keys,
misc=misc,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
disk_offload_index = None
# TODO: separate this in another function: it's not core....
# All potential warnings/infos
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info
warner(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {update_key_name(unexpected_keys)}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {update_key_name(missing_keys)}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes)
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
@ -4925,8 +5101,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
_load_parameter_into_model(self, key, value)
else:
# hf_quantizer.create_quantized_param(self, value, key, "cpu")
pass
hf_quantizer.create_quantized_param(self, value, key, "cpu")
def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None:
"""Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to

View File

@ -42,44 +42,37 @@ from ...utils.generic import check_model_inputs
from .configuration_deepseek_v2 import DeepseekV2Config
class DeepseekV2Experts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class DeepseekV2Experts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.n_routed_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(config.n_routed_experts):
self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -224,10 +224,12 @@ def apply_rotary_emb(
return xq_out, xk_out
class DeepseekV2Experts(Qwen2MoeExperts):
class DeepseekV2Experts(Qwen2MoeExperts, nn.ModuleList):
def __init__(self, config):
super().__init__(config)
nn.ModuleList.__init__(self)
self.num_experts = config.n_routed_experts
for _ in range(config.n_routed_experts):
self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size))
class DeepseekV2Moe(nn.Module):

View File

@ -149,44 +149,37 @@ class DeepseekV3TopkRouter(nn.Module):
return router_logits
class DeepseekV3NaiveMoe(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class DeepseekV3NaiveMoe(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -102,10 +102,12 @@ class DeepseekV3TopkRouter(nn.Module):
return router_logits
class DeepseekV3NaiveMoe(MixtralExperts):
class DeepseekV3NaiveMoe(MixtralExperts, nn.ModuleList):
def __init__(self, config):
super().__init__(config)
nn.ModuleList.__init__(self)
self.num_experts = config.num_local_experts
for _ in range(self.num_experts):
self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size))
class DeepseekV3MoE(nn.Module):

View File

@ -305,44 +305,37 @@ class Dots1TopkRouter(nn.Module):
return router_logits
class Dots1NaiveMoe(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Dots1NaiveMoe(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(Dots1MLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -315,95 +315,62 @@ class Ernie4_5_MoeStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoeExperts(nn.Module):
class Ernie4_5_MoeExperts(nn.ModuleList):
def __init__(self, config):
super().__init__()
self.num_experts = config.moe_num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.use_bias = config.use_bias
self.act_fn = ACT2FN[config.hidden_act]
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
if self.use_bias:
self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim))
self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
else:
self.gate_up_proj_bias = None
self.down_proj_bias = None
for _ in range(self.num_experts):
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
def forward(
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
if selected_experts.numel() == 0:
return final_hidden_states
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = int(expert_idx.item())
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
gate_inputs = F.linear(
current_state,
self.gate_up_proj[expert_idx],
None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx],
)
gate, up = gate_inputs.chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = F.linear(
current_hidden_states,
self.down_proj[expert_idx],
None if self.down_proj_bias is None else self.down_proj_bias[expert_idx],
)
current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None]
current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class Ernie4_5_MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.linear = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.moe_statics = Ernie4_5_MoeStatics(config)
self.top_k = config.moe_k
self.norm_min = config.moe_norm_min
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.linear(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return routing_weights, selected_experts
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.num_experts = config.moe_num_experts
self.top_k = config.moe_k
self.router = Ernie4_5_MoeTopKRouter(config)
self.norm_min = config.moe_norm_min
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.moe_statics = Ernie4_5_MoeStatics(config)
self.experts = Ernie4_5_MoeExperts(config)
self.shared_experts = None
if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def route_tokens_to_experts(self, hidden_states):
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.gate(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
@ -411,7 +378,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
routing_weights, selected_experts = self.router(hidden_states)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
if self.shared_experts is not None:
@ -487,11 +454,11 @@ class Ernie4_5_MoePreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.router", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Ernie4_5_MoeDecoderLayer,
"attentions": Ernie4_5_MoeAttention,
}
_keep_in_fp32_modules_strict = ["router"]
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
# Not supporting multi-token prediction (MTP) atm
_keys_to_ignore_on_load_unexpected = ["mtp"]

View File

@ -19,7 +19,6 @@ import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...masking_utils import create_causal_mask
from ...modeling_outputs import MoeModelOutputWithPast
@ -97,95 +96,62 @@ class Ernie4_5_MoeStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoeExperts(nn.Module):
class Ernie4_5_MoeExperts(nn.ModuleList):
def __init__(self, config):
super().__init__()
self.num_experts = config.moe_num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.use_bias = config.use_bias
self.act_fn = ACT2FN[config.hidden_act]
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
if self.use_bias:
self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim))
self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
else:
self.gate_up_proj_bias = None
self.down_proj_bias = None
for _ in range(self.num_experts):
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
def forward(
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
if selected_experts.numel() == 0:
return final_hidden_states
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = int(expert_idx.item())
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
gate_inputs = F.linear(
current_state,
self.gate_up_proj[expert_idx],
None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx],
)
gate, up = gate_inputs.chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = F.linear(
current_hidden_states,
self.down_proj[expert_idx],
None if self.down_proj_bias is None else self.down_proj_bias[expert_idx],
)
current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None]
current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class Ernie4_5_MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.linear = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.moe_statics = Ernie4_5_MoeStatics(config)
self.top_k = config.moe_k
self.norm_min = config.moe_norm_min
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.linear(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return routing_weights, selected_experts
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.num_experts = config.moe_num_experts
self.top_k = config.moe_k
self.router = Ernie4_5_MoeTopKRouter(config)
self.norm_min = config.moe_norm_min
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.moe_statics = Ernie4_5_MoeStatics(config)
self.experts = Ernie4_5_MoeExperts(config)
self.shared_experts = None
if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def route_tokens_to_experts(self, hidden_states):
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.gate(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
@ -193,7 +159,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
routing_weights, selected_experts = self.router(hidden_states)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
if self.shared_experts is not None:
@ -227,11 +193,11 @@ class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer):
class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel):
config: Ernie4_5_MoeConfig
_no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
_keep_in_fp32_modules_strict = ["router"]
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
# Not supporting multi-token prediction (MTP) atm
_keys_to_ignore_on_load_unexpected = ["mtp"]
_can_record_outputs = {
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.router", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Ernie4_5_MoeDecoderLayer,
"attentions": Ernie4_5_MoeAttention,
}

View File

@ -109,7 +109,6 @@ class FlexOlmoConfig(PreTrainedConfig):
model_type = "flex_olmo"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_local_experts": "num_experts"}
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k

View File

@ -23,7 +23,6 @@ from collections.abc import Callable
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
@ -292,78 +291,64 @@ class FlexOlmoAttention(nn.Module):
return attn_output, attn_weights
class FlexOlmoExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class FlexOlmoExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: FlexOlmoConfig):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class FlexOlmoTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
for _ in range(config.num_experts):
self.append(FlexOlmoMLP(config))
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class FlexOlmoSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = FlexOlmoTopKRouter(config)
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
self.experts = FlexOlmoExperts(config)
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1)
top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
top_k_weights = top_k_weights.to(hidden_states.dtype)
return top_k_index, top_k_weights
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
top_k_weights, top_k_index = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
batch_size, sequence_length, hidden_dim
)

View File

@ -156,7 +156,7 @@ class Gemma3TextConfig(PreTrainedConfig):
layer_types: Optional[list[str]] = None,
final_logit_softcapping: Optional[float] = None,
attn_logit_softcapping: Optional[float] = None,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
use_bidirectional_attention: Optional[bool] = False,
**kwargs,
):
@ -186,10 +186,16 @@ class Gemma3TextConfig(PreTrainedConfig):
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.layer_types = layer_types
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
if rope_scaling is not None:
rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling}
if (rope_scaling := kwargs.pop("rope_scaling", None)) is not None:
if rope_parameters is None:
rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling}
elif "full_attention" in rope_parameters:
rope_parameters["full_attention"].update(rope_scaling)
else:
rope_parameters.update(rope_scaling)
self.rope_parameters = rope_parameters
self.use_bidirectional_attention = use_bidirectional_attention
if use_bidirectional_attention:

View File

@ -191,7 +191,10 @@ _VARIANTS = {
num_hidden_layers=34,
num_key_value_heads=4,
sliding_window=1024,
rope_parameters={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
rope_parameters={
"full_attention": {"rope_type": "linear", "factor": 8.0},
"sliding_attention": {"rope_type": "default"},
},
rope_theta=1_000_000,
rope_local_base_freq=10_000,
attn_logit_softcapping=None,
@ -209,7 +212,10 @@ _VARIANTS = {
num_hidden_layers=48,
num_key_value_heads=8,
sliding_window=1024,
rope_parameters={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
rope_parameters={
"full_attention": {"rope_type": "linear", "factor": 8.0},
"sliding_attention": {"rope_type": "default"},
},
rope_theta=1_000_000,
rope_local_base_freq=10_000,
attn_logit_softcapping=None,
@ -227,7 +233,10 @@ _VARIANTS = {
num_key_value_heads=16,
head_dim=128,
sliding_window=1024,
rope_parameters={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
rope_parameters={
"full_attention": {"rope_type": "linear", "factor": 8.0},
"sliding_attention": {"rope_type": "default"},
},
rope_theta=1_000_000,
rope_local_base_freq=10_000,
attn_logit_softcapping=None,

View File

@ -171,7 +171,7 @@ class Gemma3TextConfig(Gemma2Config, PreTrainedConfig):
layer_types: Optional[list[str]] = None,
final_logit_softcapping: Optional[float] = None,
attn_logit_softcapping: Optional[float] = None,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
use_bidirectional_attention: Optional[bool] = False,
**kwargs,
):
@ -201,10 +201,16 @@ class Gemma3TextConfig(Gemma2Config, PreTrainedConfig):
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.layer_types = layer_types
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
if rope_scaling is not None:
rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling}
if (rope_scaling := kwargs.pop("rope_scaling", None)) is not None:
if rope_parameters is None:
rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling}
elif "full_attention" in rope_parameters:
rope_parameters["full_attention"].update(rope_scaling)
else:
rope_parameters.update(rope_scaling)
self.rope_parameters = rope_parameters
self.use_bidirectional_attention = use_bidirectional_attention
if use_bidirectional_attention:

View File

@ -330,44 +330,37 @@ class Glm4MoeRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Glm4MoeNaiveMoe(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Glm4MoeNaiveMoe(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -1424,6 +1424,8 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Glm4vCausalLMOutputWithPast]:
r"""
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -1432,8 +1434,6 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
Example:

View File

@ -351,44 +351,37 @@ class Glm4vMoeTextTopkRouter(nn.Module):
return router_logits
class Glm4vMoeTextNaiveMoe(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Glm4vMoeTextNaiveMoe(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -411,9 +411,10 @@ class GraniteMoeDecoderLayer(GradientCheckpointingLayer):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx)
self.block_sparse_moe = GraniteMoeMoE(config)
self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.block_sparse_moe = GraniteMoeMoE(config)
self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
def forward(

View File

@ -105,8 +105,7 @@ class GraniteMoeDecoderLayer(MixtralDecoderLayer):
self.block_sparse_moe = GraniteMoeMoE(config)
self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
del self.mlp
self.block_sparse_moe = GraniteMoeMoE(config)
self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
def forward(

View File

@ -1119,9 +1119,10 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
self.hidden_size = config.hidden_size
# Either attention or mamba will be initialized, depending on the layer type.
self.self_attn = None
self.block_sparse_moe = GraniteMoeHybridMoE(config)
self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.block_sparse_moe = GraniteMoeHybridMoE(config)
self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
self.shared_mlp = GraniteMoeHybridMLP(config)
self.mamba = None

View File

@ -401,9 +401,10 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx)
self.block_sparse_moe = GraniteMoeSharedMoE(config)
self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.block_sparse_moe = GraniteMoeSharedMoE(config)
self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config)

View File

@ -243,44 +243,38 @@ class HunYuanMoEV1Gate(nn.Module):
return logits
class HunYuanMoEV1Experts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class HunYuanMoEV1Experts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: HunYuanMoEV1Config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(HunYuanMoEV1MLP(config))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -557,44 +557,38 @@ class JambaMLP(nn.Module):
return down_proj
class JambaExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class JambaExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: JambaConfig):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(JambaMLP(config))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -24,7 +24,6 @@ import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
@ -145,44 +144,37 @@ class Lfm2MoeMLP(nn.Module):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Lfm2MoeExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Lfm2MoeExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(config.num_experts):
self.append(Lfm2MoeMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -164,7 +164,7 @@ class LongcatFlashTopkRouter(nn.Module):
topk_indices = self.get_topk_indices(scores)
topk_weights = scores.gather(1, topk_indices)
topk_weights = topk_weights * self.routed_scaling_factor
return topk_weights.to(router_logits.dtype), topk_indices
return topk_indices, topk_weights
@torch.no_grad()
def get_topk_indices(self, scores):
@ -173,51 +173,29 @@ class LongcatFlashTopkRouter(nn.Module):
return topk_indices
class LongcatFlashExperts(nn.Module):
class LongcatFlashExperts(nn.ModuleList):
def __init__(self, config):
super().__init__()
self.intermediate_size = config.expert_ffn_hidden_size
self.hidden_size = config.hidden_size
self.num_routed_experts = config.n_routed_experts
self.zero_expert_num = config.zero_expert_num or 0
self.total_experts = self.num_routed_experts + self.zero_expert_num
self.act_fn = ACT2FN[config.hidden_act]
self.num_experts = config.n_routed_experts + config.zero_expert_num
self.zero_expert_num = config.zero_expert_num
if self.num_routed_experts > 0:
self.gate_up_proj = nn.Parameter(
torch.empty(self.total_experts, 2 * self.intermediate_size, self.hidden_size)
)
self.down_proj = nn.Parameter(
torch.empty(self.num_routed_experts, self.hidden_size, self.intermediate_size)
)
else:
self.register_parameter("gate_up_proj", None)
self.register_parameter("down_proj", None)
self.extend(
[LongcatFlashMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
+ [nn.Identity() for _ in range(self.zero_expert_num)]
)
def forward(self, hidden_states, top_k_index, top_k_weights):
final_hidden_states = torch.zeros_like(hidden_states)
if top_k_index.numel() == 0:
return final_hidden_states
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False)
for expert_idx_tensor in expert_hit:
expert_idx = int(expert_idx_tensor.item())
selection_idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0))
if token_idx.numel() == 0:
continue
current_state = hidden_states[token_idx]
if expert_idx >= self.num_routed_experts or self.gate_up_proj is None:
current_hidden_states = current_state
else:
gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, selection_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(hidden_states.dtype))
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
@ -237,7 +215,7 @@ class LongcatFlashMoE(nn.Module):
def forward(self, hidden_states):
orig_shape = hidden_states.shape
topk_weights, topk_indices = self.router(hidden_states)
topk_indices, topk_weights = self.router(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
return hidden_states

View File

@ -20,7 +20,6 @@ import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@ -91,54 +90,32 @@ class LongcatFlashTopkRouter(DeepseekV3TopkRouter):
topk_indices = self.get_topk_indices(scores)
topk_weights = scores.gather(1, topk_indices)
topk_weights = topk_weights * self.routed_scaling_factor
return topk_weights.to(router_logits.dtype), topk_indices
return topk_indices, topk_weights
class LongcatFlashExperts(nn.Module):
class LongcatFlashExperts(nn.ModuleList):
def __init__(self, config):
super().__init__()
self.intermediate_size = config.expert_ffn_hidden_size
self.hidden_size = config.hidden_size
self.num_routed_experts = config.n_routed_experts
self.zero_expert_num = config.zero_expert_num or 0
self.total_experts = self.num_routed_experts + self.zero_expert_num
self.act_fn = ACT2FN[config.hidden_act]
self.num_experts = config.n_routed_experts + config.zero_expert_num
self.zero_expert_num = config.zero_expert_num
if self.num_routed_experts > 0:
self.gate_up_proj = nn.Parameter(
torch.empty(self.total_experts, 2 * self.intermediate_size, self.hidden_size)
)
self.down_proj = nn.Parameter(
torch.empty(self.num_routed_experts, self.hidden_size, self.intermediate_size)
)
else:
self.register_parameter("gate_up_proj", None)
self.register_parameter("down_proj", None)
self.extend(
[LongcatFlashMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
+ [nn.Identity() for _ in range(self.zero_expert_num)]
)
def forward(self, hidden_states, top_k_index, top_k_weights):
final_hidden_states = torch.zeros_like(hidden_states)
if top_k_index.numel() == 0:
return final_hidden_states
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False)
for expert_idx_tensor in expert_hit:
expert_idx = int(expert_idx_tensor.item())
selection_idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0))
if token_idx.numel() == 0:
continue
current_state = hidden_states[token_idx]
if expert_idx >= self.num_routed_experts or self.gate_up_proj is None:
current_hidden_states = current_state
else:
gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, selection_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(hidden_states.dtype))
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
@ -158,7 +135,7 @@ class LongcatFlashMoE(nn.Module):
def forward(self, hidden_states):
orig_shape = hidden_states.shape
topk_weights, topk_indices = self.router(hidden_states)
topk_indices, topk_weights = self.router(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
return hidden_states

View File

@ -452,63 +452,57 @@ class MiniMaxAttention(nn.Module):
return attn_output, attn_weights
class MiniMaxExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class MiniMaxMLP(nn.Module):
def __init__(self, config: MiniMaxConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MiniMaxExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: MiniMaxConfig):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class MiniMaxTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
for _ in range(self.num_experts):
self.append(MiniMaxMLP(config))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class MiniMaxSparseMoeBlock(nn.Module):
@ -516,15 +510,22 @@ class MiniMaxSparseMoeBlock(nn.Module):
super().__init__()
self.top_k = config.num_experts_per_tok
self.jitter_noise = config.router_jitter_noise
self.gate = MiniMaxTopKRouter(config)
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = MiniMaxExperts(config)
def route_tokens_to_experts(self, router_logits):
routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1)
top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
return top_k_index, top_k_weights.to(router_logits.dtype)
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
top_k_weights, top_k_index = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return hidden_states
@ -536,6 +537,8 @@ class MiniMaxDecoderLayer(GradientCheckpointingLayer):
self.hidden_size = config.hidden_size
self.self_attn = MiniMaxAttention(config, layer_idx)
self.block_sparse_moe = MiniMaxSparseMoeBlock(config)
self.input_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -543,7 +546,7 @@ class MiniMaxDecoderLayer(GradientCheckpointingLayer):
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.mlp_alpha_factor = config.mlp_alpha_factor
self.mlp_beta_factor = config.mlp_beta_factor
self.block_sparse_moe = MiniMaxSparseMoeBlock(config)
if self.layer_type == "linear_attention":
self.self_attn = MiniMaxLightningAttention(config, layer_idx)
self.attn_alpha_factor = config.linear_attn_alpha_factor

View File

@ -476,8 +476,7 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.mlp_alpha_factor = config.mlp_alpha_factor
self.mlp_beta_factor = config.mlp_beta_factor
del self.mlp
self.block_sparse_moe = MiniMaxSparseMoeBlock(config)
if self.layer_type == "linear_attention":
self.self_attn = MiniMaxLightningAttention(config, layer_idx)
self.attn_alpha_factor = config.linear_attn_alpha_factor

View File

@ -115,16 +115,14 @@ class MixtralConfig(PreTrainedConfig):
model_type = "mixtral"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "local_colwise",
"layers.*.self_attn.k_proj": "local_colwise",
"layers.*.self_attn.v_proj": "local_colwise",
"layers.*.self_attn.o_proj": "local_rowwise",
"layers.*.self_attn": "gather",
"layers.*.mlp.gate": "ep_router", # we need to replicate here to correctly route experts
"layers.*.mlp.experts.gate_up_proj": "local_colwise",
"layers.*.mlp.experts.down_proj": "local_rowwise",
"layers.*.mlp.experts": "gather",
# "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise" ? if you load from
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -28,7 +28,6 @@ from collections.abc import Callable
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.utils.generic import check_model_inputs
@ -54,63 +53,57 @@ from ...utils.generic import OutputRecorder
from .configuration_mixtral import MixtralConfig
class MixtralExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class MixtralMLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MixtralExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: MixtralConfig):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class MixtralTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
for _ in range(self.num_experts):
self.append(MixtralMLP(config))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class MixtralSparseMoeBlock(nn.Module):
@ -118,15 +111,22 @@ class MixtralSparseMoeBlock(nn.Module):
super().__init__()
self.top_k = config.num_experts_per_tok
self.jitter_noise = config.router_jitter_noise
self.gate = MixtralTopKRouter(config)
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = MixtralExperts(config)
def route_tokens_to_experts(self, router_logits):
routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1)
top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
return top_k_index, top_k_weights.to(router_logits.dtype)
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
top_k_weights, top_k_index = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return hidden_states
@ -359,7 +359,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer):
self.self_attn = MixtralAttention(config, layer_idx)
self.mlp = MixtralSparseMoeBlock(config)
self.block_sparse_moe = MixtralSparseMoeBlock(config)
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -387,7 +387,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer):
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@ -405,7 +405,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0),
"hidden_states": MixtralDecoderLayer,
"attentions": MixtralAttention,
}

View File

@ -22,7 +22,6 @@
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
@ -132,63 +131,57 @@ def load_balancing_loss_func(
return overall_loss * num_experts
class MixtralExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class MixtralMLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MixtralExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: MixtralConfig):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class MixtralTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
for _ in range(self.num_experts):
self.append(MixtralMLP(config))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class MixtralSparseMoeBlock(nn.Module):
@ -196,15 +189,22 @@ class MixtralSparseMoeBlock(nn.Module):
super().__init__()
self.top_k = config.num_experts_per_tok
self.jitter_noise = config.router_jitter_noise
self.gate = MixtralTopKRouter(config)
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = MixtralExperts(config)
def route_tokens_to_experts(self, router_logits):
routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1)
top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
return top_k_index, top_k_weights.to(router_logits.dtype)
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
top_k_weights, top_k_index = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype))
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return hidden_states
@ -229,7 +229,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer):
self.self_attn = MixtralAttention(config, layer_idx)
self.mlp = MixtralSparseMoeBlock(config)
self.block_sparse_moe = MixtralSparseMoeBlock(config)
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -257,7 +257,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer):
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@ -265,7 +265,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer):
class MixtralPreTrainedModel(MistralPreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_can_record_outputs = {
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0),
"hidden_states": MixtralDecoderLayer,
"attentions": MixtralAttention,
}

View File

@ -104,7 +104,6 @@ class OlmoeConfig(PreTrainedConfig):
model_type = "olmoe"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_local_experts": "num_experts"}
def __init__(
self,

View File

@ -20,7 +20,6 @@ from collections.abc import Callable
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
@ -295,78 +294,64 @@ class OlmoeAttention(nn.Module):
return attn_output, attn_weights
class OlmoeExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class OlmoeExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: OlmoeConfig):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class OlmoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
for _ in range(config.num_experts):
self.append(OlmoeMLP(config))
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class OlmoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = OlmoeTopKRouter(config)
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
self.experts = OlmoeExperts(config)
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1)
top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
top_k_weights = top_k_weights.to(hidden_states.dtype)
return top_k_index, top_k_weights
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
top_k_weights, top_k_index = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
batch_size, sequence_length, hidden_dim
)
@ -426,7 +411,7 @@ class OlmoePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_can_record_outputs = {
"router_logits": OutputRecorder(OlmoeTopKRouter, index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="gate", index=1),
"hidden_states": OlmoeDecoderLayer,
"attentions": OlmoeAttention,
}

View File

@ -35,7 +35,6 @@ from ..llama.modeling_llama import (
eager_attention_forward,
)
from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM, MixtralModel
from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeTopKRouter
from .configuration_olmoe import OlmoeConfig
@ -116,24 +115,38 @@ class OlmoeAttention(LlamaAttention):
return attn_output, attn_weights
class OlmoeExperts(MixtralExperts):
pass
class OlmoeTopKRouter(Qwen2MoeTopKRouter):
pass
class OlmoeExperts(MixtralExperts, nn.ModuleList):
def __init__(self, config):
nn.ModuleList.__init__(self)
for _ in range(config.num_experts):
self.append(OlmoeMLP(config))
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
class OlmoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = OlmoeTopKRouter(config)
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
self.experts = OlmoeExperts(config)
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1)
top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
top_k_weights = top_k_weights.to(hidden_states.dtype)
return top_k_index, top_k_weights
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
top_k_weights, top_k_index = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
batch_size, sequence_length, hidden_dim
)
@ -160,7 +173,7 @@ class OlmoePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_can_record_outputs = {
"router_logits": OutputRecorder(OlmoeTopKRouter, index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="gate", index=1),
"hidden_states": OlmoeDecoderLayer,
"attentions": OlmoeAttention,
}

View File

@ -262,6 +262,24 @@ class PhimoeAttention(nn.Module):
return attn_output, attn_weights
class PhimoeMLP(nn.Module):
def __init__(self, config: PhimoeConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class PhimoeMultiplier(torch.autograd.Function):
@staticmethod
def forward(
@ -324,47 +342,58 @@ class PhimoeMultiplier(torch.autograd.Function):
)
class PhimoeExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class PhimoeExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: PhimoeConfig):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(PhimoeMLP(config))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class PhimoeRouter(nn.Linear):
def __init__(self, config: PhimoeConfig):
super().__init__(config.hidden_size, config.num_local_experts, bias=False)
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.router_jitter_noise = config.router_jitter_noise
self.input_jitter_noise = config.router_jitter_noise
def forward(self, hidden_states):
if self.training and self.input_jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
)
router_logits = super().forward(hidden_states)
return router_logits
def sparsemixer(scores, jitter_eps, training, top_k=2):
"""
Sparse mixer function to select top-k experts and compute multipliers.
@ -488,27 +517,6 @@ def sparsemixer(scores, jitter_eps, training, top_k=2):
)
class PhimoeTopKRouter(nn.Linear):
def __init__(self, config: PhimoeConfig):
super().__init__(config.hidden_size, config.num_local_experts, bias=False)
self.router_jitter_noise = config.router_jitter_noise
self.input_jitter_noise = config.input_jitter_noise
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.training and self.input_jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
)
router_logits = super().forward(hidden_states)
routing_weights, selected_experts = sparsemixer(
router_logits,
jitter_eps=self.router_jitter_noise,
training=self.training,
)
routing_weights = torch.zeros_like(router_logits).scatter_(1, selected_experts, routing_weights)
return routing_weights, selected_experts
class PhimoeSparseMoeBlock(nn.Module):
"""
This implementation is
@ -527,10 +535,19 @@ class PhimoeSparseMoeBlock(nn.Module):
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.router = PhimoeTopKRouter(config)
self.router_jitter_noise = config.router_jitter_noise
self.gate = PhimoeRouter(config)
self.experts = PhimoeExperts(config)
self.input_jitter_noise = config.input_jitter_noise
def route_tokens_to_experts(self, router_logits):
routing_weights, selected_experts = sparsemixer(
router_logits,
jitter_eps=self.router_jitter_noise,
training=self.training,
)
return routing_weights, selected_experts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.input_jitter_noise > 0:
@ -540,7 +557,8 @@ class PhimoeSparseMoeBlock(nn.Module):
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim)
routing_weights, selected_experts = self.router(hidden_states)
router_logits = self.gate(hidden_states)
routing_weights, selected_experts = self.route_tokens_to_experts(router_logits)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
@ -573,7 +591,7 @@ class PhimoeDecoderLayer(GradientCheckpointingLayer):
self.self_attn = PhimoeAttention(config, layer_idx)
self.mlp = PhimoeSparseMoeBlock(config)
self.block_sparse_moe = PhimoeSparseMoeBlock(config)
self.input_layernorm = PhimoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = PhimoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -601,7 +619,7 @@ class PhimoeDecoderLayer(GradientCheckpointingLayer):
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@ -619,7 +637,7 @@ class PhimoePreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(PhimoeTopKRouter, layer_name="mlp.router", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": PhimoeDecoderLayer,
"attentions": PhimoeAttention,
}

View File

@ -30,6 +30,7 @@ from ..mixtral.modeling_mixtral import (
MixtralDecoderLayer,
MixtralExperts,
MixtralForCausalLM,
MixtralMLP,
MixtralModel,
MixtralPreTrainedModel,
MixtralRotaryEmbedding,
@ -86,6 +87,10 @@ class PhimoeAttention(LlamaAttention):
pass
class PhimoeMLP(MixtralMLP):
pass
class PhimoeMultiplier(torch.autograd.Function):
@staticmethod
def forward(
@ -271,29 +276,30 @@ def sparsemixer(scores, jitter_eps, training, top_k=2):
)
class PhimoeExperts(MixtralExperts):
pass
class PhimoeExperts(MixtralExperts, nn.ModuleList):
def __init__(self, config: PhimoeConfig):
nn.ModuleList.__init__(self)
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
for _ in range(self.num_experts):
self.append(PhimoeMLP(config))
class PhimoeTopKRouter(nn.Linear):
class PhimoeRouter(nn.Linear):
def __init__(self, config: PhimoeConfig):
super().__init__(config.hidden_size, config.num_local_experts, bias=False)
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.router_jitter_noise = config.router_jitter_noise
self.input_jitter_noise = config.input_jitter_noise
self.input_jitter_noise = config.router_jitter_noise
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def forward(self, hidden_states):
if self.training and self.input_jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
)
router_logits = super().forward(hidden_states)
routing_weights, selected_experts = sparsemixer(
router_logits,
jitter_eps=self.router_jitter_noise,
training=self.training,
)
routing_weights = torch.zeros_like(router_logits).scatter_(1, selected_experts, routing_weights)
return routing_weights, selected_experts
return router_logits
class PhimoeSparseMoeBlock(nn.Module):
@ -314,10 +320,19 @@ class PhimoeSparseMoeBlock(nn.Module):
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.router = PhimoeTopKRouter(config)
self.router_jitter_noise = config.router_jitter_noise
self.gate = PhimoeRouter(config)
self.experts = PhimoeExperts(config)
self.input_jitter_noise = config.input_jitter_noise
def route_tokens_to_experts(self, router_logits):
routing_weights, selected_experts = sparsemixer(
router_logits,
jitter_eps=self.router_jitter_noise,
training=self.training,
)
return routing_weights, selected_experts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.input_jitter_noise > 0:
@ -327,7 +342,8 @@ class PhimoeSparseMoeBlock(nn.Module):
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim)
routing_weights, selected_experts = self.router(hidden_states)
router_logits = self.gate(hidden_states)
routing_weights, selected_experts = self.route_tokens_to_experts(router_logits)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
@ -338,7 +354,7 @@ class PhimoeDecoderLayer(MixtralDecoderLayer):
class PhimoePreTrainedModel(MixtralPreTrainedModel):
_can_record_outputs = {
"router_logits": OutputRecorder(PhimoeTopKRouter, layer_name="mlp.router", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": PhimoeDecoderLayer,
"attentions": PhimoeAttention,
}

View File

@ -289,81 +289,66 @@ class Qwen2MoeAttention(nn.Module):
return attn_output, attn_weights
class Qwen2MoeExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Qwen2MoeExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(config.num_experts):
self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class Qwen2MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = Qwen2MoeTopKRouter(config)
# gating
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Qwen2MoeExperts(config)
self.num_experts_per_tok = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
shared_expert_output = self.shared_expert(hidden_states_reshaped)
routing_weights, selected_experts = self.gate(hidden_states_reshaped)
router_logits = self.gate(hidden_states_reshaped)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output
@ -434,7 +419,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(Qwen2MoeTopKRouter, index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Qwen2MoeDecoderLayer,
"attentions": Qwen2MoeAttention,
}

View File

@ -82,47 +82,40 @@ class Qwen2MoeAttention(LlamaAttention):
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
class Qwen2MoeExperts(MixtralExperts):
class Qwen2MoeExperts(MixtralExperts, nn.Module):
def __init__(self, config):
super().__init__(config)
nn.ModuleList.__init__(self)
self.num_experts = config.num_experts
self.intermediate_dim = config.moe_intermediate_size
class Qwen2MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
for _ in range(config.num_experts):
self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size))
class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = Qwen2MoeTopKRouter(config)
# gating
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Qwen2MoeExperts(config)
self.num_experts_per_tok = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
shared_expert_output = self.shared_expert(hidden_states_reshaped)
routing_weights, selected_experts = self.gate(hidden_states_reshaped)
router_logits = self.gate(hidden_states_reshaped)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output
@ -150,7 +143,7 @@ class Qwen2MoeDecoderLayer(LlamaDecoderLayer, nn.Module):
@auto_docstring
class Qwen2MoePreTrainedModel(MixtralPreTrainedModel):
_can_record_outputs = {
"router_logits": OutputRecorder(Qwen2MoeTopKRouter, index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Qwen2MoeDecoderLayer,
"attentions": Qwen2MoeAttention,
}

View File

@ -209,78 +209,61 @@ class Qwen3MoeMLP(nn.Module):
return down_proj
class Qwen3Moe(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Qwen3MoeExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
def __init__(self, config: Qwen3MoeConfig):
super().__init__()
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class Qwen3MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(self, config: Qwen3MoeConfig):
super().__init__()
self.experts = Qwen3Moe(config)
self.router = Qwen3MoeTopKRouter(config)
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Qwen3MoeExperts(config)
self.num_experts_per_tok = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
routing_weights, selected_experts = self.router(hidden_states_reshaped)
router_logits = self.gate(hidden_states_reshaped)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
@ -367,7 +350,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.router", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Qwen3MoeDecoderLayer,
"attentions": Qwen3MoeAttention,
}

View File

@ -17,6 +17,7 @@
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...cache_utils import Cache
@ -31,12 +32,13 @@ from ..llama.modeling_llama import (
LlamaRMSNorm,
)
from ..mixtral.modeling_mixtral import (
MixtralExperts,
MixtralForCausalLM,
MixtralModel,
MixtralPreTrainedModel,
load_balancing_loss_func,
)
from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeExperts, Qwen2MoeMLP, Qwen2MoeTopKRouter
from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeMLP
from ..qwen3.modeling_qwen3 import Qwen3Attention
from .configuration_qwen3_moe import Qwen3MoeConfig
@ -55,24 +57,35 @@ class Qwen3MoeMLP(Qwen2MoeMLP):
pass
class Qwen3Moe(Qwen2MoeExperts):
pass
class Qwen3MoeTopKRouter(Qwen2MoeTopKRouter):
pass
class Qwen3MoeExperts(MixtralExperts, nn.ModuleList):
def __init__(self, config: Qwen3MoeConfig):
nn.ModuleList.__init__(self)
self.num_experts = config.num_experts
for _ in range(self.num_experts):
self.append(Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size))
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(self, config: Qwen3MoeConfig):
super().__init__()
self.experts = Qwen3Moe(config)
self.router = Qwen3MoeTopKRouter(config)
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Qwen3MoeExperts(config)
self.num_experts_per_tok = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
routing_weights, selected_experts = self.router(hidden_states_reshaped)
router_logits = self.gate(hidden_states_reshaped)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
@ -87,7 +100,7 @@ class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer):
class Qwen3MoePreTrainedModel(MixtralPreTrainedModel):
_can_record_outputs = {
"router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.router", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Qwen3MoeDecoderLayer,
"attentions": Qwen3MoeAttention,
}

View File

@ -819,81 +819,66 @@ class Qwen3NextMLP(nn.Module):
return down_proj
class Qwen3NextExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Qwen3NextExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(config.num_experts):
self.append(Qwen3NextMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class Qwen3NextTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = Qwen3NextTopKRouter(config)
# gating
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Qwen3NextExperts(config)
self.num_experts_per_tok = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.shared_expert = Qwen3NextMLP(config, intermediate_size=config.shared_expert_intermediate_size)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
shared_expert_output = self.shared_expert(hidden_states_reshaped)
routing_weights, selected_experts = self.gate(hidden_states_reshaped)
router_logits = self.gate(hidden_states_reshaped)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output

View File

@ -1323,43 +1323,37 @@ class Qwen3OmniMoeThinkerTextMLP(nn.Module):
return down_proj
class Qwen3OmniMoeThinkerTextExperts(nn.Module):
class Qwen3OmniMoeThinkerTextExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: Qwen3OmniMoeThinkerConfig):
nn.ModuleList.__init__(self)
super().__init__()
self.num_experts = config.num_experts
for _ in range(self.num_experts):
self.append(Qwen3OmniMoeThinkerTextMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
@ -2774,83 +2768,68 @@ class Qwen3OmniMoeTalkerTextMLP(nn.Module):
return down_proj
class Qwen3OmniMoeTalkerTextExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Qwen3OmniMoeTalkerTextExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(config.num_experts):
self.append(Qwen3OmniMoeTalkerTextMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
num_experts = top_k_weights.shape[1]
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class Qwen3OmniMoeTalkerTextTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
class Qwen3OmniMoeTalkerTextSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = Qwen3OmniMoeTalkerTextTopKRouter(config)
# gating
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Qwen3OmniMoeTalkerTextExperts(config)
self.num_experts_per_tok = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.shared_expert = Qwen3OmniMoeTalkerTextMLP(
config, intermediate_size=config.shared_expert_intermediate_size
)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
shared_expert_output = self.shared_expert(hidden_states_reshaped)
routing_weights, selected_experts = self.gate(hidden_states_reshaped)
router_logits = self.gate(hidden_states_reshaped)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output

View File

@ -365,27 +365,6 @@ class Qwen3VLMoeTextDecoderLayer(GradientCheckpointingLayer):
return hidden_states
class Qwen3VLMoeTextTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
@auto_docstring
class Qwen3VLMoePreTrainedModel(PreTrainedModel):
config: Qwen3VLMoeConfig
@ -399,7 +378,7 @@ class Qwen3VLMoePreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(Qwen3VLMoeTextTopKRouter, layer_name="mlp.router", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Qwen3VLMoeTextDecoderLayer,
"attentions": Qwen3VLMoeTextAttention,
}

View File

@ -510,7 +510,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
elif "MultiheadAttentionPoolingHead" in module.__class__.__name__:
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
@ -678,9 +678,14 @@ class SiglipTextModel(SiglipPreTrainedModel):
)
class SiglipVisionTransformer(nn.Module):
class SiglipVisionTransformer(SiglipPreTrainedModel):
_can_record_outputs = {
"hidden_states": SiglipEncoderLayer,
"attentions": SiglipAttention,
}
def __init__(self, config: SiglipVisionConfig):
super().__init__()
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
@ -691,6 +696,7 @@ class SiglipVisionTransformer(nn.Module):
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(config)
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,

View File

@ -349,99 +349,6 @@ class Siglip2EncoderLayer(GradientCheckpointingLayer):
return hidden_states
class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].
Args:
config: Siglip2Config
"""
def __init__(self, config: Siglip2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
**kwargs,
)
return BaseModelOutput(last_hidden_state=hidden_states)
class Siglip2VisionTransformer(nn.Module):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = Siglip2MultiheadAttentionPoolingHead(config)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
hidden_states = self.embeddings(pixel_values, spatial_shapes)
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
else:
encoder_attention_mask = attention_mask
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@ -585,7 +492,7 @@ class Siglip2PreTrainedModel(PreTrainedModel):
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, Siglip2MultiheadAttentionPoolingHead):
elif "MultiheadAttentionPoolingHead" in module.__class__.__name__:
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
@ -607,6 +514,105 @@ class Siglip2PreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].
Args:
config: Siglip2Config
"""
def __init__(self, config: Siglip2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
**kwargs,
)
return BaseModelOutput(last_hidden_state=hidden_states)
class Siglip2VisionTransformer(Siglip2PreTrainedModel):
_can_record_outputs = {
"hidden_states": Siglip2EncoderLayer,
"attentions": Siglip2Attention,
}
def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = Siglip2MultiheadAttentionPoolingHead(config)
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
hidden_states = self.embeddings(pixel_values, spatial_shapes)
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
else:
encoder_attention_mask = attention_mask
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class Siglip2TextEmbeddings(nn.Module):
def __init__(self, config: Siglip2TextConfig):
super().__init__()

View File

@ -37,6 +37,7 @@ from transformers.models.siglip.modeling_siglip import (
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...utils import auto_docstring, filter_out_non_signature_kwargs
from ...utils.generic import check_model_inputs
class Siglip2TextConfig(SiglipTextConfig):
@ -230,6 +231,10 @@ class Siglip2VisionEmbeddings(nn.Module):
return embeddings
class Siglip2PreTrainedModel(SiglipPreTrainedModel):
pass
class Siglip2VisionTransformer(SiglipVisionTransformer):
def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
@ -280,10 +285,6 @@ class Siglip2VisionTransformer(SiglipVisionTransformer):
)
class Siglip2PreTrainedModel(SiglipPreTrainedModel):
pass
class Siglip2TextModel(SiglipTextModel):
pass
@ -314,6 +315,8 @@ class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead):
class Siglip2VisionModel(SiglipVisionModel):
# Update: add `spatial_shapes` and `pixel_attention_mask`
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,

View File

@ -59,15 +59,11 @@ class HfQuantizer(ABC):
requires_parameters_quantization (`bool`):
Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is
required to create a new xxxParameter in order to properly quantize the model.
requires_full_weights (`bool`):
Whether the quantization method needs the full (non-sharded) weights for conversion. If set to `False`, only
the relevant tensor slices will be provided during weight loading.
"""
requires_calibration = False
required_packages = None
requires_parameters_quantization = False
requires_full_weights = True
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
self.quantization_config = quantization_config

View File

@ -75,8 +75,6 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
dtype = torch.float32
return dtype
# TODO: make this into a `ConversionType` ops -> potentially requires all weights on all ranks
# depending on the layer type (moe -> no if ep)
def create_quantized_param(
self,
model: "PreTrainedModel",
@ -95,9 +93,8 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
else:
return
# if tensor_name == "weight_scale_inv":
# raise ValueError("Expect unquantized weights but got a quantized weight_scale")
if tensor_name == "weight_scale_inv":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
param_value = param_value.to(target_device)
@ -140,10 +137,10 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
_load_parameter_into_model(model, param_name.rsplit(".", 1)[0] + ".weight_scale_inv", scale)
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
from ..integrations.finegrained_fp8 import FP8Expert, FP8Linear
from ..integrations.finegrained_fp8 import FP8Linear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, (FP8Linear, FP8Expert)):
if isinstance(module, FP8Linear):
if self.pre_quantized or tensor_name == "bias":
return False
else:
@ -185,9 +182,6 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
# TODO: similarly, just as we have a weight weight remapping we
# need to have a cleaner way to remap the quantized keys.
# 1. A SINGLE normal_key -> quantized keys used for ckpt renaming and for TP_plan as well
def update_tp_plan(self, config):
if "Qwen3" in config.__class__.__name__:
text_plan = {

View File

@ -877,7 +877,13 @@ def check_model_inputs(tie_last_hidden_states=True):
def wrapped_forward(*args, **kwargs):
if key == "hidden_states" and len(collected_outputs[key]) == 0:
collected_outputs[key] += (args[0],)
output = orig_forward(*args, **kwargs)
if kwargs.get("debug_io", False):
with model_addition_debugger_context(
module, kwargs.get("debug_io_dir", "~/model_debug"), kwargs.get("prune_layers")
):
output = orig_forward(*args, **kwargs)
else:
output = orig_forward(*args, **kwargs)
if not isinstance(output, tuple):
collected_outputs[key] += (output,)
elif output[index] is not None:
@ -918,13 +924,7 @@ def check_model_inputs(tie_last_hidden_states=True):
monkey_patched_layers.append((module, original_forward))
try:
if kwargs.get("debug_io", False):
with model_addition_debugger_context(
self, kwargs.get("debug_io_dir", "model_debug"), kwargs.get("prune_layers")
):
outputs = func(self, *args, **kwargs)
else:
outputs = func(self, *args, **kwargs)
outputs = func(self, *args, **kwargs)
except TypeError as original_exception:
# If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.
# Get a TypeError even after removing the recordable kwargs -> re-raise the original exception

View File

@ -1176,12 +1176,9 @@ def is_mistral_common_available() -> bool:
@lru_cache
def is_opentelemetry_available() -> bool:
try:
return _is_package_available("opentelemetry") and version.parse(
importlib.metadata.version("opentelemetry-api")
) >= version.parse("1.30.0")
except Exception as _:
return False
return _is_package_available("opentelemetry") and version.parse(
importlib.metadata.version("opentelemetry-api")
) >= version.parse("1.30.0")
def check_torch_load_is_safe() -> None:

View File

@ -1,236 +0,0 @@
import logging
import re
import shutil
import sys
from collections import OrderedDict, defaultdict
from collections.abc import Iterable
from typing import Any, Optional
_DIGIT_RX = re.compile(r"(?<=\.)(\d+)(?=\.|$)") # numbers between dots or at the end
def _pattern_of(key: str) -> str:
"""Replace every dot-delimited integer with '*' to get the structure."""
return _DIGIT_RX.sub("*", key)
def _fmt_indices(values: list[int]) -> str:
"""Format a list of ints as single number, {a, b, ...}, or first...last."""
if len(values) == 1:
return str(values[0])
values = sorted(values)
if len(values) > 10:
return f"{values[0]}...{values[-1]}"
return ", ".join(map(str, values))
def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]:
"""
Merge keys like 'layers.0.x', 'layers.1.x' into 'layers.{0, 1}.x'
BUT only merge together keys that have the exact same value.
Returns a new dict {merged_key: value}.
"""
# (pattern, value) -> list[set[int]] (per-star index values)
not_mapping = False
if not isinstance(mapping, dict):
mapping = {k: k for k in mapping}
not_mapping = True
bucket: dict[tuple[str, Any], list[set[int]]] = defaultdict(list)
for key, val in mapping.items():
digs = _DIGIT_RX.findall(key)
patt = _pattern_of(key)
for i, d in enumerate(digs):
if len(bucket[patt]) <= i:
bucket[patt].append(set())
bucket[patt][i].add(int(d))
bucket[patt].append(val)
out_items = {}
for patt, values in bucket.items():
sets, val = values[:-1], values[-1]
parts = patt.split("*") # stars are between parts
final = parts[0]
for i in range(1, len(parts)):
# i-1 is the star index before parts[i]
if i - 1 < len(sets) and sets[i - 1]:
insert = _fmt_indices(sorted(sets[i - 1]))
if len(sets[i - 1]) > 1:
final += "{" + insert + "}"
else:
final += insert
else:
# If no digits observed for this star position, keep a literal '*'
final += "*"
final += parts[i]
out_items[final] = val
# Stable ordering by merged key
out = OrderedDict(out_items)
if not_mapping:
return out.keys()
return out
class ANSI:
palette = {
"reset": "",
"red": "",
"yellow": "",
"orange": "",
"purple": "",
"bold": "",
"italic": "",
"dim": "",
}
def __init__(self, enable):
self.enable = enable
def __getitem__(self, key):
return self.palette[key] if self.enable else ""
_ansi_re = re.compile(r"\x1b\[[0-9;]*m")
def _strip_ansi(s: str) -> str:
return _ansi_re.sub("", str(s))
def _pad(text, width):
t = str(text)
pad = max(0, width - len(_strip_ansi(t)))
return t + " " * pad
def _make_table(rows, headers):
# compute display widths while ignoring ANSI codes
cols = list(zip(*([headers] + rows))) if rows else [headers]
widths = [max(len(_strip_ansi(x)) for x in col) for col in cols]
header_line = " | ".join(_pad(h, w) for h, w in zip(headers, widths))
sep_line = "-+-".join("-" * w for w in widths)
body = [" | ".join(_pad(c, w) for c, w in zip(r, widths)) for r in rows]
return "\n".join([header_line, sep_line] + body)
def _color(s, color, ansi):
# ansi returns empty strings when disabled, so safe to interpolate
return f"{ansi[color]}{s}{ansi['reset']}"
def _get_terminal_width(default=80):
try:
return shutil.get_terminal_size().columns
except Exception:
return default
def log_state_dict_report(
*,
model,
pretrained_model_name_or_path,
logger: Optional[logging.Logger] = None,
error_msgs: Optional[Iterable[str]] = None,
unexpected_keys=None,
missing_keys=None,
mismatched_keys=None,
mismatched_shapes=None,
ignore_mismatched_sizes=True,
misc=None,
limit_rows=50, # safety for huge checkpoints
color=True, # allow disabling for plain logs
min_width_full_table=60, # terminal min width to attempt full table
):
"""Log a readable report about state_dict loading issues.
This version is terminal-size aware: for very small terminals it falls back to a compact
Key | Status view so output doesn't wrap badly.
"""
if logger is None:
logger = logging.getLogger(__name__)
error_msgs = error_msgs or []
unexpected_keys = unexpected_keys or []
missing_keys = missing_keys or []
mismatched_keys = mismatched_keys or []
mismatched_shapes = mismatched_shapes or []
misc = misc or {}
# Detect whether the current stdout supports ANSI colors; allow callers to pass `color=False` to force no color
color_enabled = bool(color and sys.stdout.isatty())
ansi = ANSI(color_enabled)
if error_msgs:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` to `from_pretrained(...)` if appropriate."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
term_w = _get_terminal_width()
rows = []
if unexpected_keys:
for k in update_key_name(unexpected_keys):
status = "UNEXPECTED"
status = _color(status, "orange", ansi)
rows.append([k, status, "", ""])
if missing_keys:
for k in update_key_name(missing_keys):
status = "MISSING"
status = _color(status, "red", ansi)
rows.append([k, status, ""])
if mismatched_keys:
iterator = {a: (b, c) for a, b, c in mismatched_shapes}
for key, (shape_ckpt, shape_model) in update_key_name(iterator).items():
status = "MISMATCH"
status = _color(status, "yellow", ansi)
data = [key, status]
if term_w > limit_rows:
data.append(
" ".join(["Reinit due to size mismatch", f"ckpt: {str(shape_ckpt)} vs model:{str(shape_model)}"])
)
rows.append(data)
if misc:
for k, v in update_key_name(misc).items():
status = "MISC"
status = _color(status, "purple", ansi)
_details = v[:term_w]
rows.append([k, status, _details])
if not rows:
print(f"No key issues when initializing {model.__class__.__name__} from {pretrained_model_name_or_path}.")
return
headers = ["Key", "Status"]
if term_w > 200:
headers += ["Details"]
else:
headers += ["", ""]
table = _make_table(rows, headers=headers)
prelude = (
f"{ansi['bold']}{model.__class__.__name__} LOAD REPORT{ansi['reset']} from: {pretrained_model_name_or_path}\n"
)
tips = f"\n\n{ansi['italic']}Notes:"
if unexpected_keys:
tips += f"\n- {_color('UNEXPECTED', 'orange', ansi) + ansi['italic']}\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch."
if missing_keys:
tips += f"\n- {_color('MISSING', 'red', ansi) + ansi['italic']}\t:those params were newly initialized because missing form the checkpoint. Consider training on your downstream task."
if mismatched_keys:
tips += f"\n- {_color('MISMATCH', 'yellow', ansi) + ansi['italic']}\t:ckpt weights were loaded, but they did not match the original empty weight."
if misc:
tips += f"\n- {_color('MISC', 'purple', ansi) + ansi['italic']}\t:originate from the conversion scheme"
tips += f"{ansi['reset']}"
logger.warning(prelude + table + tips)
if not ignore_mismatched_sizes and mismatched_keys:
raise RuntimeError(
"You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!"
)

View File

@ -232,7 +232,7 @@ class AutoformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -539,7 +539,7 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@ -625,7 +625,7 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@ -708,7 +708,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -438,7 +438,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -297,7 +297,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -241,7 +241,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -246,7 +246,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -200,7 +200,7 @@ class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
_, info = FastSpeech2ConformerModel.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -618,7 +618,7 @@ class FastSpeech2ConformerWithHifiGanTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
_, info = FastSpeech2ConformerWithHifiGan.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -248,7 +248,7 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_ensure_weights_are_shared(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()

View File

@ -218,7 +218,7 @@ class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -316,7 +316,7 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -272,7 +272,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -246,7 +246,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -265,7 +265,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -462,7 +462,7 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -274,7 +274,7 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()

View File

@ -256,7 +256,7 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -276,7 +276,7 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):

View File

@ -208,7 +208,7 @@ class PatchTSTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):

View File

@ -253,7 +253,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -231,7 +231,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -261,7 +261,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -282,7 +282,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -354,7 +354,7 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, Generatio
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@ -859,7 +859,7 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@ -1359,7 +1359,7 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -205,7 +205,7 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -422,7 +422,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], set())
self.assertEqual(info["missing_keys"], [])
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -1924,7 +1924,7 @@ class ModelTesterMixin:
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking there was no complain of missing weights
self.assertEqual(infos["missing_keys"], set())
self.assertEqual(infos["missing_keys"], [])
# Checking the tensor sharing are correct
ptrs = defaultdict(list)
@ -1958,7 +1958,7 @@ class ModelTesterMixin:
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking there was no complain of missing weights
self.assertEqual(infos["missing_keys"], set())
self.assertEqual(infos["missing_keys"], [])
def test_tied_weights_keys(self):
original_config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -2485,17 +2485,17 @@ class ModelTesterMixin:
new_model = AutoModelForSequenceClassification.from_pretrained(
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
)
self.assertIn("Reinit due to size mismatch", cl.out)
self.assertIn("the shapes did not match", cl.out)
new_model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
logits = new_model(**inputs).logits
self.assertEqual(logits.shape[1], 2) # we still want to load :)
self.assertEqual(logits.shape[1], 42)
with CaptureLogger(logger) as cl:
new_model_without_prefix = AutoModel.from_pretrained(
tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
)
self.assertIn("Reinit due to size mismatch", cl.out)
self.assertIn("the shapes did not match", cl.out)
input_ids = ids_tensor((2, 8), 10)
new_model_without_prefix.to(torch_device)
if self.is_encoder_decoder:
@ -2536,7 +2536,7 @@ class ModelTesterMixin:
with CaptureLogger(logger) as cl:
new_model = model_class.from_pretrained(tmp_dir, num_labels=42, ignore_mismatched_sizes=True)
self.assertIn("Reinit due to size mismatch", cl.out)
self.assertIn("the shapes did not match", cl.out)
# Find the name of the module with the mismatched size
top_linear_modules = [
@ -3895,125 +3895,7 @@ class ModelTesterMixin:
):
self.assertEqual(k1, k2)
self.assertEqual(v1.dtype, v2.dtype)
self.assertTrue((v1 == v2).all())
@require_torch
def test_weight_conversion_operations_roundtrip():
import torch
from transformers.core_model_loading import (
Chunk,
Concatenate,
Fp8Dequantize,
Fp8Quantize,
MergeModuleList,
Shard,
WeightConversion,
convert_state_dict,
)
state_dict = {
"experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
"experts.1.w1.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]),
"experts.0.w3.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]),
"experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]),
"self_attn.q_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
"self_attn.k_proj.weight": torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
"self_attn.v_proj.weight": torch.tensor([[9.0, 10.0], [11.0, 12.0]]),
"self_attn.out_proj.weight": torch.arange(12.0).reshape(6, 2),
"mlp.w2.weight": torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
}
forward_mapping = [
WeightConversion(
["experts.*.w1.weight", "experts.*.w3.weight"],
"experts.gate_up_proj.weight",
[MergeModuleList(dim=0), Concatenate(dim=0), Fp8Quantize(block_size=(1, 1))],
),
WeightConversion(
["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"],
"self_attn.qkv_proj.weight",
Concatenate(dim=0),
),
WeightConversion(
"self_attn.out_proj.weight",
["self_attn.out_proj.weight.shard0", "self_attn.out_proj.weight.shard1"],
Shard(dim=0, world_size=2, return_all=True),
),
WeightConversion("mlp.w2.weight", "mlp.down_proj.weight"),
]
converted_state, _ = convert_state_dict(None, state_dict, forward_mapping, tp_plan=None, quantization_config=None)
expected_qkv = torch.cat(
(
state_dict["self_attn.q_proj.weight"],
state_dict["self_attn.k_proj.weight"],
state_dict["self_attn.v_proj.weight"],
),
dim=0,
)
torch.testing.assert_close(converted_state["self_attn.qkv_proj.weight"], expected_qkv)
reconstructed_out_proj = torch.cat(
(converted_state["self_attn.out_proj.weight.shard0"], converted_state["self_attn.out_proj.weight.shard1"]),
dim=0,
)
torch.testing.assert_close(reconstructed_out_proj, state_dict["self_attn.out_proj.weight"])
torch.testing.assert_close(converted_state["mlp.down_proj.weight"], state_dict["mlp.w2.weight"])
inverse_mapping = [
WeightConversion(
["experts.gate_up_proj.weight", "experts.gate_up_proj.scale"],
"experts.gate_up_proj.dequantized",
Fp8Dequantize(block_size=(1, 1)),
),
WeightConversion(
"experts.gate_up_proj.dequantized",
["experts.w1.concat", "experts.w3.concat"],
Chunk(dim=0, sizes=[4, 4]),
),
WeightConversion(
"experts.w1.concat",
["experts.0.w1.weight", "experts.1.w1.weight"],
Chunk(dim=0, sizes=[2, 2]),
),
WeightConversion(
"experts.w3.concat",
["experts.0.w3.weight", "experts.1.w3.weight"],
Chunk(dim=0, sizes=[2, 2]),
),
WeightConversion(
"self_attn.qkv_proj.weight",
[
"self_attn.q_proj.weight",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight",
],
Chunk(dim=0, sizes=[2, 2, 2]),
),
WeightConversion(
["self_attn.out_proj.weight.shard0", "self_attn.out_proj.weight.shard1"],
"self_attn.out_proj.weight",
Concatenate(dim=0),
),
WeightConversion("mlp.down_proj.weight", "mlp.w2.weight"),
]
roundtrip_state, _ = convert_state_dict(
None, converted_state, inverse_mapping, tp_plan=None, quantization_config=None
)
torch.testing.assert_close(roundtrip_state["experts.0.w1.weight"], state_dict["experts.0.w1.weight"])
torch.testing.assert_close(roundtrip_state["experts.1.w1.weight"], state_dict["experts.1.w1.weight"])
torch.testing.assert_close(roundtrip_state["experts.0.w3.weight"], state_dict["experts.0.w3.weight"])
torch.testing.assert_close(roundtrip_state["experts.1.w3.weight"], state_dict["experts.1.w3.weight"])
torch.testing.assert_close(roundtrip_state["self_attn.q_proj.weight"], state_dict["self_attn.q_proj.weight"])
torch.testing.assert_close(roundtrip_state["self_attn.k_proj.weight"], state_dict["self_attn.k_proj.weight"])
torch.testing.assert_close(roundtrip_state["self_attn.v_proj.weight"], state_dict["self_attn.v_proj.weight"])
torch.testing.assert_close(roundtrip_state["self_attn.out_proj.weight"], state_dict["self_attn.out_proj.weight"])
torch.testing.assert_close(roundtrip_state["mlp.w2.weight"], state_dict["mlp.w2.weight"])
self.assertTrue((v1 == v2).all())
global_rng = random.Random()

View File

@ -1,112 +0,0 @@
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.core_model_loading import build_glob_alt, match_glob
class TestWeightGlobMatching(unittest.TestCase):
def setUp(self):
self.weight_globs_digits = [
"model.layers.*.mlp.gate_up_proj.weight",
"model.layers.*.self_attn.q_proj.weight",
"embed_tokens.weight",
]
self.alt_digits, self.map_digits = build_glob_alt(self.weight_globs_digits, digits_only=True)
self.weight_globs_any = [
"model.layers.*.mlp.gate_up_proj.weight",
"model.layers.*.self_attn.q_proj.weight",
"embed_tokens.weight",
]
self.alt_any, self.map_any = build_glob_alt(self.weight_globs_any, digits_only=False)
def test_exact_match(self):
self.assertEqual(match_glob("embed_tokens.weight", self.alt_digits, self.map_digits), "embed_tokens.weight")
def test_digits_only_star_accepts_digits(self):
self.assertEqual(
match_glob("model.layers.0.mlp.gate_up_proj.weight", self.alt_digits, self.map_digits),
"model.layers.*.mlp.gate_up_proj.weight",
)
self.assertEqual(
match_glob("model.layers.12.self_attn.q_proj.weight", self.alt_digits, self.map_digits),
"model.layers.*.self_attn.q_proj.weight",
)
def test_digits_only_star_rejects_nondigits(self):
# 'a' is not digits, so it should not match with digits_only=True
self.assertIsNone(match_glob("model.layers.a.mlp.gate_up_proj.weight", self.alt_digits, self.map_digits))
def test_anychar_star_accepts_nondigits(self):
self.assertEqual(
match_glob("model.layers.a.mlp.gate_up_proj.weight", self.alt_any, self.map_any),
"model.layers.*.mlp.gate_up_proj.weight",
)
self.assertEqual(
match_glob("model.layers.00x.mlp.gate_up_proj.weight", self.alt_any, self.map_any),
"model.layers.*.mlp.gate_up_proj.weight",
)
def test_no_match(self):
self.assertIsNone(match_glob("model.layers.0.mlp.up_proj.weight", self.alt_digits, self.map_digits))
def test_leftmost_alternative_wins_for_overlapping_patterns(self):
# Overlapping patterns: both could match; ensure leftmost wins
globs = [
"model.layers.*.mlp.*.weight", # broader (first)
"model.layers.0.mlp.gate_up_proj.weight", # more specific (second)
]
alt, mapping = build_glob_alt(globs, digits_only=False)
# Both branches match; Python's regex picks the leftmost alternative → index 0
self.assertEqual(
match_glob("model.layers.0.mlp.gate_up_proj.weight", alt, mapping), "model.layers.*.mlp.*.weight"
)
def test_multiple_patterns_same_prefix(self):
globs = [
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
]
alt, mapping = build_glob_alt(globs, digits_only=True)
self.assertEqual(
match_glob("model.layers.3.self_attn.q_proj.weight", alt, mapping),
"model.layers.*.self_attn.q_proj.weight",
)
self.assertEqual(
match_glob("model.layers.3.self_attn.k_proj.weight", alt, mapping),
"model.layers.*.self_attn.k_proj.weight",
)
self.assertEqual(
match_glob("model.layers.3.self_attn.v_proj.weight", alt, mapping),
"model.layers.*.self_attn.v_proj.weight",
)
def test_anchor_full_match_only(self):
# Make sure partial strings don't match—anchors ^...$ are in each branch
self.assertIsNone(match_glob("foo.model.layers.0.mlp.gate_up_proj.weight.bar", self.alt_any, self.map_any))
def test_large_batch_performance_smoke(self):
# Not a perf benchmark, but ensures building and matching a larger alternation is OK
globs = [f"model.layers.*.mlp.block{i}.weight" for i in range(200)]
alt, mapping = build_glob_alt(globs, digits_only=True)
key = "model.layers.123.mlp.block57.weight"
self.assertEqual(match_glob(key, alt, mapping), "model.layers.*.mlp.block57.weight")
if __name__ == "__main__":
unittest.main()

View File

@ -1,216 +0,0 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import unittest
import torch
import torch.nn as nn
from transformers.core_model_loading import (
Chunk,
Concatenate,
MergeModulelist,
WeightConverter,
_apply_star_subst,
_glob_to_regex_src,
build_glob_alt,
convert_and_load_state_dict_in_model,
glob_to_re,
match_glob,
)
class TestGlobRegexHelpers(unittest.TestCase):
def test_glob_to_regex_src_digits_only(self):
pattern = _glob_to_regex_src("model.layers.*.mlp.weight", digits_only=True)
self.assertEqual(pattern, r"model\.layers\.(\d+)\.mlp\.weight")
def test_glob_to_regex_src_any_chars(self):
pattern = _glob_to_regex_src("model.layers.*.mlp.weight", digits_only=False)
self.assertEqual(pattern, r"model\.layers\.(.+)\.mlp\.weight")
def test_glob_to_re_fullmatch(self):
regex_src = glob_to_re("model.layers.*.mlp.weight", digits_only=True)
regex = re.compile(f"^{regex_src}$")
self.assertIsNotNone(regex.fullmatch("model.layers.12.mlp.weight"))
self.assertIsNone(regex.fullmatch("model.layers.foo.mlp.weight"))
def test_apply_star_subst(self):
pattern = "model.layers.*.block.*.weight"
replaced = _apply_star_subst(pattern, ["03", "attn"])
self.assertEqual(replaced, "model.layers.03.block.attn.weight")
def test_build_glob_alt_without_prefix(self):
globs = ["model.layers.*.weight"]
alt, mapping = build_glob_alt(globs, allow_prefix=False)
self.assertIsNone(match_glob("foo.model.layers.0.weight", alt, mapping))
self.assertEqual(match_glob("model.layers.0.weight", alt, mapping), "model.layers.*.weight")
def test_build_glob_alt_with_prefix(self):
globs = ["layers.*.weight"]
alt, mapping = build_glob_alt(globs, allow_prefix=True)
self.assertEqual(match_glob("model.layers.0.weight", alt, mapping), "layers.*.weight")
class DummyParamModule(nn.Module):
def __init__(self, shape):
super().__init__()
self.weight = nn.Parameter(torch.zeros(shape))
class DummySelfAttn(nn.Module):
def __init__(self):
super().__init__()
self.q_proj = DummyParamModule((1, 2))
self.k_proj = DummyParamModule((1, 2))
self.v_proj = DummyParamModule((1, 2))
class DummyExperts(nn.Module):
def __init__(self):
super().__init__()
self.gate_up_proj = DummyParamModule((2, 4, 2))
self.down_proj = DummyParamModule((2, 2, 2))
class DummyLayer(nn.Module):
def __init__(self):
super().__init__()
self.self_attn = DummySelfAttn()
self.experts = DummyExperts()
class DummyTopModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([DummyLayer(), DummyLayer()])
class DummyMLP(nn.Module):
def __init__(self):
super().__init__()
self.down_proj = DummyParamModule((2, 2))
class DummyRoot(nn.Module):
def __init__(self):
super().__init__()
self.model = DummyTopModel()
self.mlp = DummyMLP()
class TestConvertAndLoadStateDict(unittest.TestCase):
def test_moe_and_qkv_conversion(self):
model = DummyRoot()
raw_tensors = {
"model.layers.0.experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
"model.layers.0.experts.1.w1.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]),
"model.layers.0.experts.0.w3.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]),
"model.layers.0.experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]),
"model.layers.0.experts.0.w2.weight": torch.tensor([[20.0, 21.0], [22.0, 23.0]]),
"model.layers.0.experts.1.w2.weight": torch.tensor([[24.0, 25.0], [26.0, 27.0]]),
"model.layers.1.experts.0.w1.weight": torch.tensor([[30.0, 31.0], [32.0, 33.0]]),
"model.layers.1.experts.1.w1.weight": torch.tensor([[34.0, 35.0], [36.0, 37.0]]),
"model.layers.1.experts.0.w3.weight": torch.tensor([[38.0, 39.0], [40.0, 41.0]]),
"model.layers.1.experts.1.w3.weight": torch.tensor([[42.0, 43.0], [44.0, 45.0]]),
"model.layers.1.experts.0.w2.weight": torch.tensor([[46.0, 47.0], [48.0, 49.0]]),
"model.layers.1.experts.1.w2.weight": torch.tensor([[50.0, 51.0], [52.0, 53.0]]),
"model.layers.0.self_attn.qkv_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
"model.layers.1.self_attn.qkv_proj.weight": torch.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]),
"mlp.w2.weight": torch.tensor([[60.0, 61.0], [62.0, 63.0]]),
}
state_dict = {k: (k, v.clone()) for k, v in raw_tensors.items()}
weight_mapping = [
WeightConverter(
["model.layers.*.experts.*.w1.weight", "model.layers.*.experts.*.w3.weight"],
"model.layers.*.experts.gate_up_proj.weight",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
"model.layers.*.experts.*.w2.weight",
"model.layers.*.experts.down_proj.weight",
operations=[MergeModulelist(dim=0)],
),
WeightConverter(
"model.layers.*.self_attn.qkv_proj.weight",
[
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
],
operations=[Concatenate(dim=0), Chunk(dim=0, chunks=3)],
),
WeightConverter("mlp.w2.weight", "mlp.down_proj.weight"),
]
missing, unexpected, mismatch, misc = convert_and_load_state_dict_in_model(
model, state_dict, weight_mapping, tp_plan=None, quantizer=None
)
self.assertEqual(missing, set())
self.assertEqual(unexpected, set())
self.assertEqual(mismatch, set())
self.assertEqual(misc, {})
model_state = model.state_dict()
def cat_gate(layer_prefix: str) -> torch.Tensor:
w1 = [
raw_tensors[f"{layer_prefix}.experts.0.w1.weight"],
raw_tensors[f"{layer_prefix}.experts.1.w1.weight"],
]
w3 = [
raw_tensors[f"{layer_prefix}.experts.0.w3.weight"],
raw_tensors[f"{layer_prefix}.experts.1.w3.weight"],
]
return torch.cat([torch.stack(w1, dim=0), torch.stack(w3, dim=0)], dim=1)
torch.testing.assert_close(
model_state["model.layers.0.experts.gate_up_proj.weight"], cat_gate("model.layers.0")
)
torch.testing.assert_close(
model_state["model.layers.1.experts.gate_up_proj.weight"], cat_gate("model.layers.1")
)
def stack_down(layer_prefix: str) -> torch.Tensor:
return torch.stack(
[
raw_tensors[f"{layer_prefix}.experts.0.w2.weight"],
raw_tensors[f"{layer_prefix}.experts.1.w2.weight"],
],
dim=0,
)
torch.testing.assert_close(
model_state["model.layers.0.experts.down_proj.weight"], stack_down("model.layers.0")
)
torch.testing.assert_close(
model_state["model.layers.1.experts.down_proj.weight"], stack_down("model.layers.1")
)
for layer_idx in range(2):
key = f"model.layers.{layer_idx}.self_attn.qkv_proj.weight"
expected_q, expected_k, expected_v = torch.chunk(raw_tensors[key], chunks=3, dim=0)
prefix = f"model.layers.{layer_idx}.self_attn"
torch.testing.assert_close(model_state[f"{prefix}.q_proj.weight"], expected_q)
torch.testing.assert_close(model_state[f"{prefix}.k_proj.weight"], expected_k)
torch.testing.assert_close(model_state[f"{prefix}.v_proj.weight"], expected_v)
torch.testing.assert_close(model_state["mlp.down_proj.weight"], raw_tensors["mlp.w2.weight"])
if __name__ == "__main__":
unittest.main()

View File

@ -90,6 +90,8 @@ PRIVATE_MODELS = [
"Kosmos2_5TextForCausalLM",
"Kosmos2_5VisionModel",
"SmolVLMVisionTransformer",
"SiglipVisionTransformer",
"Siglip2VisionTransformer",
"AriaTextForCausalLM",
"AriaTextModel",
"Phi4MultimodalAudioModel",
@ -358,7 +360,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"SegGptForImageSegmentation",
"SiglipVisionModel",
"SiglipTextModel",
"SiglipVisionTransformer",
"Siglip2VisionModel",
"Siglip2VisionTransformer",
"Siglip2TextModel",
"ChameleonVQVAE", # no autoclass for VQ-VAE models
"VitPoseForPoseEstimation",