mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-10-31 09:04:37 +08:00 
			
		
		
		
	Compare commits
	
		
			14 Commits
		
	
	
		
			refactor-w
			...
			siglip_and
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3ee3c563dd | |||
| 02c324f43f | |||
| b47b35637f | |||
| f54d0db71d | |||
| 40a9dc87d3 | |||
| 91d34b0a99 | |||
| 448dd635e3 | |||
| 807983c2a7 | |||
| 5aa7610d12 | |||
| fe7c9228a4 | |||
| 082dcf21d1 | |||
| 4f93734169 | |||
| 76a14c7008 | |||
| ca68be8560 | 
| @ -1,83 +0,0 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
| from .core_model_loading import Concatenate, MergeModulelist, WeightConverter | ||||
|  | ||||
|  | ||||
| _checkpoint_conversion_mapping = { | ||||
|     "mixtral": [ | ||||
|         WeightConverter( | ||||
|             source_keys=[ | ||||
|                 "block_sparse_moe.experts.*.w1.weight", | ||||
|                 "block_sparse_moe.experts.*.w3.weight", | ||||
|             ],  # you give me a list of 2 keys, I collect a list of tensors | ||||
|             target_keys="mlp.experts.gate_up_proj",  # target key gets the list of two tensors | ||||
|             operations=[ | ||||
|                 MergeModulelist( | ||||
|                     dim=0 | ||||
|                 ),  # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors | ||||
|                 Concatenate(dim=1),  # each process has 2 tensors, gate and up, we concat them into gate_up | ||||
|             ],  # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first | ||||
|         ), | ||||
|         WeightConverter( | ||||
|             source_keys=[ | ||||
|                 "block_sparse_moe.experts.*.w2.weight", | ||||
|             ], | ||||
|             target_keys="mlp.experts.down_proj",  # target key gets the list of two tensors | ||||
|             operations=[ | ||||
|                 MergeModulelist( | ||||
|                     dim=0 | ||||
|                 ),  # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors | ||||
|             ],  # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first | ||||
|         ), | ||||
|         # WeightConverter( | ||||
|         #     ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], | ||||
|         #     "self_attn.qkv_proj", | ||||
|         #     Concatenate(dim=0),  # more like stack? | ||||
|         # ), | ||||
|         WeightConverter("*.block_sparse_moe.", "*.mlp."), | ||||
|     ], | ||||
|     "qwen2_moe": [ | ||||
|         WeightConverter( | ||||
|             source_keys=[ | ||||
|                 "mlp.experts.*.gate_proj.weight", | ||||
|                 "mlp.experts.*.up_proj.weight", | ||||
|             ], | ||||
|             target_keys="mlp.experts.gate_up_proj", | ||||
|             operations=[MergeModulelist(dim=0), Concatenate(dim=1)], | ||||
|         ), | ||||
|         WeightConverter( | ||||
|             source_keys=["mlp.experts.*.down_proj.weight"], | ||||
|             target_keys="mlp.experts.down_proj", | ||||
|             operations=[MergeModulelist(dim=0)], | ||||
|         ), | ||||
|     ], | ||||
| } | ||||
| _checkpoint_conversion_mapping["phimoe"] = _checkpoint_conversion_mapping["mixtral"].copy() | ||||
| _checkpoint_conversion_mapping["deepseek_v2"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["deepseek_v3"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["dot1"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["ernie_4_5_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["glm4_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["glm4v_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["jamba"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["lfm2_moe"] = _checkpoint_conversion_mapping["mixtral"].copy() | ||||
| _checkpoint_conversion_mapping["long_cat_flash"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["qwen3_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["qwen3_omni_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["qwen3_next"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["qwen3_vl_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["hunyuan_v1_moe"] = _checkpoint_conversion_mapping["qwen2_moe"].copy() | ||||
| _checkpoint_conversion_mapping["minimax"] = _checkpoint_conversion_mapping["mixtral"].copy() | ||||
| @ -1,576 +0,0 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright 2025 The HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """Core helpers for loading model checkpoints.""" | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import itertools | ||||
| import os | ||||
| import re | ||||
| import threading | ||||
| from abc import abstractmethod | ||||
| from collections import defaultdict | ||||
| from collections.abc import Sequence | ||||
| from concurrent.futures import Future, ThreadPoolExecutor | ||||
| from dataclasses import dataclass, field | ||||
| from functools import partial | ||||
| from typing import Any, Optional, Union | ||||
|  | ||||
| import torch | ||||
| from torch.distributed.tensor import DTensor | ||||
|  | ||||
| from .integrations.finegrained_fp8 import Fp8Quantize | ||||
| from .integrations.tensor_parallel import ALL_PARALLEL_STYLES | ||||
| from .utils import logging | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
|  | ||||
|  | ||||
| def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str: | ||||
|     """ | ||||
|     Convert a glob with '*' into a regex *source* string. | ||||
|     '*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing. | ||||
|     """ | ||||
|     star = r"(\d+)" if digits_only else r"(.+)" | ||||
|     return re.escape(glob).replace(r"\*", star) | ||||
|  | ||||
|  | ||||
| def build_glob_alt( | ||||
|     globs: list[str], | ||||
|     *, | ||||
|     digits_only: bool = True, | ||||
|     allow_prefix: bool = True, | ||||
| ) -> tuple[re.Pattern, dict[str, str]]: | ||||
|     """ | ||||
|     Build one compiled regex alternation with a named group per glob. | ||||
|     - digits_only: '*' => digits only (\\d+) if True, else any chars (.+) | ||||
|     - allow_prefix: if True, allow arbitrary prefix before the pattern | ||||
|                     (keeps '$' so we still require a full suffix match) | ||||
|     Returns (compiled_regex, name->glob map). | ||||
|     """ | ||||
|     name_map: dict[str, str] = {} | ||||
|     parts: list[str] = [] | ||||
|  | ||||
|     # If we keep using .match(), we must handle prefix allowance in the pattern itself. | ||||
|     prefix_src = r".*" if allow_prefix else r"^" | ||||
|  | ||||
|     for i, g in enumerate(globs): | ||||
|         name = f"g{i}" | ||||
|         name_map[name] = g | ||||
|         pat_src = _glob_to_regex_src(g, digits_only=digits_only) | ||||
|         # Each branch is fully wrapped and uniquely named. | ||||
|         parts.append(f"(?P<{name}>{prefix_src}{pat_src})") | ||||
|  | ||||
|     alt_src = "|".join(parts) | ||||
|     return re.compile(alt_src), name_map | ||||
|  | ||||
|  | ||||
| def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]: | ||||
|     """ | ||||
|     Match the key against the alternation; return the original glob string that matched. | ||||
|     """ | ||||
|     m = alt.match(key) | ||||
|     if not m: | ||||
|         return None | ||||
|     return name_map.get(m.lastgroup) | ||||
|  | ||||
|  | ||||
| class ConversionOps: | ||||
|     """Base class for weight conversion operations.""" | ||||
|  | ||||
|     # Reusable scratch buffer to avoid reallocations. | ||||
|     _buffer: Optional[torch.Tensor] = None | ||||
|     # The inverse operation class, will be used when saving the checkpoint | ||||
|     _inverse_op: type[ConversionOps] | ||||
|  | ||||
|     def _ensure_buffer( | ||||
|         self, | ||||
|         required_shape: torch.Size, | ||||
|         *, | ||||
|         dtype: torch.dtype, | ||||
|         device: torch.device, | ||||
|         growth_factor: float = 1.5, | ||||
|     ) -> torch.Tensor: | ||||
|         """Ensure a pre-allocated buffer large enough for ``required_shape`` exists.""" | ||||
|  | ||||
|         required_elems = 1 | ||||
|         for dim in required_shape: | ||||
|             required_elems *= int(dim) | ||||
|  | ||||
|         need_new = ( | ||||
|             self._buffer is None | ||||
|             or self._buffer.dtype != dtype | ||||
|             or self._buffer.device != device | ||||
|             or self._buffer.numel() < required_elems | ||||
|         ) | ||||
|  | ||||
|         if need_new: | ||||
|             capacity = max(required_elems, int(required_elems * growth_factor)) | ||||
|             self._buffer = torch.empty(capacity, dtype=dtype, device=device) | ||||
|  | ||||
|         return self._buffer[:required_elems].view(required_shape) | ||||
|  | ||||
|     def clear_cache(self) -> None: | ||||
|         """Free any cached buffers.""" | ||||
|         self._buffer = None | ||||
|  | ||||
|     @abstractmethod | ||||
|     def convert(self, value: Union[Sequence[torch.Tensor], torch.Tensor], *args, **kwargs) -> torch.Tensor: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class Chunk(ConversionOps): | ||||
|     """Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``.""" | ||||
|  | ||||
|     _inverse_op: type[ConversionOps] | ||||
|  | ||||
|     def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None): | ||||
|         if chunks is None and sizes is None: | ||||
|             raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.") | ||||
|         if chunks is not None and chunks <= 0: | ||||
|             raise ValueError("`chunks` must be a strictly positive integer.") | ||||
|         self.dim = dim | ||||
|         self.chunks = chunks | ||||
|         self.sizes = list(sizes) if sizes is not None else None | ||||
|         self._inverse_op = Concatenate | ||||
|  | ||||
|     def convert(self, value: torch.Tensor) -> list[torch.Tensor]: | ||||
|         if not isinstance(value, torch.Tensor): | ||||
|             raise TypeError("Chunk expects a torch.Tensor as input.") | ||||
|         if self.sizes is not None: | ||||
|             return list(torch.split(value, self.sizes, dim=self.dim)) | ||||
|         return list(torch.chunk(value, self.chunks, dim=self.dim)) | ||||
|  | ||||
|  | ||||
| class Concatenate(ConversionOps): | ||||
|     """Concatenate tensors along `dim` using a reusable buffer.""" | ||||
|  | ||||
|     _inverse_op: type[ConversionOps] | ||||
|  | ||||
|     def __init__(self, dim: int = 0): | ||||
|         self.dim = dim | ||||
|         self._inverse_op = Chunk | ||||
|  | ||||
|     @torch.no_grad | ||||
|     def convert(self, value: Sequence[torch.Tensor]) -> torch.Tensor: | ||||
|         if isinstance(value[0], list): | ||||
|             value = [v[0] for v in value] | ||||
|         tensors = value | ||||
|         if not tensors: | ||||
|             raise ValueError("Fuse requires at least one tensor to concatenate.") | ||||
|  | ||||
|         out_shape = list(tensors[0].shape) | ||||
|         out_shape[self.dim] = sum([t.size(self.dim) for t in tensors]) | ||||
|  | ||||
|         with torch.no_grad():  # we use staging buffers | ||||
|             out = self._ensure_buffer(torch.Size(out_shape), dtype=tensors[0].dtype, device=tensors[0].device) | ||||
|             torch.cat(tuple(tensors), dim=self.dim, out=out) | ||||
|             # offset = 0 | ||||
|             # for tensor in tensors: | ||||
|             #     index = [slice(None)] * tensor.ndim | ||||
|             #     index[self.dim] = slice(offset, offset + tensor.shape[self.dim]) | ||||
|             #     out[tuple(index)].copy_(tensor, non_blocking=tensor.is_cuda) | ||||
|             #     offset += tensor.shape[self.dim] | ||||
|         return out.clone()  # need to say I can overwrite this storage now | ||||
|  | ||||
|  | ||||
| class MergeModulelist(Concatenate): | ||||
|     """ | ||||
|     Merge a list of tensors into a single tensor along the first dimension. | ||||
|     We explicitly define this because for EP or TP you want to make sure you know what you are doing! | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, dim: int = 0): | ||||
|         super().__init__(dim=dim) | ||||
|         self._inverse_op = SplitModulelist | ||||
|  | ||||
|     def convert(self, value: Sequence[torch.Tensor]) -> list[torch.Tensor]: | ||||
|         merged = [] | ||||
|         with torch.no_grad():  # we use staging buffers | ||||
|             for group in value: | ||||
|                 if not isinstance(group, Sequence) or len(group) == 0: | ||||
|                     raise ValueError("MergeModulelist requires non-empty sub-sequences.") | ||||
|                 group = [k for k in group if k.ndim] | ||||
|                 out_shape = list(group[0].shape) | ||||
|                 out_shape.insert(self.dim, len(group)) | ||||
|                 out = self._ensure_buffer(torch.Size(out_shape), dtype=group[0].dtype, device=group[0].device) | ||||
|                 torch.stack(tuple(group), dim=self.dim, out=out) | ||||
|                 # for off, tensor in enumerate(group): | ||||
|                 #     out[off].copy_(tensor, non_blocking=tensor.is_cuda) | ||||
|                 # torch.as_tensor(numpy.stack(batch)) | ||||
|                 merged.append(out.clone())  # TODO have a single staging tensor here as well! | ||||
|         return merged | ||||
|  | ||||
|  | ||||
| class SplitModulelist(ConversionOps): | ||||
|     """Inverse of :class:`MergeModulelist` using explicit split sizes per group.""" | ||||
|  | ||||
|     def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0): | ||||
|         if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes): | ||||
|             raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.") | ||||
|         self.sizes = [list(sub) for sub in sizes] | ||||
|         self.dim = dim | ||||
|         self._inverse_op = MergeModulelist | ||||
|  | ||||
|     def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]: | ||||
|         if not isinstance(value, Sequence): | ||||
|             raise TypeError("SplitModulelist expects a sequence of tensors.") | ||||
|         if len(value) != len(self.sizes): | ||||
|             raise ValueError("Number of tensors does not match the provided split specifications.") | ||||
|  | ||||
|         result: list[list[torch.Tensor]] = [] | ||||
|         for tensor, split_sizes in zip(value, self.sizes): | ||||
|             if not isinstance(tensor, torch.Tensor): | ||||
|                 raise TypeError("SplitModulelist can only split torch.Tensor instances.") | ||||
|             splits = torch.split(tensor, split_sizes, dim=self.dim) | ||||
|             result.append(list(splits)) | ||||
|         return result | ||||
|  | ||||
|  | ||||
| class Cast(ConversionOps): | ||||
|     """ | ||||
|     Casts the tensor to a given dtype | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, dtype): | ||||
|         self.dtype = dtype | ||||
|  | ||||
|     def convert(self, realized_value): | ||||
|         return realized_value.to(self.dtype) | ||||
|  | ||||
|  | ||||
| class To(ConversionOps): | ||||
|     """ | ||||
|     Transfers the tensor to the provided device potentially using a stream? | ||||
|  | ||||
|     if param_device == "disk": | ||||
|         if not is_safetensors: | ||||
|             disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index) | ||||
|     elif not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name): | ||||
|         if is_fsdp_enabled(): | ||||
|             param_device = "cpu" if is_local_dist_rank_0() else "meta" | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, device): | ||||
|         self.device = device | ||||
|  | ||||
|     def convert(self, realized_value): | ||||
|         with torch.device(self.device): | ||||
|             out = [[x[...] for x in inner] if isinstance(inner, list) else inner[...] for inner in realized_value] | ||||
|         return out | ||||
|  | ||||
|  | ||||
| @dataclass(slots=True) | ||||
| class WeightConverter: | ||||
|     r""" | ||||
|     A weight convert that acts on a pattern of source keys. | ||||
|     The keys need to be collected based on the target keys. | ||||
|  | ||||
|     With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match: | ||||
|     `model.layers.*.experts.*` -> it will act on all of them | ||||
|     {"model.layers.*.experts.*": []} | ||||
|     but | ||||
|     `experts.*.mlp` will be layer specific. | ||||
|     {"model.layers.1.experts.*": [], } | ||||
|     - source_keys: str | list[str] (wildcards '*' match digits) | ||||
|     - target_keys: str | list[str] | None | ||||
|     - distributed_operation / operations / quantization_operations are ALWAYS lists. | ||||
|     """ | ||||
|  | ||||
|     source_keys: Union[str, list[str]] | ||||
|     target_keys: Optional[Union[str, list[str]]] = None | ||||
|     operations: list[ConversionOps] = field(default_factory=list, repr=False) | ||||
|  | ||||
|     distributed_operation: dict[str, ConversionOps] = field(default_factory=dict, compare=False, repr=False) | ||||
|     quantization_operation: dict[str, ConversionOps] = field(default_factory=dict, compare=False, repr=False) | ||||
|     _compiled: tuple[tuple[str, re.Pattern], ...] = field(default_factory=tuple, compare=False, repr=False) | ||||
|     _regex_pat: tuple[re.Pattern, dict[str, str]] = field(default_factory=tuple, compare=False, repr=False) | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         if not isinstance(self.source_keys, list): | ||||
|             self.source_keys = [self.source_keys] | ||||
|         if not isinstance(self.target_keys, list): | ||||
|             if self.target_keys is None: | ||||
|                 self.target_keys = self.source_keys | ||||
|             else: | ||||
|                 self.target_keys = [self.target_keys] | ||||
|         self._regex_pat = build_glob_alt(self.source_keys) | ||||
|  | ||||
|  | ||||
| def set_param_for_module( | ||||
|     model, k, v, meta_model_state_dict, empty_tensor, mismatch_keys, missing_keys, misc, distributed_operation | ||||
| ): | ||||
|     try: | ||||
|         module_path, _, param_name = k.rpartition(".") | ||||
|         module_obj = model.get_submodule(module_path) if module_path else model | ||||
|         param_value = v[0] if isinstance(v, list) else v[:] | ||||
|         ref = meta_model_state_dict.get(k, empty_tensor) | ||||
|         use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor | ||||
|         if not isinstance(param_value, torch.nn.Parameter): | ||||
|             if distributed_operation != {} and use_dtensor: | ||||
|                 param_value = DTensor.from_local( | ||||
|                     param_value, | ||||
|                     distributed_operation.device_mesh, | ||||
|                     distributed_operation.shard, | ||||
|                     run_check=False, | ||||
|                     shape=ref.size(), | ||||
|                     stride=ref.stride(), | ||||
|                 ) | ||||
|             else: | ||||
|                 pass  # TODO for "local" stuff, it will trigger missmatched no? | ||||
|             param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) | ||||
|  | ||||
|         if ref is not None and ref.shape != param_value.shape: | ||||
|             mismatch_keys.add((k, param_value.shape, ref.shape)) | ||||
|  | ||||
|         if k in missing_keys: | ||||
|             missing_keys.remove(k) | ||||
|  | ||||
|         setattr(module_obj, param_name, param_value) | ||||
|     except Exception as e: | ||||
|         misc[k] = f"{e} for {k} on {list(module_obj.state_dict().keys())}" | ||||
|  | ||||
|  | ||||
| @dataclass(slots=True) | ||||
| class ConversionEntry: | ||||
|     weight_converter: WeightConverter | ||||
|     collected_tensors: dict = field(default_factory=lambda: defaultdict(dict)) | ||||
|  | ||||
|  | ||||
| # Tune these to your storage: | ||||
| GLOBAL_WORKERS = min(32, (os.cpu_count() or 8) * 2)  # NVMe: 8-16; HDD/NFS: 2-4 | ||||
| PER_FILE_LIMIT = 4  # concurrent reads per file | ||||
|  | ||||
|  | ||||
| def _materialize_copy(x): | ||||
|     # PyTorch: this runs in C and releases the GIL; good for threads. | ||||
|     return x[...] | ||||
|  | ||||
|  | ||||
| def spawn_materialize(EXEC, _file_sems, file_id, t) -> Future: | ||||
|     sem = _file_sems[file_id] | ||||
|  | ||||
|     def _job(): | ||||
|         with sem: | ||||
|             return _materialize_copy(t) | ||||
|  | ||||
|     return EXEC.submit(_job) | ||||
|  | ||||
|  | ||||
| def spawn_tp_materialize(EXEC, _file_sems, file_id, t, sharding_method, empty_tensor, tensor_idx) -> Future: | ||||
|     sem = _file_sems[file_id] | ||||
|  | ||||
|     def _job(): | ||||
|         with sem: | ||||
|             return sharding_method.shard_tensor(t, empty_tensor, tensor_idx=tensor_idx)[0] | ||||
|  | ||||
|     return EXEC.submit(_job) | ||||
|  | ||||
|  | ||||
| def dot_natural_key(s: str): | ||||
|     parts = s.split(".") | ||||
|     for i, p in enumerate(parts): | ||||
|         # whole-segment digits -> int; otherwise leave as str | ||||
|         if p.isdigit(): | ||||
|             parts[i] = int(p) | ||||
|     return parts | ||||
|  | ||||
|  | ||||
| def convert_and_load_state_dict_in_model( | ||||
|     model, | ||||
|     state_dict, | ||||
|     weight_mapping, | ||||
|     tp_plan, | ||||
|     quantizer, | ||||
|     device_map=None, | ||||
|     keep_in_dtype=None, | ||||
|     device_mesh=None, | ||||
|     profile: bool = False, | ||||
| ): | ||||
|     """ | ||||
|     Convert a state dict according to a weight mapping (one WeightConverter per glob pattern), | ||||
|     collecting tensors per *layer instance* (the concrete indices captured from '*'). | ||||
|     """ | ||||
|     tp_plan = tp_plan or {}  # {glob_pattern: plan_obj_or_key} | ||||
|     device_map = device_map or {}  # {exact_target_key: device} | ||||
|     keep_in_dtype = keep_in_dtype or {}  # {glob_pattern: dtype} | ||||
|     weight_mapping = weight_mapping or {}  # {glob_pattern: WeightConverter} | ||||
|     meta_model_state_dict = model.state_dict() | ||||
|     missing_keys = set(meta_model_state_dict.keys()) | ||||
|     if model.config.tie_word_embeddings and "lm_head.weight" in missing_keys: | ||||
|         missing_keys.remove("lm_head.weight") | ||||
|  | ||||
|     misc = {} | ||||
|     mismatch_keys = set() | ||||
|     unexpected_keys = set() | ||||
|     # Global executor + per-file semaphores | ||||
|     EXEC = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS) | ||||
|     _file_sems = defaultdict(lambda: threading.Semaphore(PER_FILE_LIMIT)) | ||||
|  | ||||
|     _patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping])) | ||||
|     source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys} | ||||
|     weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns) | ||||
|     tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys())) | ||||
|     dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(keep_in_dtype.keys())) | ||||
|  | ||||
|     state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) | ||||
|     # 1. Create the conversion entries | ||||
|     by_conversion_pattern: dict[str, ConversionEntry] = {} | ||||
|     for original_key, (file_id, tensor) in state_dict: | ||||
|         matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name) | ||||
|         if matched_pattern is not None: | ||||
|             converter = source_to_target[matched_pattern]  # TODO make sure its the ref | ||||
|             sub_with_extractor = partial(re.sub, _glob_to_regex_src(matched_pattern), string=original_key) | ||||
|             entry_key = "|".join(converter.target_keys) | ||||
|             target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys])) | ||||
|             entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter)) | ||||
|             converter_key = sub_with_extractor(matched_pattern) | ||||
|         else: | ||||
|             converter = WeightConverter(original_key) | ||||
|             converter_key = entry_key = target_key = original_key | ||||
|             entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter)) | ||||
|  | ||||
|         prefix = model.base_model_prefix | ||||
|         new_target_key = [] | ||||
|         for t in target_key.split("|"):  # let's correct the keys | ||||
|             if t.startswith(prefix) and meta_model_state_dict.get(t.replace(f"{prefix}.", "")) is not None: | ||||
|                 t = t.replace(f"{prefix}.", "") | ||||
|             elif meta_model_state_dict.get(f"{prefix}.{t}") is not None: | ||||
|                 t = f"{prefix}.{t}" | ||||
|             new_target_key.append(t) | ||||
|         target_key = "|".join(new_target_key) | ||||
|  | ||||
|         for t in target_key.split("|"): | ||||
|             empty_tensor = meta_model_state_dict.get(t) | ||||
|             if empty_tensor is None: | ||||
|                 unexpected_keys.add(t) | ||||
|                 continue | ||||
|             if ( | ||||
|                 quantizer is not None | ||||
|                 and quantizer.param_needs_quantization(model, t) | ||||
|                 and quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer" | ||||
|             ): | ||||
|                 converter.quantization_operation[t] = Fp8Quantize()  # TODO support other methods | ||||
|             else: | ||||
|                 raise ValueError("This quantization method is gonna be supported SOOOON") | ||||
|  | ||||
|         first_target_key = target_key.split("|")[0] | ||||
|         future = None | ||||
|         if device_mesh: | ||||
|             if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name): | ||||
|                 empty_tensor = meta_model_state_dict.get(first_target_key) | ||||
|                 if getattr(converter, "distributed_operation", {}) == {}: | ||||
|                     converter.distributed_operation = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]] | ||||
|                     converter.distributed_operation.device_mesh = device_mesh | ||||
|                     converter.distributed_operation.rank = device_map[""].index | ||||
|                     converter.distributed_operation.empty_tensor = empty_tensor.clone() | ||||
|                 shard_index = len(entry.collected_tensors[target_key].get(converter_key, [])) | ||||
|                 future = spawn_tp_materialize( | ||||
|                     EXEC, _file_sems, file_id, tensor, converter.distributed_operation, empty_tensor, shard_index | ||||
|                 ) | ||||
|  | ||||
|         if future is None:  # If not TP, async move tensors | ||||
|             future = spawn_materialize(EXEC, _file_sems, file_id, tensor) | ||||
|         entry.collected_tensors[target_key].setdefault(converter_key, []).append(future) | ||||
|  | ||||
|     # 2. Actually convert the ckpt | ||||
|     inverse_converters = {} | ||||
|     keys = list(by_conversion_pattern.keys()) | ||||
|     total_layers = sum(len(by_conversion_pattern[key].collected_tensors) for key in keys) | ||||
|     progress_bar = logging.tqdm(total=total_layers, desc="Converting weights", leave=False) if total_layers else None | ||||
|  | ||||
|     try: | ||||
|         for key in keys[::-1]:  # revert to process simple keys first | ||||
|             group = by_conversion_pattern.pop(key) | ||||
|             converter = group.weight_converter | ||||
|             operations = converter.operations if isinstance(converter.operations, list) else [converter.operations] | ||||
|             for layer_name, tensors_for_this_layer in group.collected_tensors.items(): | ||||
|                 concrete_target_keys = layer_name.split("|") | ||||
|                 if bool(set(concrete_target_keys) - unexpected_keys): | ||||
|                     values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()] | ||||
|  | ||||
|                     for op in operations: | ||||
|                         try: | ||||
|                             values = op.convert(values) | ||||
|                         except Exception as e: | ||||
|                             misc[layer_name] = ( | ||||
|                                 f"{e}\nError: {op.__class__.__name__} on tensors collected from {converter.source_keys}. Ckpt contains: {values}" | ||||
|                             ) | ||||
|  | ||||
|                     values = [values] if not isinstance(values, list) else values | ||||
|                     realized_value = {k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys} | ||||
|  | ||||
|                     for k in list(realized_value.keys()).copy(): | ||||
|                         if op := converter.quantization_operation.get(k): | ||||
|                             try: | ||||
|                                 realized_value.update( | ||||
|                                     op.convert({k: realized_value.pop(k)}, quant_config=quantizer.quantization_config) | ||||
|                                 ) | ||||
|                             except Exception as e: | ||||
|                                 misc[layer_name] = f"{op.__class__.__name__}: {e}" | ||||
|  | ||||
|                     if progress_bar is not None: | ||||
|                         progress_bar.set_postfix_str(layer_name, refresh=False) | ||||
|                         progress_bar.update() | ||||
|  | ||||
|                     for k, output_value in realized_value.items(): | ||||
|                         matched_dtype_pattern = match_glob(k, dtype_policy_alt, dtype_policy_by_group_name) | ||||
|                         if matched_dtype_pattern is not None: | ||||
|                             op = Cast(keep_in_dtype[matched_dtype_pattern]) | ||||
|                             output_value = op(output_value) | ||||
|  | ||||
|                         for src in converter.source_keys:  # what should happen to k when we meet k at saving | ||||
|                             inverse_converters[k] = {src: converter} | ||||
|                         set_param_for_module( | ||||
|                             model, | ||||
|                             k, | ||||
|                             output_value, | ||||
|                             meta_model_state_dict, | ||||
|                             empty_tensor, | ||||
|                             mismatch_keys, | ||||
|                             missing_keys, | ||||
|                             misc, | ||||
|                             converter.distributed_operation, | ||||
|                         ) | ||||
|  | ||||
|             del group | ||||
|             for op in operations: | ||||
|                 op.clear_cache() | ||||
|     finally: | ||||
|         if progress_bar is not None: | ||||
|             progress_bar.close() | ||||
|     model.inverse_converters = inverse_converters | ||||
|     EXEC.shutdown(wait=True) | ||||
|     return missing_keys, unexpected_keys, mismatch_keys, misc | ||||
|  | ||||
|  | ||||
| # TODO this is not done yet! | ||||
| def revert_weight_conversion(model, state_dict): | ||||
|     reverse_key_mapping = getattr(model, "inverse_converters", {}) | ||||
|     original_state_dict = {} | ||||
|     for key, value in state_dict.items(): | ||||
|         for pattern, inverse_converter in reverse_key_mapping.items(): | ||||
|             # TODO FIXME you name it | ||||
|             replacement = inverse_converter.lstrip("^")  # strip off un-needed chars and patterns | ||||
|             replacement = re.sub(r"\(.*\)", "", replacement) | ||||
|             key, n_replace = re.subn(pattern, replacement, key) | ||||
|             # Early exit of the loop | ||||
|             if n_replace > 0: | ||||
|                 break | ||||
|         original_state_dict[key] = value | ||||
|     state_dict = original_state_dict | ||||
|     return state_dict | ||||
| @ -1635,12 +1635,7 @@ class GenerationMixin(ContinuousMixin): | ||||
|  | ||||
|         # TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions' | ||||
|         for key, value in model_kwargs.items(): | ||||
|             if ( | ||||
|                 value is not None | ||||
|                 and key not in model_args | ||||
|                 and key not in TransformersKwargs.__optional_keys__ | ||||
|                 and key != "debug_io" | ||||
|             ): | ||||
|             if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__: | ||||
|                 unused_model_args.append(key) | ||||
|  | ||||
|         if unused_model_args: | ||||
|  | ||||
| @ -512,8 +512,10 @@ def accelerate_disk_offload( | ||||
|     checkpoint_files, | ||||
|     device_map, | ||||
|     checkpoint_keys, | ||||
|     key_renaming_mapping, | ||||
|     sharded_metadata, | ||||
|     dtype, | ||||
|     reverse_key_renaming_mapping, | ||||
| ): | ||||
|     disk_only_shard_files = [] | ||||
|     if disk_offload_folder is not None: | ||||
| @ -532,13 +534,19 @@ def accelerate_disk_offload( | ||||
|             weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0]) | ||||
|         else: | ||||
|             folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) | ||||
|             # Fix the weight map keys according to the key mapping | ||||
|             weight_map = { | ||||
|                 key_renaming_mapping[k]: v | ||||
|                 for k, v in sharded_metadata["weight_map"].items() | ||||
|                 if k in key_renaming_mapping | ||||
|             } | ||||
|             weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()} | ||||
|             # Find potential checkpoints containing only offloaded weights | ||||
|             disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map) | ||||
|         disk_offload_index = { | ||||
|             name: { | ||||
|                 "safetensors_file": file, | ||||
|                 "weight_name": name, | ||||
|                 "weight_name": reverse_key_renaming_mapping[name], | ||||
|                 "dtype": str_dtype, | ||||
|             } | ||||
|             for name, file in weight_map.items() | ||||
|  | ||||
| @ -13,11 +13,8 @@ | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
| import re | ||||
| from collections.abc import Sequence | ||||
| from typing import Any, Optional, Union | ||||
| from typing import Optional | ||||
|  | ||||
| from ..core_model_loading import ConversionOps | ||||
| from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging | ||||
|  | ||||
|  | ||||
| @ -33,18 +30,6 @@ if is_accelerate_available(): | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
| try: | ||||
|     _FP8_DTYPE = torch.float8_e4m3fn | ||||
|     _FP8_MIN = torch.finfo(_FP8_DTYPE).min | ||||
|     _FP8_MAX = torch.finfo(_FP8_DTYPE).max | ||||
|     _FP8_IS_INT = False | ||||
| except AttributeError: | ||||
|     _FP8_DTYPE = torch.int8 | ||||
|     _FP8_MIN, _FP8_MAX = -127, 127 | ||||
|     _FP8_IS_INT = True | ||||
|     logger.warning_once( | ||||
|         "torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations." | ||||
|     ) | ||||
|  | ||||
|  | ||||
| # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py | ||||
| @ -347,12 +332,6 @@ class FP8Linear(nn.Linear): | ||||
|         if self.weight.element_size() > 1: | ||||
|             return F.linear(input, self.weight, self.bias) | ||||
|         else: | ||||
|             if isinstance(self.weight, torch.distributed.tensor.DTensor): | ||||
|                 weight = self.weight._local_tensor.contiguous() | ||||
|                 scale_inv = self.weight_scale_inv._local_tensor.contiguous() | ||||
|             else: | ||||
|                 weight = self.weight | ||||
|                 scale_inv = self.weight_scale_inv | ||||
|             # Context manager used to switch among the available accelerators | ||||
|             device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda" | ||||
|             torch_accelerator_module = getattr(torch, device_type, torch.cuda) | ||||
| @ -360,9 +339,9 @@ class FP8Linear(nn.Linear): | ||||
|                 qinput, scale = act_quant(input, self.block_size[1]) | ||||
|                 output = w8a8_block_fp8_matmul_triton( | ||||
|                     qinput, | ||||
|                     weight, | ||||
|                     self.weight, | ||||
|                     scale, | ||||
|                     scale_inv, | ||||
|                     self.weight_scale_inv, | ||||
|                     self.block_size, | ||||
|                     output_dtype=input.dtype, | ||||
|                 ) | ||||
| @ -374,120 +353,6 @@ class FP8Linear(nn.Linear): | ||||
|             return output.to(dtype=input.dtype) | ||||
|  | ||||
|  | ||||
| def _ceil_div(a, b): | ||||
|     return (a + b - 1) // b | ||||
|  | ||||
|  | ||||
| class FP8Expert(nn.Module): | ||||
|     dtype = torch.float8_e4m3fn | ||||
|  | ||||
|     def __init__(self, config, block_size, device): | ||||
|         super().__init__() | ||||
|  | ||||
|         from ..activations import ACT2FN | ||||
|  | ||||
|         self.block_size = block_size | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|  | ||||
|         Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim | ||||
|         Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim | ||||
|  | ||||
|         self.gate_up_proj = nn.Parameter( | ||||
|             torch.empty(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device) | ||||
|         ) | ||||
|         self.down_proj = nn.Parameter( | ||||
|             torch.empty(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device) | ||||
|         ) | ||||
|  | ||||
|         # Create inverse scale tiles only when using 1-byte types (fp8) | ||||
|         if self.gate_up_proj.element_size() == 1: | ||||
|             bo, bi = self.block_size | ||||
|  | ||||
|             # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi) | ||||
|             gu_scale_o = _ceil_div(Wg_out, bo) | ||||
|             gu_scale_i = _ceil_div(Wg_in, bi) | ||||
|             self.gate_up_proj_scales_inv = nn.Parameter( | ||||
|                 torch.empty(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device) | ||||
|             ) | ||||
|  | ||||
|             # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi) | ||||
|             dp_scale_o = _ceil_div(Wd_out, bo) | ||||
|             dp_scale_i = _ceil_div(Wd_in, bi) | ||||
|             self.down_proj_scales_inv = nn.Parameter( | ||||
|                 torch.empty(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device) | ||||
|             ) | ||||
|         else: | ||||
|             # Match FP8Linear behavior when not using 1-byte weights | ||||
|             self.register_parameter("gate_up_proj_scale_inv", None) | ||||
|             self.register_parameter("down_proj_scale_inv", None) | ||||
|  | ||||
|         # (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default | ||||
|         self.register_parameter("gate_up_bias", None) | ||||
|         self.register_parameter("down_bias", None) | ||||
|  | ||||
|         # Activation used in the MLP (same as your config / ACT2FN) | ||||
|         # Keep a handle here; actual usage happens in forward of your MoE block | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             # current_state = hidden_states[token_idx] | ||||
|             current_state = hidden_states.index_select(0, token_idx) | ||||
|             gate, up = self.linear( | ||||
|                 current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx] | ||||
|             ).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = self.linear( | ||||
|                 current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx] | ||||
|             ) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|         return final_hidden_states | ||||
|  | ||||
|     def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor: | ||||
|         if weight.element_size() > 1: | ||||
|             return F.linear(input, weight, self.bias) | ||||
|         else: | ||||
|             # Context manager used to switch among the available accelerators | ||||
|             device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda" | ||||
|             torch_accelerator_module = getattr(torch, device_type, torch.cuda) | ||||
|             with torch_accelerator_module.device(input.device): | ||||
|                 qinput, scale = act_quant(input, self.block_size[1]) | ||||
|                 output = w8a8_block_fp8_matmul_triton( | ||||
|                     qinput, | ||||
|                     weight, | ||||
|                     scale, | ||||
|                     weight_scale_inv, | ||||
|                     self.block_size, | ||||
|                     output_dtype=input.dtype, | ||||
|                 ) | ||||
|             # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the | ||||
|             # preceding operations are ready before proceeding | ||||
|             torch_accelerator_module.synchronize() | ||||
|             return output.to(dtype=input.dtype) | ||||
|  | ||||
|  | ||||
| # TODO: we do need this.... but not recursive... | ||||
| def _replace_with_fp8_linear( | ||||
|     model, | ||||
|     tp_plan=None, | ||||
| @ -496,48 +361,40 @@ def _replace_with_fp8_linear( | ||||
|     quantization_config=None, | ||||
|     has_been_replaced=False, | ||||
| ): | ||||
|     iterator = list(model.named_parameters()).copy() | ||||
|     for name, empty_tensor in iterator: | ||||
|         current_key_name = name | ||||
|         name = name.rsplit(".", 1)[0] if "." in name else name | ||||
|         module = model.get_submodule(name) | ||||
|     """Replace Linear layers with FP8Linear.""" | ||||
|     if current_key_name is None: | ||||
|         current_key_name = [] | ||||
|  | ||||
|         current_key_name_str = re.sub(r"\d+", "*", current_key_name) | ||||
|         if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): | ||||
|             with init_empty_weights(): | ||||
|                 if ( | ||||
|                     "gate_up_proj" in current_key_name | ||||
|                     or "down_proj" in current_key_name | ||||
|                     and "experts" in current_key_name | ||||
|                 ):  # Experts! | ||||
|                     in_features = empty_tensor.size(-2) | ||||
|                     out_features = empty_tensor.size(-1) | ||||
|                     model.set_submodule( | ||||
|                         name, | ||||
|                         FP8Expert( | ||||
|                             config=model.config, | ||||
|                             block_size=quantization_config.weight_block_size, | ||||
|                             device=empty_tensor.device, | ||||
|                         ), | ||||
|                     ) | ||||
|     for name, module in model.named_children(): | ||||
|         current_key_name.append(name) | ||||
|  | ||||
|                 elif isinstance(module, nn.Linear): | ||||
|                     in_features = module.in_features | ||||
|                     out_features = module.out_features | ||||
|                     model.set_submodule( | ||||
|                         name, | ||||
|                         FP8Linear( | ||||
|                             in_features=in_features, | ||||
|                             out_features=out_features, | ||||
|                             bias=module.bias is not None, | ||||
|                             device=module.weight.device, | ||||
|                             dtype=module.weight.dtype, | ||||
|                             activation_scheme=quantization_config.activation_scheme, | ||||
|                             block_size=quantization_config.weight_block_size, | ||||
|                         ), | ||||
|         if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []): | ||||
|             current_key_name_str = ".".join(current_key_name) | ||||
|             if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): | ||||
|                 with init_empty_weights(): | ||||
|                     model._modules[name] = FP8Linear( | ||||
|                         in_features=module.in_features, | ||||
|                         out_features=module.out_features, | ||||
|                         bias=module.bias is not None, | ||||
|                         device=module.weight.device, | ||||
|                         dtype=module.weight.dtype, | ||||
|                         activation_scheme=quantization_config.activation_scheme, | ||||
|                         block_size=quantization_config.weight_block_size, | ||||
|                     ) | ||||
|                 has_been_replaced = True | ||||
|         # when changing a layer the TP PLAN for that layer should be updated. TODO | ||||
|                     has_been_replaced = True | ||||
|             # when changing a layer the TP PLAN for that layer should be updated. TODO | ||||
|  | ||||
|         if len(list(module.children())) > 0: | ||||
|             _, has_been_replaced = _replace_with_fp8_linear( | ||||
|                 module, | ||||
|                 tp_plan, | ||||
|                 modules_to_not_convert, | ||||
|                 current_key_name, | ||||
|                 quantization_config, | ||||
|                 has_been_replaced=has_been_replaced, | ||||
|             ) | ||||
|  | ||||
|         current_key_name.pop(-1) | ||||
|  | ||||
|     return model, has_been_replaced | ||||
|  | ||||
| @ -548,7 +405,7 @@ def replace_with_fp8_linear( | ||||
|     quantization_config=None, | ||||
| ): | ||||
|     """Helper function to replace model layers with FP8 versions.""" | ||||
|     modules_to_not_convert += ["lm_head"] | ||||
|     modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert | ||||
|  | ||||
|     if quantization_config.modules_to_not_convert is not None: | ||||
|         modules_to_not_convert.extend(quantization_config.modules_to_not_convert) | ||||
| @ -567,133 +424,3 @@ def replace_with_fp8_linear( | ||||
|         ) | ||||
|  | ||||
|     return model | ||||
|  | ||||
|  | ||||
| class QuantizationOp(ConversionOps): | ||||
|     """Base class for quantization operations.""" | ||||
|  | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class Fp8Quantize(QuantizationOp): | ||||
|     """ | ||||
|     A quantization operation that creates two tensors, weight and scale out of a weight. | ||||
|     """ | ||||
|  | ||||
|     _inverse_op: type[ConversionOps] | ||||
|  | ||||
|     def __init__(self, block_size: Optional[tuple[int, int]] = None): | ||||
|         self.block_size = block_size | ||||
|         self._inverse_op = Fp8Dequantize | ||||
|  | ||||
|     def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) -> dict[str, torch.Tensor]: | ||||
|         # Unpack single key/value (value may be wrapped in a list) | ||||
|         target_keys, value = tuple(input_dict.items())[0] | ||||
|         value = value[0] if isinstance(value, list) else value | ||||
|  | ||||
|         # Resolve block size (support dict-like or attr-like quant_config) | ||||
|         block_size = None | ||||
|         if quant_config is not None: | ||||
|             if isinstance(quant_config, dict): | ||||
|                 block_size = quant_config.get("weight_block_size") | ||||
|             else: | ||||
|                 block_size = getattr(quant_config, "weight_block_size", None) | ||||
|         if block_size is None: | ||||
|             block_size = (value.shape[-2], value.shape[-1]) | ||||
|  | ||||
|         block_m, block_n = block_size | ||||
|         rows, cols = value.shape[-2], value.shape[-1] | ||||
|  | ||||
|         # Enforce exact tiling like your original | ||||
|         if rows % block_m != 0 or cols % block_n != 0: | ||||
|             raise ValueError( | ||||
|                 f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}" | ||||
|             ) | ||||
|  | ||||
|         # Leading dims can be empty (2D) or include num_experts/... (3D+) | ||||
|         leading_shape = value.shape[:-2] | ||||
|         rows_tiles = rows // block_m | ||||
|         cols_tiles = cols // block_n | ||||
|  | ||||
|         original_shape = value.shape | ||||
|         value_fp32 = value.to(torch.float32) | ||||
|  | ||||
|         # Reshape to (..., rows_tiles, block_m, cols_tiles, block_n) | ||||
|         reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n) | ||||
|  | ||||
|         # Per-tile max-abs over the block dims | ||||
|         # dims: block_m is at -3, block_n is at -1 after the reshape | ||||
|         max_abs = reshaped.abs().amax(dim=(-3, -1)) | ||||
|         safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs)) | ||||
|  | ||||
|         # Tile scale (we store inverse scale like your Linear: weight_scale_inv) | ||||
|         scales = _FP8_MAX / safe_max_abs | ||||
|         scales = torch.where(max_abs > 0, scales, torch.ones_like(scales))  # keep zeros stable | ||||
|  | ||||
|         # Broadcast scales back over the block dims and quantize | ||||
|         # max_abs/scales shape: (..., rows_tiles, cols_tiles) | ||||
|         scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3)  # -> (..., rows_tiles, 1, cols_tiles, 1) | ||||
|         scaled = reshaped * scales_broadcast | ||||
|  | ||||
|         if _FP8_IS_INT: | ||||
|             quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) | ||||
|         else: | ||||
|             quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) | ||||
|  | ||||
|         quantized = quantized.reshape(original_shape) | ||||
|  | ||||
|         inv_scales = (1.0 / scales).to(torch.float32)  # shape: (*leading, rows_tiles, cols_tiles) | ||||
|         if target_keys.endswith("weight"): | ||||
|             scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv" | ||||
|         else: | ||||
|             scale_key = target_keys + "_scales_inv" | ||||
|  | ||||
|         # Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts) | ||||
|         return { | ||||
|             target_keys: quantized, | ||||
|             scale_key: inv_scales, | ||||
|         } | ||||
|  | ||||
|  | ||||
| class Fp8Dequantize(QuantizationOp): | ||||
|     """Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor.""" | ||||
|  | ||||
|     def __init__(self, block_size: Optional[tuple[int, int]] = None): | ||||
|         self.block_size = block_size | ||||
|         self._inverse_op = Fp8Quantize | ||||
|  | ||||
|     def convert( | ||||
|         self, | ||||
|         value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]], | ||||
|         *, | ||||
|         context: dict[str, Any], | ||||
|     ) -> torch.Tensor: | ||||
|         if isinstance(value, dict): | ||||
|             tensors = list(value.values()) | ||||
|         else: | ||||
|             tensors = list(value) if isinstance(value, Sequence) else [value] | ||||
|         if len(tensors) != 2: | ||||
|             raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.") | ||||
|         quantized, scales = tensors | ||||
|         if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor): | ||||
|             raise TypeError("Fp8Dequantize expects tensors as inputs.") | ||||
|  | ||||
|         quantized_fp32 = quantized.to(torch.float32) | ||||
|         rows, cols = quantized_fp32.shape[-2:] | ||||
|         block_size = self.block_size | ||||
|         if block_size is None: | ||||
|             quant_config = context.get("quantization_config") | ||||
|             block_size = getattr(quant_config, "weight_block_size", None) | ||||
|         if block_size is None: | ||||
|             block_size = (rows, cols) | ||||
|         block_m, block_n = block_size | ||||
|         if rows % block_m != 0 or cols % block_n != 0: | ||||
|             raise ValueError( | ||||
|                 f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})." | ||||
|             ) | ||||
|  | ||||
|         reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) | ||||
|         expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n) | ||||
|         expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2) | ||||
|         dequantized = reshaped * expanded_scales | ||||
|         return dequantized.reshape(quantized_fp32.shape) | ||||
|  | ||||
| @ -236,7 +236,7 @@ class PeftAdapterMixin: | ||||
|                 **adapter_kwargs, | ||||
|             ) | ||||
|             peft_config.inference_mode = not is_trainable | ||||
|         # TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE! | ||||
|  | ||||
|         # Create and add fresh new adapters into the model. | ||||
|         inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs) | ||||
|  | ||||
|  | ||||
| @ -18,7 +18,6 @@ import operator | ||||
| import os | ||||
| import re | ||||
| from functools import partial, reduce | ||||
| from typing import Optional | ||||
|  | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| @ -307,7 +306,7 @@ def repack_weights( | ||||
|     return final_ordered_tensor | ||||
|  | ||||
|  | ||||
| def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None): | ||||
| def get_tensor_shard(param, empty_param, device_mesh, rank, dim): | ||||
|     """ | ||||
|     Generalized tensor sharding across a multi-dimensional device mesh. | ||||
|     Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`. | ||||
| @ -359,57 +358,32 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Opt | ||||
|         rank (int): Global rank of the current process/device. | ||||
|         dim (int): Dimension along which to shard the tensor. | ||||
|     """ | ||||
|     param_dim = empty_param.ndim | ||||
|     param_dim = empty_param.dim() | ||||
|  | ||||
|     if dim < 0: | ||||
|         dim = param_dim + dim | ||||
|     if dim >= param_dim: | ||||
|         raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}") | ||||
|  | ||||
|     # Flatten the mesh to get the total number of devices | ||||
|     mesh_shape = device_mesh.shape | ||||
|     world_size = reduce(operator.mul, mesh_shape) | ||||
|     if dim < 0: | ||||
|         dim = param_dim + dim | ||||
|     if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2: | ||||
|         dim = 0 | ||||
|     elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2: | ||||
|         dim = 0 | ||||
|  | ||||
|     shard_size = math.ceil(empty_param.size(dim) / world_size) | ||||
|     start = rank * shard_size | ||||
|     end = min(start + shard_size, empty_param.size(dim)) | ||||
|  | ||||
|     if dim >= param_dim: | ||||
|         raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}") | ||||
|  | ||||
|     if rank >= world_size: | ||||
|         raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}") | ||||
|  | ||||
|     # we have the full tensor not 1 part of it. | ||||
|     # in that case, we just assume that the weight was properly saved | ||||
|     # and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise | ||||
|     # to inform that it needs to read form a packed tensor. It will also take care of the module list thingy. | ||||
|     # here we take care of potential chunking / layer split / layer chunking. | ||||
|     # The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case | ||||
|     # actually we still shard dim=0 does not change | ||||
|     # so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the | ||||
|     # tensor on a certain device (with the input tensor_index) | ||||
|     dimensions = param.get_shape() | ||||
|     shard_size = math.ceil(empty_param.shape[dim] / world_size) | ||||
|     start = rank * shard_size | ||||
|  | ||||
|     if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2: | ||||
|         # special case we don't "shard" just send this entire tensor to the correct rank. | ||||
|         if start <= tensor_idx < end: | ||||
|             # this tensor does need to be materialized on this device: | ||||
|             return param[:] | ||||
|         else: | ||||
|             return torch.empty([], dtype=torch.int64, device=rank) | ||||
|  | ||||
|     slice_indices = [slice(None)] * len(param.get_shape()) | ||||
|  | ||||
|     if start < param.get_shape()[dim]: | ||||
|     # Construct slicing index dynamically | ||||
|     end = min(start + shard_size, empty_param.shape[dim]) | ||||
|     slice_indices = [slice(None)] * param_dim | ||||
|     if start < empty_param.shape[dim]: | ||||
|         slice_indices[dim] = slice(start, end) | ||||
|         param = param[tuple(slice_indices)] | ||||
|         if isinstance(param, list):  # TODO handle the modulelist case! | ||||
|             param = [p[:] for p in param] | ||||
|         return param | ||||
|  | ||||
|         return param[tuple(slice_indices)] | ||||
|     dimensions = list(param.shape) | ||||
|     dimensions[dim] = 0 | ||||
|     return torch.empty(tuple(dimensions), dtype=torch.int64)  # empty allocates memory.... | ||||
|     return torch.empty(tuple(dimensions), dtype=torch.int64) | ||||
|  | ||||
|  | ||||
| def distribute_module( | ||||
| @ -436,8 +410,6 @@ class TensorParallelLayer: | ||||
|     """ | ||||
|  | ||||
|     use_dtensor = True | ||||
|     device_mes = None | ||||
|     rank = None | ||||
|  | ||||
|     @staticmethod | ||||
|     def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ... | ||||
| @ -565,9 +537,6 @@ class ReplicateParallel(TensorParallelLayer): | ||||
|     def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): | ||||
|         return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs | ||||
|  | ||||
|     def shard_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): | ||||
|         return param[...].to(param_casting_dtype) | ||||
|  | ||||
|     def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): | ||||
|         param = param[...].to(param_casting_dtype) | ||||
|         if to_contiguous: | ||||
| @ -609,25 +578,17 @@ class ColwiseParallel(TensorParallelLayer): | ||||
|             input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False) | ||||
|         return input_tensor | ||||
|  | ||||
|     def shard_tensor(self, param, empty_param, param_type=None, tensor_idx=None): | ||||
|         device_mesh = self.device_mesh | ||||
|         rank = self.rank | ||||
|         if param_type == "bias": | ||||
|             parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx) | ||||
|             shard = [Shard(-1)] | ||||
|         else: | ||||
|             shard = [Shard(-2)] | ||||
|             parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx) | ||||
|         self.shard = shard | ||||
|         return parameter, shard | ||||
|  | ||||
|     def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): | ||||
|         # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) | ||||
|         # means Colwise as Linear is input * weight^T + bias, where | ||||
|         # weight would become Shard(1) | ||||
|         parameter, shard = self.shard_tensor( | ||||
|             param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh | ||||
|         ) | ||||
|         if param_type == "bias": | ||||
|             parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) | ||||
|             shard = [Shard(-1)] | ||||
|         else: | ||||
|             shard = [Shard(-2)] | ||||
|             parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2) | ||||
|  | ||||
|         parameter = parameter.to(param_casting_dtype) | ||||
|         if to_contiguous: | ||||
|             parameter = parameter.contiguous() | ||||
| @ -647,14 +608,6 @@ class ColwiseParallel(TensorParallelLayer): | ||||
|  | ||||
|  | ||||
| class PackedColwiseParallel(ColwiseParallel): | ||||
|     def shard_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): | ||||
|         return get_packed_weights(param, empty_param, device_mesh, rank, -2), [Shard(-2)] | ||||
|  | ||||
|     def create_nn_parameter( | ||||
|         self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh | ||||
|     ): | ||||
|         return nn.Parameter(param, requires_grad=param.is_floating_point()) | ||||
|  | ||||
|     def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): | ||||
|         # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) | ||||
|         # means Colwise as Linear is input * weight^T + bias, where | ||||
| @ -701,18 +654,6 @@ class RowwiseParallel(TensorParallelLayer): | ||||
|         self.use_local_output = use_local_output | ||||
|         self.use_dtensor = use_dtensor | ||||
|  | ||||
|     def shard_tensor(self, param, empty_param, param_type=None, tensor_idx=None): | ||||
|         device_mesh = self.device_mesh | ||||
|         rank = self.rank | ||||
|         if param_type == "bias": | ||||
|             shard = [Replicate()] | ||||
|             parameter = param[:] | ||||
|         else: | ||||
|             parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx) | ||||
|             shard = [Shard(-1)] | ||||
|         self.shard = shard | ||||
|         return parameter, shard | ||||
|  | ||||
|     def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): | ||||
|         # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) | ||||
|         # means Rowwise as nn.Linear is input * weight^T + bias, where | ||||
| @ -784,9 +725,6 @@ class RowwiseParallel(TensorParallelLayer): | ||||
|  | ||||
|  | ||||
| class PackedRowwiseParallel(RowwiseParallel): | ||||
|     def shard_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): | ||||
|         return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)] | ||||
|  | ||||
|     def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): | ||||
|         # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) | ||||
|         # means Colwise as Linear is input * weight^T + bias, where | ||||
| @ -979,9 +917,6 @@ class RouterParallel(TensorParallelLayer): | ||||
|         )  # masking class for one hot | ||||
|         return router_scores, router_indices | ||||
|  | ||||
|     def shard_tensor(self, param, *args, **kwargs): | ||||
|         return param[:], None | ||||
|  | ||||
|     def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): | ||||
|         # TODO: i'd like for this to be the default | ||||
|         param = param[...].to(param_casting_dtype) | ||||
|  | ||||
| @ -26,7 +26,7 @@ import sys | ||||
| import warnings | ||||
| from abc import abstractmethod | ||||
| from collections import defaultdict | ||||
| from collections.abc import Callable, Sequence | ||||
| from collections.abc import Callable | ||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | ||||
| from contextlib import contextmanager | ||||
| from enum import Enum | ||||
| @ -45,17 +45,17 @@ from torch.distributions import constraints | ||||
| from torch.utils.checkpoint import checkpoint | ||||
|  | ||||
| from .configuration_utils import PreTrainedConfig | ||||
| from .conversion_mapping import _checkpoint_conversion_mapping as DEFAULT_WEIGHT_CONVERSION_MAPPING | ||||
| from .core_model_loading import WeightConverter, convert_and_load_state_dict_in_model, revert_weight_conversion | ||||
| from .distributed import DistributedConfig | ||||
| from .dynamic_module_utils import custom_object_save | ||||
| from .generation import CompileConfig, GenerationConfig | ||||
| from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled | ||||
| from .integrations.accelerate import ( | ||||
|     _get_device_map, | ||||
|     accelerate_disk_offload, | ||||
|     accelerate_dispatch, | ||||
|     check_and_set_device_map, | ||||
|     expand_device_map, | ||||
|     find_tied_parameters, | ||||
|     init_empty_weights, | ||||
| ) | ||||
| from .integrations.deepspeed import _load_state_dict_into_zero3_model | ||||
| @ -122,7 +122,6 @@ from .utils.import_utils import ( | ||||
|     is_sagemaker_mp_enabled, | ||||
|     is_tracing, | ||||
| ) | ||||
| from .utils.loading_report import log_state_dict_report | ||||
| from .utils.quantization_config import QuantizationMethod | ||||
|  | ||||
|  | ||||
| @ -131,6 +130,7 @@ if is_accelerate_available(): | ||||
|     from accelerate.utils import ( | ||||
|         extract_model_from_parallel, | ||||
|         offload_weight, | ||||
|         save_offload_index, | ||||
|     ) | ||||
|     from accelerate.utils.modeling import get_state_dict_from_offload | ||||
|  | ||||
| @ -730,6 +730,25 @@ def load_shard_file(args): | ||||
|     # Fix the key names | ||||
|     state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} | ||||
|  | ||||
|     error_msgs = [] | ||||
|     if is_deepspeed_zero3_enabled() and not is_quantized: | ||||
|         error_msgs += _load_state_dict_into_zero3_model(model, state_dict) | ||||
|     # Skip it with fsdp on ranks other than 0 | ||||
|     elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): | ||||
|         disk_offload_index = _load_state_dict_into_meta_model( | ||||
|             model, | ||||
|             state_dict, | ||||
|             shard_file, | ||||
|             reverse_key_renaming_mapping, | ||||
|             device_map=device_map, | ||||
|             disk_offload_folder=disk_offload_folder, | ||||
|             disk_offload_index=disk_offload_index, | ||||
|             hf_quantizer=hf_quantizer, | ||||
|             device_mesh=device_mesh, | ||||
|         ) | ||||
|  | ||||
|     return error_msgs, disk_offload_index | ||||
|  | ||||
|  | ||||
| def load_shard_files_with_threadpool(args_list): | ||||
|     num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) | ||||
| @ -1155,6 +1174,104 @@ def _get_dtype( | ||||
|     return config, dtype, dtype_orig | ||||
|  | ||||
|  | ||||
| def _find_missing_and_unexpected_keys( | ||||
|     model: "PreTrainedModel", | ||||
|     original_checkpoint_keys: list[str], | ||||
|     checkpoint_keys: list[str], | ||||
|     loading_base_model_from_task_state_dict: bool, | ||||
|     hf_quantizer: Optional[HfQuantizer], | ||||
| ) -> tuple[list[str], list[str]]: | ||||
|     """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys | ||||
|     (keys found in the loaded state dict keys, but that are NOT part of the model parameters) | ||||
|     """ | ||||
|     prefix = model.base_model_prefix | ||||
|  | ||||
|     # Compute expected keys, i.e. keys that the full model expects | ||||
|     expected_keys = list(model.state_dict().keys()) | ||||
|     if hf_quantizer is not None: | ||||
|         expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys) | ||||
|  | ||||
|     # Adjust prefix of the keys to make them match loaded keys before removing them | ||||
|     missing_keys = sorted(set(expected_keys) - set(checkpoint_keys)) | ||||
|     unexpected_keys = set(checkpoint_keys) - set(expected_keys) | ||||
|     # If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys | ||||
|     if loading_base_model_from_task_state_dict: | ||||
|         task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")] | ||||
|         unexpected_keys.update(task_specific_keys) | ||||
|  | ||||
|     # Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but | ||||
|     # may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway | ||||
|     model_buffers = {n for n, _ in model.named_buffers()} | ||||
|     unexpected_keys = sorted(unexpected_keys - model_buffers) | ||||
|  | ||||
|     tied_params = find_tied_parameters(model) | ||||
|     for group in tied_params: | ||||
|         missing_in_group = [k for k in missing_keys if k in group] | ||||
|         if len(missing_in_group) > 0 and len(missing_in_group) < len(group): | ||||
|             missing_keys = [k for k in missing_keys if k not in missing_in_group] | ||||
|  | ||||
|     if hf_quantizer is not None: | ||||
|         missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) | ||||
|         unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys) | ||||
|  | ||||
|     return missing_keys, unexpected_keys | ||||
|  | ||||
|  | ||||
| def _find_mismatched_keys( | ||||
|     model: "PreTrainedModel", | ||||
|     state_dict: Optional[dict], | ||||
|     checkpoint_files: Optional[list[str]], | ||||
|     ignore_mismatched_sizes: bool, | ||||
|     keys_to_rename_mapping: dict[str, str], | ||||
|     is_quantized: bool, | ||||
|     weights_only: bool, | ||||
| ) -> tuple[list[str], list[tuple[int, int]]]: | ||||
|     """ | ||||
|     Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes` | ||||
|     is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking | ||||
|     every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do | ||||
|     need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize | ||||
|     correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the | ||||
|     case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform | ||||
|     this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the | ||||
|     mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be | ||||
|     initialized, not only the weights that are mismatched). | ||||
|     """ | ||||
|  | ||||
|     # An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function | ||||
|     # if there are no mismatch (which is almost always the case) | ||||
|     if not ignore_mismatched_sizes: | ||||
|         return [], [] | ||||
|  | ||||
|     if state_dict is not None: | ||||
|         checkpoint_files = [""] | ||||
|  | ||||
|     model_state_dict = model.state_dict() | ||||
|     mismatched_keys = [] | ||||
|     mismatched_shapes = [] | ||||
|     for shard_file in checkpoint_files: | ||||
|         # If shard_file is "", we use the existing state_dict instead of loading it | ||||
|         if shard_file != "": | ||||
|             state_dict = load_state_dict( | ||||
|                 shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only | ||||
|             ) | ||||
|  | ||||
|         # Fix the key names | ||||
|         new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping} | ||||
|  | ||||
|         for key, tensor in new_state_dict.items(): | ||||
|             if key in model_state_dict and tensor.shape != model_state_dict[key].shape: | ||||
|                 # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. | ||||
|                 # Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights. | ||||
|                 if not ( | ||||
|                     is_quantized and tensor.shape[-1] == 1 and tensor.numel() * 2 == model_state_dict[key].numel() | ||||
|                 ): | ||||
|                     mismatched_keys.append(key) | ||||
|                     mismatched_shapes.append((tensor.shape, model_state_dict[key].shape)) | ||||
|  | ||||
|     return mismatched_keys, mismatched_shapes | ||||
|  | ||||
|  | ||||
| class PipelineParallel(Enum): | ||||
|     inputs = 0 | ||||
|     outputs = 1 | ||||
| @ -1560,8 +1677,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|     # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag | ||||
|     _keep_in_fp32_modules_strict = None | ||||
|  | ||||
|     _dtype_per_modules: Optional[dict[str, torch.dtype]] = None | ||||
|  | ||||
|     # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing | ||||
|     # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. | ||||
|     _keys_to_ignore_on_load_missing = None | ||||
| @ -1732,9 +1847,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|         self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) | ||||
|         self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict) | ||||
|  | ||||
|         if isinstance(self._keep_in_fp32_modules, dict): | ||||
|             self._dtype_per_modules = dict.fromkeys(self._keep_in_fp32_modules.keys(), torch.float32) | ||||
|  | ||||
|         self._no_split_modules = self._no_split_modules or [] | ||||
|         _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs  # added for executorch support only | ||||
|  | ||||
| @ -2525,34 +2637,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|             # 0.02 is the standard default value across the library | ||||
|             std = getattr(self.config.get_text_config(), "initializer_range", 0.02) | ||||
|  | ||||
|         try: | ||||
|             if isinstance( | ||||
|                 module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d) | ||||
|             ): | ||||
|                 module.weight.data.normal_(mean=0.0, std=std) | ||||
|                 if module.bias is not None: | ||||
|                     module.bias.data.zero_() | ||||
|             elif isinstance(module, nn.Embedding): | ||||
|                 module.weight.data.normal_(mean=0.0, std=std) | ||||
|                 if module.padding_idx is not None: | ||||
|                     module.weight.data[module.padding_idx].zero_() | ||||
|             elif isinstance(module, nn.MultiheadAttention): | ||||
|                 # This uses torch's original init | ||||
|                 module._reset_parameters() | ||||
|             # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names | ||||
|             # between modelings (because they are prefixed with the model name) | ||||
|             elif ( | ||||
|                 isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) | ||||
|                 or "LayerNorm" in module.__class__.__name__ | ||||
|                 or "RMSNorm" in module.__class__.__name__ | ||||
|             ): | ||||
|                 # Norms can exist without weights (in which case they are None from torch primitives) | ||||
|                 if hasattr(module, "weight") and module.weight is not None: | ||||
|                     module.weight.data.fill_(1.0) | ||||
|                 if hasattr(module, "bias") and module.bias is not None: | ||||
|                     module.bias.data.zero_() | ||||
|         except Exception as e: | ||||
|             logger.warning_once(f"Failed to init: {str(e)}") | ||||
|         if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)): | ||||
|             module.weight.data.normal_(mean=0.0, std=std) | ||||
|             if module.bias is not None: | ||||
|                 module.bias.data.zero_() | ||||
|         elif isinstance(module, nn.Embedding): | ||||
|             module.weight.data.normal_(mean=0.0, std=std) | ||||
|             if module.padding_idx is not None: | ||||
|                 module.weight.data[module.padding_idx].zero_() | ||||
|         elif isinstance(module, nn.MultiheadAttention): | ||||
|             # This uses torch's original init | ||||
|             module._reset_parameters() | ||||
|         # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names | ||||
|         # between modelings (because they are prefixed with the model name) | ||||
|         elif ( | ||||
|             isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) | ||||
|             or "LayerNorm" in module.__class__.__name__ | ||||
|             or "RMSNorm" in module.__class__.__name__ | ||||
|         ): | ||||
|             # Norms can exist without weights (in which case they are None from torch primitives) | ||||
|             if hasattr(module, "weight") and module.weight is not None: | ||||
|                 module.weight.data.fill_(1.0) | ||||
|             if hasattr(module, "bias") and module.bias is not None: | ||||
|                 module.bias.data.zero_() | ||||
|  | ||||
|     def _initialize_weights(self, module): | ||||
|         """ | ||||
| @ -3350,7 +3457,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|         variant: Optional[str] = None, | ||||
|         token: Optional[Union[str, bool]] = None, | ||||
|         save_peft_format: bool = True, | ||||
|         save_original_format: bool = False, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """ | ||||
| @ -3399,10 +3505,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|                 For backward compatibility with PEFT library, in case adapter weights are attached to the model, all | ||||
|                 keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can | ||||
|                 disable this behaviours by setting `save_peft_format` to `False`. | ||||
|             save_original_format (`bool`, *optional*, defaults to `True`): | ||||
|                 For backward compatibility with the previous versions of `transfomers` you can save the checkpoint with | ||||
|                 its reverse mapping. The reverse mapping needs to exists even if the model was loaded from a None legacy | ||||
|                 checkpoint. | ||||
|             kwargs (`dict[str, Any]`, *optional*): | ||||
|                 Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. | ||||
|         """ | ||||
| @ -3542,18 +3644,24 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|                         module_map[name + f".{key}"] = module | ||||
|             state_dict = model_to_save.state_dict() | ||||
|  | ||||
|         if ( | ||||
|             any( | ||||
|                 allowed_name in class_name.__name__.lower() | ||||
|                 for class_name in self.__class__.__mro__[:-1] | ||||
|                 for allowed_name in VLMS | ||||
|             ) | ||||
|             or save_original_format | ||||
|         if any( | ||||
|             allowed_name in class_name.__name__.lower() | ||||
|             for class_name in self.__class__.__mro__[:-1] | ||||
|             for allowed_name in VLMS | ||||
|         ): | ||||
|             # MEGA BIG TODO HERE: self._conversion_ops needs to be used to save the final ckpt | ||||
|             # using what was loaded. Actually self._conversion_ops wont work because we need it | ||||
|             # even if the files are not legacy -> thus no conversion happened | ||||
|             state_dict = revert_weight_conversion(self, state_dict) | ||||
|             reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()} | ||||
|  | ||||
|             original_state_dict = {} | ||||
|             for key, value in state_dict.items(): | ||||
|                 for pattern, replacement in reverse_key_mapping.items(): | ||||
|                     replacement = replacement.lstrip("^")  # strip off un-needed chars and patterns | ||||
|                     replacement = re.sub(r"\(.*\)", "", replacement) | ||||
|                     key, n_replace = re.subn(pattern, replacement, key) | ||||
|                     # Early exit of the loop | ||||
|                     if n_replace > 0: | ||||
|                         break | ||||
|                 original_state_dict[key] = value | ||||
|             state_dict = original_state_dict | ||||
|  | ||||
|         # Translate state_dict from smp to hf if saving with smp >= 1.10 | ||||
|         if IS_SAGEMAKER_MP_POST_1_10: | ||||
| @ -3721,8 +3829,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|  | ||||
|             if safe_serialization: | ||||
|                 # At some point we will need to deal better with save_function (used for TPU and other distributed | ||||
|                 # joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting | ||||
|                 # too much before scheduling the next write when its on a different | ||||
|                 # joyfulness), but for now this enough. | ||||
|                 safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) | ||||
|             else: | ||||
|                 save_function(shard, os.path.join(save_directory, shard_file)) | ||||
| @ -4180,7 +4287,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|         commit_hash = kwargs.pop("_commit_hash", None) | ||||
|         variant = kwargs.pop("variant", None) | ||||
|         adapter_kwargs = kwargs.pop("adapter_kwargs", {}) | ||||
|  | ||||
|         adapter_name = kwargs.pop("adapter_name", "default") | ||||
|         generation_config = kwargs.pop("generation_config", None) | ||||
|         gguf_file = kwargs.pop("gguf_file", None) | ||||
| @ -4279,7 +4385,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|             commit_hash = getattr(config, "_commit_hash", commit_hash) | ||||
|  | ||||
|         download_kwargs_with_commit["commit_hash"] = commit_hash | ||||
|         profile_weight_conversion = kwargs.pop("profile_weight_conversion", False) | ||||
|  | ||||
|         # Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call | ||||
|         # to correctly redispatch recursively if the kwarg is provided | ||||
| @ -4290,11 +4395,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|             config, quantization_config, dtype, device_map, weights_only, user_agent | ||||
|         ) | ||||
|  | ||||
|         weight_conversions: Optional[list[WeightConverter]] = None | ||||
|         model_type = getattr(config, "model_type", None) | ||||
|         if model_type is not None: | ||||
|             weight_conversions = DEFAULT_WEIGHT_CONVERSION_MAPPING.get(model_type) | ||||
|  | ||||
|         if gguf_file: | ||||
|             if hf_quantizer is not None: | ||||
|                 raise ValueError( | ||||
| @ -4354,6 +4454,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|         model.upcast_modules_in_fp32(hf_quantizer, dtype) | ||||
|         # Make sure to tie the weights correctly | ||||
|         model.tie_weights() | ||||
|  | ||||
|         # make sure we use the model's config since the __init__ call might have copied it | ||||
|         config = model.config | ||||
|  | ||||
| @ -4393,8 +4494,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|             device_mesh=device_mesh, | ||||
|             key_mapping=key_mapping, | ||||
|             weights_only=weights_only, | ||||
|             weight_mapping=weight_conversions, | ||||
|             profile_weight_conversion=profile_weight_conversion, | ||||
|         ) | ||||
|  | ||||
|         model.tie_weights()  # make sure token embedding weights are still tied if needed | ||||
| @ -4415,7 +4514,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|             ) | ||||
|  | ||||
|         # for device_map="auto" : dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly | ||||
|         # harm performances). TODO: replace with native PP | ||||
|         # harm performances). | ||||
|         if device_map is not None and device_mesh is None: | ||||
|             accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers) | ||||
|  | ||||
| @ -4574,16 +4673,97 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|         device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, | ||||
|         key_mapping: Optional[dict[str, str]] = None, | ||||
|         weights_only: bool = True, | ||||
|         weight_mapping: Optional[Sequence[WeightConverter]] = None, | ||||
|         profile_weight_conversion: bool = False, | ||||
|     ): | ||||
|         # TODO: we should only be calling hf_quantizer.skip_placement or something like that | ||||
|         is_quantized = hf_quantizer is not None | ||||
|         is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in { | ||||
|             QuantizationMethod.HQQ, | ||||
|             QuantizationMethod.QUARK, | ||||
|         } | ||||
|         # Model's definition arriving here is final (TP hooks added, quantized layers replaces) | ||||
|  | ||||
|         # Get all the keys of the state dicts that we have to initialize the model with | ||||
|         if sharded_metadata is not None: | ||||
|             original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"] | ||||
|         elif state_dict is not None: | ||||
|             original_checkpoint_keys = list(state_dict.keys()) | ||||
|         else: | ||||
|             original_checkpoint_keys = list( | ||||
|                 load_state_dict(checkpoint_files[0], map_location="meta", weights_only=weights_only).keys() | ||||
|             ) | ||||
|  | ||||
|         # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture | ||||
|         prefix = model.base_model_prefix | ||||
|         has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False | ||||
|         expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False | ||||
|         loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module | ||||
|         loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module | ||||
|  | ||||
|         # Find the key names that the model expects from the serialized keys | ||||
|         key_renaming_mapping = model._get_key_renaming_mapping( | ||||
|             original_checkpoint_keys, | ||||
|             key_mapping, | ||||
|             loading_base_model_from_task_state_dict, | ||||
|             loading_task_model_from_base_state_dict, | ||||
|         ) | ||||
|         checkpoint_keys = list(key_renaming_mapping.values()) | ||||
|  | ||||
|         # Find missing and unexpected keys from the state dict | ||||
|         missing_keys, unexpected_keys = _find_missing_and_unexpected_keys( | ||||
|             model, original_checkpoint_keys, checkpoint_keys, loading_base_model_from_task_state_dict, hf_quantizer | ||||
|         ) | ||||
|         # Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the | ||||
|         # same way as missing keys) | ||||
|         mismatched_keys, mismatched_shapes = _find_mismatched_keys( | ||||
|             model, | ||||
|             state_dict, | ||||
|             checkpoint_files, | ||||
|             ignore_mismatched_sizes, | ||||
|             key_renaming_mapping, | ||||
|             is_quantized, | ||||
|             weights_only, | ||||
|         ) | ||||
|  | ||||
|         # We need to update both the mapping and the list of checkpoint keys to remove the mismatched and unexpected ones | ||||
|         key_renaming_mapping = { | ||||
|             k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys and v not in unexpected_keys | ||||
|         } | ||||
|         checkpoint_keys = list(key_renaming_mapping.values()) | ||||
|  | ||||
|         # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when | ||||
|         # loading the weights as they are not in the loaded state dict) | ||||
|         model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, dtype, hf_quantizer) | ||||
|  | ||||
|         # correctly initialize the missing (and potentially mismatched) keys | ||||
|         model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized) | ||||
|  | ||||
|         # Get reverse key mapping | ||||
|         reverse_key_renaming_mapping = {v: k for k, v in key_renaming_mapping.items()} | ||||
|  | ||||
|         is_offloaded_safetensors = False | ||||
|         # This offload index if for params explicitly on the "disk" in the device_map | ||||
|         disk_offload_index = None | ||||
|         disk_only_shard_files = [] | ||||
|         # Prepare parameters offloading if needed | ||||
|         if device_map is not None and "disk" in device_map.values(): | ||||
|             disk_offload_index, disk_only_shard_files, is_offloaded_safetensors = accelerate_disk_offload( | ||||
|                 disk_offload_folder, | ||||
|                 checkpoint_files, | ||||
|                 device_map, | ||||
|                 checkpoint_keys, | ||||
|                 key_renaming_mapping, | ||||
|                 sharded_metadata, | ||||
|                 dtype, | ||||
|                 reverse_key_renaming_mapping, | ||||
|             ) | ||||
|         # To be able to iterate, even if we don't use it if the state_dict is already provided | ||||
|         elif state_dict is not None: | ||||
|             checkpoint_files = [""] | ||||
|  | ||||
|         # Compute expected model keys | ||||
|         expected_keys = list(model.state_dict().keys()) | ||||
|         if hf_quantizer is not None: | ||||
|             expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys) | ||||
|  | ||||
|         if logger.level >= logging.WARNING: | ||||
|             verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None)) | ||||
|  | ||||
| @ -4592,79 +4772,46 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|             expanded_device_map = expand_device_map(device_map, expected_keys) | ||||
|             caching_allocator_warmup(model, expanded_device_map, hf_quantizer) | ||||
|  | ||||
|         # Now we read all the files to get a pointer on each physical weights | ||||
|         merged_state_dict = {} | ||||
|         all_pointer = set() | ||||
|  | ||||
|         if device_map is None: | ||||
|             device_map = {"": "cpu"} | ||||
|         keys = sorted(device_map.keys(), key=len, reverse=True) | ||||
|         tp_plan = getattr(model, "_tp_plan", None) | ||||
|         keep_in_dtype = None  # TODO use keep_in | ||||
|         error_msgs = [] | ||||
|         misc = {} | ||||
|  | ||||
|         if is_deepspeed_zero3_enabled() and not is_quantized: | ||||
|             error_msgs += _load_state_dict_into_zero3_model(model, state_dict) | ||||
|         else: | ||||
|             if checkpoint_files is not None: | ||||
|                 pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")") | ||||
|                 if sharded_metadata is None: | ||||
|                     k_v_iterator = dict.fromkeys( | ||||
|                         safe_open(checkpoint_files[0], framework="pt").keys(), "model.safetensors" | ||||
|                     ).items() | ||||
|                 else: | ||||
|                     k_v_iterator = sharded_metadata["weight_map"].items() | ||||
|  | ||||
|                 for k, v in k_v_iterator: | ||||
|                     key = pattern.match(k).group(1) | ||||
|                     if key is not None and key != "": | ||||
|                         device = device_map[key] | ||||
|                     else: | ||||
|                         device = device_map[""] | ||||
|                         if isinstance(device, torch.device): | ||||
|                             device = device.index  # safetensors only | ||||
|                     file_pointer = safe_open( | ||||
|                         os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device | ||||
|                     ) | ||||
|                     all_pointer.add(file_pointer) | ||||
|                     merged_state_dict[k] = (v, file_pointer.get_slice(k))  # don't meterialize yet | ||||
|             elif state_dict is not None: | ||||
|                 merged_state_dict = {k: ("", v) for k, v in state_dict.items()} | ||||
|             else: | ||||
|                 raise ValueError("Neither a state dict nor checkpoint files were found.") | ||||
|  | ||||
|             missing_keys, unexpected_keys, mismatched_keys, misc = convert_and_load_state_dict_in_model( | ||||
|                 model, | ||||
|                 merged_state_dict, | ||||
|                 weight_mapping, | ||||
|                 tp_plan, | ||||
|                 hf_quantizer, | ||||
|         # Prepare and compatabilize arguments for serial and parallel shard loading | ||||
|         args_list = [ | ||||
|             ( | ||||
|                 shard_file, | ||||
|                 state_dict, | ||||
|                 disk_only_shard_files, | ||||
|                 is_quantized, | ||||
|                 device_map, | ||||
|                 keep_in_dtype, | ||||
|                 device_mesh=device_mesh, | ||||
|                 profile=profile_weight_conversion, | ||||
|                 hf_quantizer, | ||||
|                 key_renaming_mapping, | ||||
|                 weights_only, | ||||
|                 model, | ||||
|                 reverse_key_renaming_mapping, | ||||
|                 disk_offload_folder, | ||||
|                 disk_offload_index, | ||||
|                 device_mesh, | ||||
|             ) | ||||
|             for shard_file in checkpoint_files | ||||
|         ] | ||||
|  | ||||
|         for k in all_pointer:  # finally close all opened file pointeres | ||||
|             k.__exit__(None, None, None) | ||||
|         error_msgs = [] | ||||
|  | ||||
|         new_state_dict = model.state_dict() | ||||
|         if ( | ||||
|             os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES | ||||
|             and not is_deepspeed_zero3_enabled() | ||||
|         ): | ||||
|             _error_msgs, disk_offload_index = load_shard_files_with_threadpool(args_list) | ||||
|             error_msgs += _error_msgs | ||||
|         else: | ||||
|             if len(args_list) > 1: | ||||
|                 args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") | ||||
|  | ||||
|         #!!!!!!!!!!!!!!!!!!!!!!! POST PROCESS!!!!!!!!!!!!!!!!!! | ||||
|         # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture | ||||
|         prefix = model.base_model_prefix | ||||
|         has_prefix_module = any(s.startswith(prefix) for s in new_state_dict.keys()) if len(prefix) > 0 else False | ||||
|         expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False | ||||
|         loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module | ||||
|             for args in args_list: | ||||
|                 _error_msgs, disk_offload_index = load_shard_file(args) | ||||
|                 error_msgs += _error_msgs | ||||
|  | ||||
|         # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when | ||||
|         # loading the weights as they are not in the loaded state dict) | ||||
|         miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys} | ||||
|         model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer) | ||||
|  | ||||
|         # correctly initialize the missing (and potentially mismatched) keys | ||||
|         model._initialize_missing_keys(miss_and_mismatched, is_quantized) | ||||
|         # Save offloaded index if needed | ||||
|         if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors: | ||||
|             save_offload_index(disk_offload_index, disk_offload_folder) | ||||
|             disk_offload_index = None | ||||
|  | ||||
|         # Post-processing for tensor parallelism | ||||
|         if device_mesh is not None: | ||||
| @ -4703,19 +4850,48 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|         missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys( | ||||
|             missing_keys, unexpected_keys, loading_task_model_from_base_state_dict | ||||
|         ) | ||||
|         log_state_dict_report( | ||||
|             model=model, | ||||
|             pretrained_model_name_or_path=pretrained_model_name_or_path, | ||||
|             logger=logger, | ||||
|             error_msgs=error_msgs, | ||||
|             unexpected_keys=unexpected_keys, | ||||
|             missing_keys=missing_keys, | ||||
|             mismatched_keys=mismatched_keys, | ||||
|             mismatched_shapes=mismatched_keys, | ||||
|             misc=misc, | ||||
|             ignore_mismatched_sizes=ignore_mismatched_sizes, | ||||
|         ) | ||||
|         disk_offload_index = None | ||||
|  | ||||
|         # TODO: separate this in another function: it's not core.... | ||||
|         # All potential warnings/infos | ||||
|         if len(error_msgs) > 0: | ||||
|             error_msg = "\n\t".join(error_msgs) | ||||
|             if "size mismatch" in error_msg: | ||||
|                 error_msg += ( | ||||
|                     "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." | ||||
|                 ) | ||||
|             raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | ||||
|         if len(unexpected_keys) > 0: | ||||
|             archs = [] if model.config.architectures is None else model.config.architectures | ||||
|             warner = logger.warning if model.__class__.__name__ in archs else logger.info | ||||
|             warner( | ||||
|                 f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" | ||||
|                 f" initializing {model.__class__.__name__}: {update_key_name(unexpected_keys)}\n- This IS expected if you are" | ||||
|                 f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" | ||||
|                 " with another architecture (e.g. initializing a BertForSequenceClassification model from a" | ||||
|                 " BertForPreTraining model).\n- This IS NOT expected if you are initializing" | ||||
|                 f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" | ||||
|                 " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." | ||||
|             ) | ||||
|         if len(missing_keys) > 0: | ||||
|             logger.warning( | ||||
|                 f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | ||||
|                 f" {pretrained_model_name_or_path} and are newly initialized: {update_key_name(missing_keys)}\nYou should probably" | ||||
|                 " TRAIN this model on a down-stream task to be able to use it for predictions and inference." | ||||
|             ) | ||||
|         if len(mismatched_keys) > 0: | ||||
|             mismatched_warning = "\n".join( | ||||
|                 [ | ||||
|                     f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" | ||||
|                     for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes) | ||||
|                 ] | ||||
|             ) | ||||
|             logger.warning( | ||||
|                 f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | ||||
|                 f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" | ||||
|                 f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" | ||||
|                 " to use it for predictions and inference." | ||||
|             ) | ||||
|  | ||||
|         return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs | ||||
|  | ||||
|     def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): | ||||
| @ -4925,8 +5101,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH | ||||
|                 if not is_quantized or not hf_quantizer.param_needs_quantization(self, key): | ||||
|                     _load_parameter_into_model(self, key, value) | ||||
|                 else: | ||||
|                     # hf_quantizer.create_quantized_param(self, value, key, "cpu") | ||||
|                     pass | ||||
|                     hf_quantizer.create_quantized_param(self, value, key, "cpu") | ||||
|  | ||||
|     def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None: | ||||
|         """Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to | ||||
|  | ||||
| @ -42,44 +42,37 @@ from ...utils.generic import check_model_inputs | ||||
| from .configuration_deepseek_v2 import DeepseekV2Config | ||||
|  | ||||
|  | ||||
| class DeepseekV2Experts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class DeepseekV2Experts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.n_routed_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.moe_intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(config.n_routed_experts): | ||||
|             self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -224,10 +224,12 @@ def apply_rotary_emb( | ||||
|     return xq_out, xk_out | ||||
|  | ||||
|  | ||||
| class DeepseekV2Experts(Qwen2MoeExperts): | ||||
| class DeepseekV2Experts(Qwen2MoeExperts, nn.ModuleList): | ||||
|     def __init__(self, config): | ||||
|         super().__init__(config) | ||||
|         nn.ModuleList.__init__(self) | ||||
|         self.num_experts = config.n_routed_experts | ||||
|         for _ in range(config.n_routed_experts): | ||||
|             self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|  | ||||
| class DeepseekV2Moe(nn.Module): | ||||
|  | ||||
| @ -149,44 +149,37 @@ class DeepseekV3TopkRouter(nn.Module): | ||||
|         return router_logits | ||||
|  | ||||
|  | ||||
| class DeepseekV3NaiveMoe(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class DeepseekV3NaiveMoe(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -102,10 +102,12 @@ class DeepseekV3TopkRouter(nn.Module): | ||||
|         return router_logits | ||||
|  | ||||
|  | ||||
| class DeepseekV3NaiveMoe(MixtralExperts): | ||||
| class DeepseekV3NaiveMoe(MixtralExperts, nn.ModuleList): | ||||
|     def __init__(self, config): | ||||
|         super().__init__(config) | ||||
|         nn.ModuleList.__init__(self) | ||||
|         self.num_experts = config.num_local_experts | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|  | ||||
| class DeepseekV3MoE(nn.Module): | ||||
|  | ||||
| @ -305,44 +305,37 @@ class Dots1TopkRouter(nn.Module): | ||||
|         return router_logits | ||||
|  | ||||
|  | ||||
| class Dots1NaiveMoe(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class Dots1NaiveMoe(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(Dots1MLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -315,95 +315,62 @@ class Ernie4_5_MoeStatics(nn.Module): | ||||
|         return hidden_states + self.e_score_correction_bias.squeeze() | ||||
|  | ||||
|  | ||||
| class Ernie4_5_MoeExperts(nn.Module): | ||||
| class Ernie4_5_MoeExperts(nn.ModuleList): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.moe_num_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.moe_intermediate_size | ||||
|         self.use_bias = config.use_bias | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         if self.use_bias: | ||||
|             self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim)) | ||||
|             self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|         else: | ||||
|             self.gate_up_proj_bias = None | ||||
|             self.down_proj_bias = None | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         if selected_experts.numel() == 0: | ||||
|             return final_hidden_states | ||||
|  | ||||
|         expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = int(expert_idx.item()) | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             gate_inputs = F.linear( | ||||
|                 current_state, | ||||
|                 self.gate_up_proj[expert_idx], | ||||
|                 None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx], | ||||
|             ) | ||||
|             gate, up = gate_inputs.chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = F.linear( | ||||
|                 current_hidden_states, | ||||
|                 self.down_proj[expert_idx], | ||||
|                 None if self.down_proj_bias is None else self.down_proj_bias[expert_idx], | ||||
|             ) | ||||
|             current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None] | ||||
|             current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class Ernie4_5_MoeTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.linear = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) | ||||
|         self.moe_statics = Ernie4_5_MoeStatics(config) | ||||
|         self.top_k = config.moe_k | ||||
|         self.norm_min = config.moe_norm_min | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         device_type = ( | ||||
|             hidden_states.device.type | ||||
|             if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" | ||||
|             else "cpu" | ||||
|         ) | ||||
|  | ||||
|         with torch.autocast(device_type=device_type, enabled=False):  # Force float32 | ||||
|             router_logits = self.linear(hidden_states.float()) | ||||
|             routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | ||||
|             _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) | ||||
|             routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) | ||||
|             routing_weights = routing_weights / torch.clamp( | ||||
|                 routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min | ||||
|             ) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return routing_weights, selected_experts | ||||
|  | ||||
|  | ||||
| class Ernie4_5_MoeSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.num_experts = config.moe_num_experts | ||||
|         self.top_k = config.moe_k | ||||
|         self.router = Ernie4_5_MoeTopKRouter(config) | ||||
|         self.norm_min = config.moe_norm_min | ||||
|  | ||||
|         self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) | ||||
|         self.moe_statics = Ernie4_5_MoeStatics(config) | ||||
|         self.experts = Ernie4_5_MoeExperts(config) | ||||
|  | ||||
|         self.shared_experts = None | ||||
|         if config.moe_num_shared_experts > 0: | ||||
|             self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states): | ||||
|         device_type = ( | ||||
|             hidden_states.device.type | ||||
|             if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" | ||||
|             else "cpu" | ||||
|         ) | ||||
|  | ||||
|         with torch.autocast(device_type=device_type, enabled=False):  # Force float32 | ||||
|             router_logits = self.gate(hidden_states.float()) | ||||
|             routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | ||||
|             _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) | ||||
|             routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) | ||||
|             routing_weights = routing_weights / torch.clamp( | ||||
|                 routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min | ||||
|             ) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return selected_experts, routing_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||
|         batch_size, sequence_length, _ = hidden_states.shape | ||||
|         hidden_states = hidden_states.view(-1, self.hidden_dim) | ||||
| @ -411,7 +378,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module): | ||||
|         if self.shared_experts is not None: | ||||
|             shared_output = self.shared_experts(hidden_states) | ||||
|  | ||||
|         routing_weights, selected_experts = self.router(hidden_states) | ||||
|         selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states) | ||||
|         final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) | ||||
|  | ||||
|         if self.shared_experts is not None: | ||||
| @ -487,11 +454,11 @@ class Ernie4_5_MoePreTrainedModel(PreTrainedModel): | ||||
|     _can_compile_fullgraph = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported) | ||||
|     _supports_attention_backend = True | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.router", index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "hidden_states": Ernie4_5_MoeDecoderLayer, | ||||
|         "attentions": Ernie4_5_MoeAttention, | ||||
|     } | ||||
|     _keep_in_fp32_modules_strict = ["router"] | ||||
|     _keep_in_fp32_modules_strict = ["gate", "moe_statics"] | ||||
|     # Not supporting multi-token prediction (MTP) atm | ||||
|     _keys_to_ignore_on_load_unexpected = ["mtp"] | ||||
|  | ||||
|  | ||||
| @ -19,7 +19,6 @@ import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|  | ||||
| from ...activations import ACT2FN | ||||
| from ...cache_utils import Cache, DynamicCache | ||||
| from ...masking_utils import create_causal_mask | ||||
| from ...modeling_outputs import MoeModelOutputWithPast | ||||
| @ -97,95 +96,62 @@ class Ernie4_5_MoeStatics(nn.Module): | ||||
|         return hidden_states + self.e_score_correction_bias.squeeze() | ||||
|  | ||||
|  | ||||
| class Ernie4_5_MoeExperts(nn.Module): | ||||
| class Ernie4_5_MoeExperts(nn.ModuleList): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.moe_num_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.moe_intermediate_size | ||||
|         self.use_bias = config.use_bias | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         if self.use_bias: | ||||
|             self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim)) | ||||
|             self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|         else: | ||||
|             self.gate_up_proj_bias = None | ||||
|             self.down_proj_bias = None | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         if selected_experts.numel() == 0: | ||||
|             return final_hidden_states | ||||
|  | ||||
|         expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = int(expert_idx.item()) | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             gate_inputs = F.linear( | ||||
|                 current_state, | ||||
|                 self.gate_up_proj[expert_idx], | ||||
|                 None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx], | ||||
|             ) | ||||
|             gate, up = gate_inputs.chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = F.linear( | ||||
|                 current_hidden_states, | ||||
|                 self.down_proj[expert_idx], | ||||
|                 None if self.down_proj_bias is None else self.down_proj_bias[expert_idx], | ||||
|             ) | ||||
|             current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None] | ||||
|             current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class Ernie4_5_MoeTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.linear = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) | ||||
|         self.moe_statics = Ernie4_5_MoeStatics(config) | ||||
|         self.top_k = config.moe_k | ||||
|         self.norm_min = config.moe_norm_min | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         device_type = ( | ||||
|             hidden_states.device.type | ||||
|             if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" | ||||
|             else "cpu" | ||||
|         ) | ||||
|  | ||||
|         with torch.autocast(device_type=device_type, enabled=False):  # Force float32 | ||||
|             router_logits = self.linear(hidden_states.float()) | ||||
|             routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | ||||
|             _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) | ||||
|             routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) | ||||
|             routing_weights = routing_weights / torch.clamp( | ||||
|                 routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min | ||||
|             ) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return routing_weights, selected_experts | ||||
|  | ||||
|  | ||||
| class Ernie4_5_MoeSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.num_experts = config.moe_num_experts | ||||
|         self.top_k = config.moe_k | ||||
|         self.router = Ernie4_5_MoeTopKRouter(config) | ||||
|         self.norm_min = config.moe_norm_min | ||||
|  | ||||
|         self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) | ||||
|         self.moe_statics = Ernie4_5_MoeStatics(config) | ||||
|         self.experts = Ernie4_5_MoeExperts(config) | ||||
|  | ||||
|         self.shared_experts = None | ||||
|         if config.moe_num_shared_experts > 0: | ||||
|             self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states): | ||||
|         device_type = ( | ||||
|             hidden_states.device.type | ||||
|             if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" | ||||
|             else "cpu" | ||||
|         ) | ||||
|  | ||||
|         with torch.autocast(device_type=device_type, enabled=False):  # Force float32 | ||||
|             router_logits = self.gate(hidden_states.float()) | ||||
|             routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | ||||
|             _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) | ||||
|             routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) | ||||
|             routing_weights = routing_weights / torch.clamp( | ||||
|                 routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min | ||||
|             ) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return selected_experts, routing_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||
|         batch_size, sequence_length, _ = hidden_states.shape | ||||
|         hidden_states = hidden_states.view(-1, self.hidden_dim) | ||||
| @ -193,7 +159,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module): | ||||
|         if self.shared_experts is not None: | ||||
|             shared_output = self.shared_experts(hidden_states) | ||||
|  | ||||
|         routing_weights, selected_experts = self.router(hidden_states) | ||||
|         selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states) | ||||
|         final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) | ||||
|  | ||||
|         if self.shared_experts is not None: | ||||
| @ -227,11 +193,11 @@ class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer): | ||||
| class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel): | ||||
|     config: Ernie4_5_MoeConfig | ||||
|     _no_split_modules = ["Ernie4_5_MoeDecoderLayer"] | ||||
|     _keep_in_fp32_modules_strict = ["router"] | ||||
|     _keep_in_fp32_modules_strict = ["gate", "moe_statics"] | ||||
|     # Not supporting multi-token prediction (MTP) atm | ||||
|     _keys_to_ignore_on_load_unexpected = ["mtp"] | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.router", index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "hidden_states": Ernie4_5_MoeDecoderLayer, | ||||
|         "attentions": Ernie4_5_MoeAttention, | ||||
|     } | ||||
|  | ||||
| @ -109,7 +109,6 @@ class FlexOlmoConfig(PreTrainedConfig): | ||||
|  | ||||
|     model_type = "flex_olmo" | ||||
|     keys_to_ignore_at_inference = ["past_key_values"] | ||||
|     attribute_map = {"num_local_experts": "num_experts"} | ||||
|     base_model_tp_plan = { | ||||
|         "layers.*.self_attn.q_proj": "colwise_rep",  # we need to replicate here due to the added norm on q and k | ||||
|         "layers.*.self_attn.k_proj": "colwise_rep",  # we need to replicate here due to the added norm on q and k | ||||
|  | ||||
| @ -23,7 +23,6 @@ from collections.abc import Callable | ||||
| from typing import Optional, Union | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|  | ||||
| from ...activations import ACT2FN | ||||
| @ -292,78 +291,64 @@ class FlexOlmoAttention(nn.Module): | ||||
|         return attn_output, attn_weights | ||||
|  | ||||
|  | ||||
| class FlexOlmoExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class FlexOlmoExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: FlexOlmoConfig): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class FlexOlmoTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         for _ in range(config.num_experts): | ||||
|             self.append(FlexOlmoMLP(config)) | ||||
|         self.num_experts = config.num_experts | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         if self.norm_topk_prob: | ||||
|             router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_top_value = router_top_value.to(router_logits.dtype) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|     def forward( | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class FlexOlmoSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.gate = FlexOlmoTopKRouter(config) | ||||
|         self.num_experts = config.num_experts | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) | ||||
|         self.experts = FlexOlmoExperts(config) | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states, router_logits): | ||||
|         routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) | ||||
|         top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) | ||||
|         if self.norm_topk_prob: | ||||
|             top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) | ||||
|         top_k_weights = top_k_weights.to(hidden_states.dtype) | ||||
|         return top_k_index, top_k_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states = hidden_states.view(-1, hidden_dim) | ||||
|         top_k_weights, top_k_index = self.gate(hidden_states) | ||||
|         router_logits = self.gate(hidden_states) | ||||
|         top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) | ||||
|         final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( | ||||
|             batch_size, sequence_length, hidden_dim | ||||
|         ) | ||||
|  | ||||
| @ -156,7 +156,7 @@ class Gemma3TextConfig(PreTrainedConfig): | ||||
|         layer_types: Optional[list[str]] = None, | ||||
|         final_logit_softcapping: Optional[float] = None, | ||||
|         attn_logit_softcapping: Optional[float] = None, | ||||
|         rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, | ||||
|         rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, | ||||
|         use_bidirectional_attention: Optional[bool] = False, | ||||
|         **kwargs, | ||||
|     ): | ||||
| @ -186,10 +186,16 @@ class Gemma3TextConfig(PreTrainedConfig): | ||||
|         self.final_logit_softcapping = final_logit_softcapping | ||||
|         self.attn_logit_softcapping = attn_logit_softcapping | ||||
|         self.layer_types = layer_types | ||||
|  | ||||
|         # Try to set `rope_scaling` if available, otherwise use `rope_parameters` | ||||
|         rope_scaling = kwargs.pop("rope_scaling", None) | ||||
|         if rope_scaling is not None: | ||||
|             rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling} | ||||
|         if (rope_scaling := kwargs.pop("rope_scaling", None)) is not None: | ||||
|             if rope_parameters is None: | ||||
|                 rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling} | ||||
|             elif "full_attention" in rope_parameters: | ||||
|                 rope_parameters["full_attention"].update(rope_scaling) | ||||
|             else: | ||||
|                 rope_parameters.update(rope_scaling) | ||||
|  | ||||
|         self.rope_parameters = rope_parameters | ||||
|         self.use_bidirectional_attention = use_bidirectional_attention | ||||
|         if use_bidirectional_attention: | ||||
|  | ||||
| @ -191,7 +191,10 @@ _VARIANTS = { | ||||
|             num_hidden_layers=34, | ||||
|             num_key_value_heads=4, | ||||
|             sliding_window=1024, | ||||
|             rope_parameters={"rope_type": "linear", "factor": 8.0},  # used for global RoPE only | ||||
|             rope_parameters={ | ||||
|                 "full_attention": {"rope_type": "linear", "factor": 8.0}, | ||||
|                 "sliding_attention": {"rope_type": "default"}, | ||||
|             }, | ||||
|             rope_theta=1_000_000, | ||||
|             rope_local_base_freq=10_000, | ||||
|             attn_logit_softcapping=None, | ||||
| @ -209,7 +212,10 @@ _VARIANTS = { | ||||
|             num_hidden_layers=48, | ||||
|             num_key_value_heads=8, | ||||
|             sliding_window=1024, | ||||
|             rope_parameters={"rope_type": "linear", "factor": 8.0},  # used for global RoPE only | ||||
|             rope_parameters={ | ||||
|                 "full_attention": {"rope_type": "linear", "factor": 8.0}, | ||||
|                 "sliding_attention": {"rope_type": "default"}, | ||||
|             }, | ||||
|             rope_theta=1_000_000, | ||||
|             rope_local_base_freq=10_000, | ||||
|             attn_logit_softcapping=None, | ||||
| @ -227,7 +233,10 @@ _VARIANTS = { | ||||
|             num_key_value_heads=16, | ||||
|             head_dim=128, | ||||
|             sliding_window=1024, | ||||
|             rope_parameters={"rope_type": "linear", "factor": 8.0},  # used for global RoPE only | ||||
|             rope_parameters={ | ||||
|                 "full_attention": {"rope_type": "linear", "factor": 8.0}, | ||||
|                 "sliding_attention": {"rope_type": "default"}, | ||||
|             }, | ||||
|             rope_theta=1_000_000, | ||||
|             rope_local_base_freq=10_000, | ||||
|             attn_logit_softcapping=None, | ||||
|  | ||||
| @ -171,7 +171,7 @@ class Gemma3TextConfig(Gemma2Config, PreTrainedConfig): | ||||
|         layer_types: Optional[list[str]] = None, | ||||
|         final_logit_softcapping: Optional[float] = None, | ||||
|         attn_logit_softcapping: Optional[float] = None, | ||||
|         rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, | ||||
|         rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, | ||||
|         use_bidirectional_attention: Optional[bool] = False, | ||||
|         **kwargs, | ||||
|     ): | ||||
| @ -201,10 +201,16 @@ class Gemma3TextConfig(Gemma2Config, PreTrainedConfig): | ||||
|         self.final_logit_softcapping = final_logit_softcapping | ||||
|         self.attn_logit_softcapping = attn_logit_softcapping | ||||
|         self.layer_types = layer_types | ||||
|  | ||||
|         # Try to set `rope_scaling` if available, otherwise use `rope_parameters` | ||||
|         rope_scaling = kwargs.pop("rope_scaling", None) | ||||
|         if rope_scaling is not None: | ||||
|             rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling} | ||||
|         if (rope_scaling := kwargs.pop("rope_scaling", None)) is not None: | ||||
|             if rope_parameters is None: | ||||
|                 rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling} | ||||
|             elif "full_attention" in rope_parameters: | ||||
|                 rope_parameters["full_attention"].update(rope_scaling) | ||||
|             else: | ||||
|                 rope_parameters.update(rope_scaling) | ||||
|  | ||||
|         self.rope_parameters = rope_parameters | ||||
|         self.use_bidirectional_attention = use_bidirectional_attention | ||||
|         if use_bidirectional_attention: | ||||
|  | ||||
| @ -330,44 +330,37 @@ class Glm4MoeRMSNorm(nn.Module): | ||||
|         return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" | ||||
|  | ||||
|  | ||||
| class Glm4MoeNaiveMoe(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class Glm4MoeNaiveMoe(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -1424,6 +1424,8 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin): | ||||
|         **kwargs: Unpack[TransformersKwargs], | ||||
|     ) -> Union[tuple, Glm4vCausalLMOutputWithPast]: | ||||
|         r""" | ||||
|         rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): | ||||
|             The rope index difference between sequence length and multimodal rope. | ||||
|         labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||||
|             Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | ||||
|             config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | ||||
| @ -1432,8 +1434,6 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin): | ||||
|             The temporal, height and width of feature shape of each image in LLM. | ||||
|         video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): | ||||
|             The temporal, height and width of feature shape of each video in LLM. | ||||
|         rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): | ||||
|             The rope index difference between sequence length and multimodal rope. | ||||
|  | ||||
|         Example: | ||||
|  | ||||
|  | ||||
| @ -351,44 +351,37 @@ class Glm4vMoeTextTopkRouter(nn.Module): | ||||
|         return router_logits | ||||
|  | ||||
|  | ||||
| class Glm4vMoeTextNaiveMoe(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class Glm4vMoeTextNaiveMoe(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -411,9 +411,10 @@ class GraniteMoeDecoderLayer(GradientCheckpointingLayer): | ||||
|         super().__init__() | ||||
|         self.hidden_size = config.hidden_size | ||||
|         self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx) | ||||
|         self.block_sparse_moe = GraniteMoeMoE(config) | ||||
|         self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.block_sparse_moe = GraniteMoeMoE(config) | ||||
|  | ||||
|         self.residual_multiplier = config.residual_multiplier  # Only diff with mixtral! | ||||
|  | ||||
|     def forward( | ||||
|  | ||||
| @ -105,8 +105,7 @@ class GraniteMoeDecoderLayer(MixtralDecoderLayer): | ||||
|         self.block_sparse_moe = GraniteMoeMoE(config) | ||||
|         self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         del self.mlp | ||||
|         self.block_sparse_moe = GraniteMoeMoE(config) | ||||
|  | ||||
|         self.residual_multiplier = config.residual_multiplier  # Only diff with mixtral! | ||||
|  | ||||
|     def forward( | ||||
|  | ||||
| @ -1119,9 +1119,10 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): | ||||
|         self.hidden_size = config.hidden_size | ||||
|         # Either attention or mamba will be initialized, depending on the layer type. | ||||
|         self.self_attn = None | ||||
|         self.block_sparse_moe = GraniteMoeHybridMoE(config) | ||||
|         self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.block_sparse_moe = GraniteMoeHybridMoE(config) | ||||
|  | ||||
|         self.residual_multiplier = config.residual_multiplier  # Only diff with mixtral! | ||||
|         self.shared_mlp = GraniteMoeHybridMLP(config) | ||||
|         self.mamba = None | ||||
|  | ||||
| @ -401,9 +401,10 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer): | ||||
|         super().__init__() | ||||
|         self.hidden_size = config.hidden_size | ||||
|         self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx) | ||||
|         self.block_sparse_moe = GraniteMoeSharedMoE(config) | ||||
|         self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.block_sparse_moe = GraniteMoeSharedMoE(config) | ||||
|  | ||||
|         self.residual_multiplier = config.residual_multiplier  # Only diff with mixtral! | ||||
|         self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config) | ||||
|  | ||||
|  | ||||
| @ -243,44 +243,38 @@ class HunYuanMoEV1Gate(nn.Module): | ||||
|         return logits | ||||
|  | ||||
|  | ||||
| class HunYuanMoEV1Experts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class HunYuanMoEV1Experts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: HunYuanMoEV1Config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(HunYuanMoEV1MLP(config)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -557,44 +557,38 @@ class JambaMLP(nn.Module): | ||||
|         return down_proj | ||||
|  | ||||
|  | ||||
| class JambaExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class JambaExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: JambaConfig): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(JambaMLP(config)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -24,7 +24,6 @@ import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|  | ||||
| from ...activations import ACT2FN | ||||
| from ...cache_utils import Cache | ||||
| from ...generation import GenerationMixin | ||||
| from ...integrations import use_kernel_forward_from_hub | ||||
| @ -145,44 +144,37 @@ class Lfm2MoeMLP(nn.Module): | ||||
|         return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||||
|  | ||||
|  | ||||
| class Lfm2MoeExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class Lfm2MoeExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.moe_intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(config.num_experts): | ||||
|             self.append(Lfm2MoeMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -164,7 +164,7 @@ class LongcatFlashTopkRouter(nn.Module): | ||||
|         topk_indices = self.get_topk_indices(scores) | ||||
|         topk_weights = scores.gather(1, topk_indices) | ||||
|         topk_weights = topk_weights * self.routed_scaling_factor | ||||
|         return topk_weights.to(router_logits.dtype), topk_indices | ||||
|         return topk_indices, topk_weights | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def get_topk_indices(self, scores): | ||||
| @ -173,51 +173,29 @@ class LongcatFlashTopkRouter(nn.Module): | ||||
|         return topk_indices | ||||
|  | ||||
|  | ||||
| class LongcatFlashExperts(nn.Module): | ||||
| class LongcatFlashExperts(nn.ModuleList): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.intermediate_size = config.expert_ffn_hidden_size | ||||
|         self.hidden_size = config.hidden_size | ||||
|         self.num_routed_experts = config.n_routed_experts | ||||
|         self.zero_expert_num = config.zero_expert_num or 0 | ||||
|         self.total_experts = self.num_routed_experts + self.zero_expert_num | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         self.num_experts = config.n_routed_experts + config.zero_expert_num | ||||
|         self.zero_expert_num = config.zero_expert_num | ||||
|  | ||||
|         if self.num_routed_experts > 0: | ||||
|             self.gate_up_proj = nn.Parameter( | ||||
|                 torch.empty(self.total_experts, 2 * self.intermediate_size, self.hidden_size) | ||||
|             ) | ||||
|             self.down_proj = nn.Parameter( | ||||
|                 torch.empty(self.num_routed_experts, self.hidden_size, self.intermediate_size) | ||||
|             ) | ||||
|         else: | ||||
|             self.register_parameter("gate_up_proj", None) | ||||
|             self.register_parameter("down_proj", None) | ||||
|         self.extend( | ||||
|             [LongcatFlashMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)] | ||||
|             + [nn.Identity() for _ in range(self.zero_expert_num)] | ||||
|         ) | ||||
|  | ||||
|     def forward(self, hidden_states, top_k_index, top_k_weights): | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         if top_k_index.numel() == 0: | ||||
|             return final_hidden_states | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False) | ||||
|         for expert_idx_tensor in expert_hit: | ||||
|             expert_idx = int(expert_idx_tensor.item()) | ||||
|             selection_idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             if token_idx.numel() == 0: | ||||
|                 continue | ||||
|             current_state = hidden_states[token_idx] | ||||
|  | ||||
|             if expert_idx >= self.num_routed_experts or self.gate_up_proj is None: | ||||
|                 current_hidden_states = current_state | ||||
|             else: | ||||
|                 gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|                 current_hidden_states = self.act_fn(gate) * up | ||||
|                 current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             current_hidden_states = current_hidden_states * top_k_weights[token_idx, selection_idx, None] | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(hidden_states.dtype)) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| @ -237,7 +215,7 @@ class LongcatFlashMoE(nn.Module): | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         orig_shape = hidden_states.shape | ||||
|         topk_weights, topk_indices = self.router(hidden_states) | ||||
|         topk_indices, topk_weights = self.router(hidden_states) | ||||
|         hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||||
|         hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) | ||||
|         return hidden_states | ||||
|  | ||||
| @ -20,7 +20,6 @@ import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|  | ||||
| from ...activations import ACT2FN | ||||
| from ...cache_utils import Cache, DynamicCache | ||||
| from ...masking_utils import create_causal_mask | ||||
| from ...modeling_flash_attention_utils import FlashAttentionKwargs | ||||
| @ -91,54 +90,32 @@ class LongcatFlashTopkRouter(DeepseekV3TopkRouter): | ||||
|         topk_indices = self.get_topk_indices(scores) | ||||
|         topk_weights = scores.gather(1, topk_indices) | ||||
|         topk_weights = topk_weights * self.routed_scaling_factor | ||||
|         return topk_weights.to(router_logits.dtype), topk_indices | ||||
|         return topk_indices, topk_weights | ||||
|  | ||||
|  | ||||
| class LongcatFlashExperts(nn.Module): | ||||
| class LongcatFlashExperts(nn.ModuleList): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.intermediate_size = config.expert_ffn_hidden_size | ||||
|         self.hidden_size = config.hidden_size | ||||
|         self.num_routed_experts = config.n_routed_experts | ||||
|         self.zero_expert_num = config.zero_expert_num or 0 | ||||
|         self.total_experts = self.num_routed_experts + self.zero_expert_num | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         self.num_experts = config.n_routed_experts + config.zero_expert_num | ||||
|         self.zero_expert_num = config.zero_expert_num | ||||
|  | ||||
|         if self.num_routed_experts > 0: | ||||
|             self.gate_up_proj = nn.Parameter( | ||||
|                 torch.empty(self.total_experts, 2 * self.intermediate_size, self.hidden_size) | ||||
|             ) | ||||
|             self.down_proj = nn.Parameter( | ||||
|                 torch.empty(self.num_routed_experts, self.hidden_size, self.intermediate_size) | ||||
|             ) | ||||
|         else: | ||||
|             self.register_parameter("gate_up_proj", None) | ||||
|             self.register_parameter("down_proj", None) | ||||
|         self.extend( | ||||
|             [LongcatFlashMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)] | ||||
|             + [nn.Identity() for _ in range(self.zero_expert_num)] | ||||
|         ) | ||||
|  | ||||
|     def forward(self, hidden_states, top_k_index, top_k_weights): | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         if top_k_index.numel() == 0: | ||||
|             return final_hidden_states | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False) | ||||
|         for expert_idx_tensor in expert_hit: | ||||
|             expert_idx = int(expert_idx_tensor.item()) | ||||
|             selection_idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             if token_idx.numel() == 0: | ||||
|                 continue | ||||
|             current_state = hidden_states[token_idx] | ||||
|  | ||||
|             if expert_idx >= self.num_routed_experts or self.gate_up_proj is None: | ||||
|                 current_hidden_states = current_state | ||||
|             else: | ||||
|                 gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|                 current_hidden_states = self.act_fn(gate) * up | ||||
|                 current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             current_hidden_states = current_hidden_states * top_k_weights[token_idx, selection_idx, None] | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(hidden_states.dtype)) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| @ -158,7 +135,7 @@ class LongcatFlashMoE(nn.Module): | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         orig_shape = hidden_states.shape | ||||
|         topk_weights, topk_indices = self.router(hidden_states) | ||||
|         topk_indices, topk_weights = self.router(hidden_states) | ||||
|         hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||||
|         hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) | ||||
|         return hidden_states | ||||
|  | ||||
| @ -452,63 +452,57 @@ class MiniMaxAttention(nn.Module): | ||||
|         return attn_output, attn_weights | ||||
|  | ||||
|  | ||||
| class MiniMaxExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class MiniMaxMLP(nn.Module): | ||||
|     def __init__(self, config: MiniMaxConfig): | ||||
|         super().__init__() | ||||
|         self.ffn_dim = config.intermediate_size | ||||
|         self.hidden_dim = config.hidden_size | ||||
|  | ||||
|         self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | ||||
|         self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) | ||||
|         self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | ||||
|  | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) | ||||
|         current_hidden_states = self.w2(current_hidden_states) | ||||
|         return current_hidden_states | ||||
|  | ||||
|  | ||||
| class MiniMaxExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: MiniMaxConfig): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class MiniMaxTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(MiniMaxMLP(config)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|     def forward( | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class MiniMaxSparseMoeBlock(nn.Module): | ||||
| @ -516,15 +510,22 @@ class MiniMaxSparseMoeBlock(nn.Module): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.jitter_noise = config.router_jitter_noise | ||||
|         self.gate = MiniMaxTopKRouter(config) | ||||
|         self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||||
|         self.experts = MiniMaxExperts(config) | ||||
|  | ||||
|     def route_tokens_to_experts(self, router_logits): | ||||
|         routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) | ||||
|         top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) | ||||
|         top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) | ||||
|         return top_k_index, top_k_weights.to(router_logits.dtype) | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         if self.training and self.jitter_noise > 0: | ||||
|             hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) | ||||
|         hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||||
|         top_k_weights, top_k_index = self.gate(hidden_states) | ||||
|         router_logits = self.gate(hidden_states) | ||||
|         top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) | ||||
|         hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)) | ||||
|         hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||||
|         return hidden_states | ||||
| @ -536,6 +537,8 @@ class MiniMaxDecoderLayer(GradientCheckpointingLayer): | ||||
|         self.hidden_size = config.hidden_size | ||||
|  | ||||
|         self.self_attn = MiniMaxAttention(config, layer_idx) | ||||
|  | ||||
|         self.block_sparse_moe = MiniMaxSparseMoeBlock(config) | ||||
|         self.input_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.post_attention_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|  | ||||
| @ -543,7 +546,7 @@ class MiniMaxDecoderLayer(GradientCheckpointingLayer): | ||||
|         self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None | ||||
|         self.mlp_alpha_factor = config.mlp_alpha_factor | ||||
|         self.mlp_beta_factor = config.mlp_beta_factor | ||||
|         self.block_sparse_moe = MiniMaxSparseMoeBlock(config) | ||||
|  | ||||
|         if self.layer_type == "linear_attention": | ||||
|             self.self_attn = MiniMaxLightningAttention(config, layer_idx) | ||||
|             self.attn_alpha_factor = config.linear_attn_alpha_factor | ||||
|  | ||||
| @ -476,8 +476,7 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer): | ||||
|         self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None | ||||
|         self.mlp_alpha_factor = config.mlp_alpha_factor | ||||
|         self.mlp_beta_factor = config.mlp_beta_factor | ||||
|         del self.mlp | ||||
|         self.block_sparse_moe = MiniMaxSparseMoeBlock(config) | ||||
|  | ||||
|         if self.layer_type == "linear_attention": | ||||
|             self.self_attn = MiniMaxLightningAttention(config, layer_idx) | ||||
|             self.attn_alpha_factor = config.linear_attn_alpha_factor | ||||
|  | ||||
| @ -115,16 +115,14 @@ class MixtralConfig(PreTrainedConfig): | ||||
|     model_type = "mixtral" | ||||
|     keys_to_ignore_at_inference = ["past_key_values"] | ||||
|     base_model_tp_plan = { | ||||
|         "layers.*.self_attn.q_proj": "local_colwise", | ||||
|         "layers.*.self_attn.k_proj": "local_colwise", | ||||
|         "layers.*.self_attn.v_proj": "local_colwise", | ||||
|         "layers.*.self_attn.o_proj": "local_rowwise", | ||||
|         "layers.*.self_attn": "gather", | ||||
|         "layers.*.mlp.gate": "ep_router",  # we need to replicate here to correctly route experts | ||||
|         "layers.*.mlp.experts.gate_up_proj": "local_colwise", | ||||
|         "layers.*.mlp.experts.down_proj": "local_rowwise", | ||||
|         "layers.*.mlp.experts": "gather", | ||||
|         # "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise" ? if you load from | ||||
|         "layers.*.self_attn.q_proj": "colwise", | ||||
|         "layers.*.self_attn.k_proj": "colwise", | ||||
|         "layers.*.self_attn.v_proj": "colwise", | ||||
|         "layers.*.self_attn.o_proj": "rowwise", | ||||
|         "layers.*.block_sparse_moe.gate": "colwise_rep",  # we need to replicate here to correctly route experts | ||||
|         "layers.*.block_sparse_moe.experts.*.w1": "colwise", | ||||
|         "layers.*.block_sparse_moe.experts.*.w2": "rowwise", | ||||
|         "layers.*.block_sparse_moe.experts.*.w3": "colwise", | ||||
|     } | ||||
|     base_model_pp_plan = { | ||||
|         "embed_tokens": (["input_ids"], ["inputs_embeds"]), | ||||
|  | ||||
| @ -28,7 +28,6 @@ from collections.abc import Callable | ||||
| from typing import Optional, Union | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|  | ||||
| from transformers.utils.generic import check_model_inputs | ||||
| @ -54,63 +53,57 @@ from ...utils.generic import OutputRecorder | ||||
| from .configuration_mixtral import MixtralConfig | ||||
|  | ||||
|  | ||||
| class MixtralExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class MixtralMLP(nn.Module): | ||||
|     def __init__(self, config: MixtralConfig): | ||||
|         super().__init__() | ||||
|         self.ffn_dim = config.intermediate_size | ||||
|         self.hidden_dim = config.hidden_size | ||||
|  | ||||
|         self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | ||||
|         self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) | ||||
|         self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | ||||
|  | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) | ||||
|         current_hidden_states = self.w2(current_hidden_states) | ||||
|         return current_hidden_states | ||||
|  | ||||
|  | ||||
| class MixtralExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: MixtralConfig): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class MixtralTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(MixtralMLP(config)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|     def forward( | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class MixtralSparseMoeBlock(nn.Module): | ||||
| @ -118,15 +111,22 @@ class MixtralSparseMoeBlock(nn.Module): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.jitter_noise = config.router_jitter_noise | ||||
|         self.gate = MixtralTopKRouter(config) | ||||
|         self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||||
|         self.experts = MixtralExperts(config) | ||||
|  | ||||
|     def route_tokens_to_experts(self, router_logits): | ||||
|         routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) | ||||
|         top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) | ||||
|         top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) | ||||
|         return top_k_index, top_k_weights.to(router_logits.dtype) | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         if self.training and self.jitter_noise > 0: | ||||
|             hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) | ||||
|         hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||||
|         top_k_weights, top_k_index = self.gate(hidden_states) | ||||
|         router_logits = self.gate(hidden_states) | ||||
|         top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) | ||||
|         hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)) | ||||
|         hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||||
|         return hidden_states | ||||
| @ -359,7 +359,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer): | ||||
|  | ||||
|         self.self_attn = MixtralAttention(config, layer_idx) | ||||
|  | ||||
|         self.mlp = MixtralSparseMoeBlock(config) | ||||
|         self.block_sparse_moe = MixtralSparseMoeBlock(config) | ||||
|         self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|  | ||||
| @ -387,7 +387,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer): | ||||
|         hidden_states = residual + hidden_states | ||||
|         residual = hidden_states | ||||
|         hidden_states = self.post_attention_layernorm(hidden_states) | ||||
|         hidden_states = self.mlp(hidden_states) | ||||
|         hidden_states = self.block_sparse_moe(hidden_states) | ||||
|         hidden_states = residual + hidden_states | ||||
|         return hidden_states | ||||
|  | ||||
| @ -405,7 +405,7 @@ class MixtralPreTrainedModel(PreTrainedModel): | ||||
|     _can_compile_fullgraph = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported) | ||||
|     _supports_attention_backend = True | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0), | ||||
|         "hidden_states": MixtralDecoderLayer, | ||||
|         "attentions": MixtralAttention, | ||||
|     } | ||||
|  | ||||
| @ -22,7 +22,6 @@ | ||||
| from typing import Optional, Union | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|  | ||||
| from ...activations import ACT2FN | ||||
| @ -132,63 +131,57 @@ def load_balancing_loss_func( | ||||
|     return overall_loss * num_experts | ||||
|  | ||||
|  | ||||
| class MixtralExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class MixtralMLP(nn.Module): | ||||
|     def __init__(self, config: MixtralConfig): | ||||
|         super().__init__() | ||||
|         self.ffn_dim = config.intermediate_size | ||||
|         self.hidden_dim = config.hidden_size | ||||
|  | ||||
|         self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | ||||
|         self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) | ||||
|         self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | ||||
|  | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) | ||||
|         current_hidden_states = self.w2(current_hidden_states) | ||||
|         return current_hidden_states | ||||
|  | ||||
|  | ||||
| class MixtralExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: MixtralConfig): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class MixtralTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(MixtralMLP(config)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|     def forward( | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class MixtralSparseMoeBlock(nn.Module): | ||||
| @ -196,15 +189,22 @@ class MixtralSparseMoeBlock(nn.Module): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.jitter_noise = config.router_jitter_noise | ||||
|         self.gate = MixtralTopKRouter(config) | ||||
|         self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||||
|         self.experts = MixtralExperts(config) | ||||
|  | ||||
|     def route_tokens_to_experts(self, router_logits): | ||||
|         routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) | ||||
|         top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) | ||||
|         top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) | ||||
|         return top_k_index, top_k_weights.to(router_logits.dtype) | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         if self.training and self.jitter_noise > 0: | ||||
|             hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) | ||||
|         hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||||
|         top_k_weights, top_k_index = self.gate(hidden_states) | ||||
|         router_logits = self.gate(hidden_states) | ||||
|         top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) | ||||
|         hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)) | ||||
|         hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||||
|         return hidden_states | ||||
| @ -229,7 +229,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer): | ||||
|  | ||||
|         self.self_attn = MixtralAttention(config, layer_idx) | ||||
|  | ||||
|         self.mlp = MixtralSparseMoeBlock(config) | ||||
|         self.block_sparse_moe = MixtralSparseMoeBlock(config) | ||||
|         self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|  | ||||
| @ -257,7 +257,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer): | ||||
|         hidden_states = residual + hidden_states | ||||
|         residual = hidden_states | ||||
|         hidden_states = self.post_attention_layernorm(hidden_states) | ||||
|         hidden_states = self.mlp(hidden_states) | ||||
|         hidden_states = self.block_sparse_moe(hidden_states) | ||||
|         hidden_states = residual + hidden_states | ||||
|         return hidden_states | ||||
|  | ||||
| @ -265,7 +265,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer): | ||||
| class MixtralPreTrainedModel(MistralPreTrainedModel): | ||||
|     _can_compile_fullgraph = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported) | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0), | ||||
|         "hidden_states": MixtralDecoderLayer, | ||||
|         "attentions": MixtralAttention, | ||||
|     } | ||||
|  | ||||
| @ -104,7 +104,6 @@ class OlmoeConfig(PreTrainedConfig): | ||||
|  | ||||
|     model_type = "olmoe" | ||||
|     keys_to_ignore_at_inference = ["past_key_values"] | ||||
|     attribute_map = {"num_local_experts": "num_experts"} | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|  | ||||
| @ -20,7 +20,6 @@ from collections.abc import Callable | ||||
| from typing import Optional, Union | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|  | ||||
| from ...activations import ACT2FN | ||||
| @ -295,78 +294,64 @@ class OlmoeAttention(nn.Module): | ||||
|         return attn_output, attn_weights | ||||
|  | ||||
|  | ||||
| class OlmoeExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class OlmoeExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: OlmoeConfig): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|     ) -> torch.Tensor: | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class OlmoeTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         for _ in range(config.num_experts): | ||||
|             self.append(OlmoeMLP(config)) | ||||
|         self.num_experts = config.num_experts | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         if self.norm_topk_prob: | ||||
|             router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_top_value = router_top_value.to(router_logits.dtype) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|     def forward( | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class OlmoeSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.gate = OlmoeTopKRouter(config) | ||||
|         self.num_experts = config.num_experts | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) | ||||
|         self.experts = OlmoeExperts(config) | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states, router_logits): | ||||
|         routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) | ||||
|         top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) | ||||
|         if self.norm_topk_prob: | ||||
|             top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) | ||||
|         top_k_weights = top_k_weights.to(hidden_states.dtype) | ||||
|         return top_k_index, top_k_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states = hidden_states.view(-1, hidden_dim) | ||||
|         top_k_weights, top_k_index = self.gate(hidden_states) | ||||
|         router_logits = self.gate(hidden_states) | ||||
|         top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) | ||||
|         final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( | ||||
|             batch_size, sequence_length, hidden_dim | ||||
|         ) | ||||
| @ -426,7 +411,7 @@ class OlmoePreTrainedModel(PreTrainedModel): | ||||
|     _supports_flash_attn = True | ||||
|     _supports_sdpa = True | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(OlmoeTopKRouter, index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="gate", index=1), | ||||
|         "hidden_states": OlmoeDecoderLayer, | ||||
|         "attentions": OlmoeAttention, | ||||
|     } | ||||
|  | ||||
| @ -35,7 +35,6 @@ from ..llama.modeling_llama import ( | ||||
|     eager_attention_forward, | ||||
| ) | ||||
| from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM, MixtralModel | ||||
| from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeTopKRouter | ||||
| from .configuration_olmoe import OlmoeConfig | ||||
|  | ||||
|  | ||||
| @ -116,24 +115,38 @@ class OlmoeAttention(LlamaAttention): | ||||
|         return attn_output, attn_weights | ||||
|  | ||||
|  | ||||
| class OlmoeExperts(MixtralExperts): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class OlmoeTopKRouter(Qwen2MoeTopKRouter): | ||||
|     pass | ||||
| class OlmoeExperts(MixtralExperts, nn.ModuleList): | ||||
|     def __init__(self, config): | ||||
|         nn.ModuleList.__init__(self) | ||||
|         for _ in range(config.num_experts): | ||||
|             self.append(OlmoeMLP(config)) | ||||
|         self.num_experts = config.num_experts | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|  | ||||
|  | ||||
| class OlmoeSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.gate = OlmoeTopKRouter(config) | ||||
|         self.num_experts = config.num_experts | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) | ||||
|         self.experts = OlmoeExperts(config) | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states, router_logits): | ||||
|         routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) | ||||
|         top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) | ||||
|         if self.norm_topk_prob: | ||||
|             top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) | ||||
|         top_k_weights = top_k_weights.to(hidden_states.dtype) | ||||
|         return top_k_index, top_k_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states = hidden_states.view(-1, hidden_dim) | ||||
|         top_k_weights, top_k_index = self.gate(hidden_states) | ||||
|         router_logits = self.gate(hidden_states) | ||||
|         top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) | ||||
|         final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( | ||||
|             batch_size, sequence_length, hidden_dim | ||||
|         ) | ||||
| @ -160,7 +173,7 @@ class OlmoePreTrainedModel(PreTrainedModel): | ||||
|     _supports_flash_attn = True | ||||
|     _supports_sdpa = True | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(OlmoeTopKRouter, index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="gate", index=1), | ||||
|         "hidden_states": OlmoeDecoderLayer, | ||||
|         "attentions": OlmoeAttention, | ||||
|     } | ||||
|  | ||||
| @ -262,6 +262,24 @@ class PhimoeAttention(nn.Module): | ||||
|         return attn_output, attn_weights | ||||
|  | ||||
|  | ||||
| class PhimoeMLP(nn.Module): | ||||
|     def __init__(self, config: PhimoeConfig): | ||||
|         super().__init__() | ||||
|         self.ffn_dim = config.intermediate_size | ||||
|         self.hidden_dim = config.hidden_size | ||||
|  | ||||
|         self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | ||||
|         self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) | ||||
|         self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | ||||
|  | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) | ||||
|         current_hidden_states = self.w2(current_hidden_states) | ||||
|         return current_hidden_states | ||||
|  | ||||
|  | ||||
| class PhimoeMultiplier(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     def forward( | ||||
| @ -324,47 +342,58 @@ class PhimoeMultiplier(torch.autograd.Function): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class PhimoeExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class PhimoeExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: PhimoeConfig): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(PhimoeMLP(config)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class PhimoeRouter(nn.Linear): | ||||
|     def __init__(self, config: PhimoeConfig): | ||||
|         super().__init__(config.hidden_size, config.num_local_experts, bias=False) | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.router_jitter_noise = config.router_jitter_noise | ||||
|         self.input_jitter_noise = config.router_jitter_noise | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         if self.training and self.input_jitter_noise > 0: | ||||
|             hidden_states *= torch.empty_like(hidden_states).uniform_( | ||||
|                 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise | ||||
|             ) | ||||
|         router_logits = super().forward(hidden_states) | ||||
|         return router_logits | ||||
|  | ||||
|  | ||||
| def sparsemixer(scores, jitter_eps, training, top_k=2): | ||||
|     """ | ||||
|     Sparse mixer function to select top-k experts and compute multipliers. | ||||
| @ -488,27 +517,6 @@ def sparsemixer(scores, jitter_eps, training, top_k=2): | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class PhimoeTopKRouter(nn.Linear): | ||||
|     def __init__(self, config: PhimoeConfig): | ||||
|         super().__init__(config.hidden_size, config.num_local_experts, bias=False) | ||||
|         self.router_jitter_noise = config.router_jitter_noise | ||||
|         self.input_jitter_noise = config.input_jitter_noise | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         if self.training and self.input_jitter_noise > 0: | ||||
|             hidden_states *= torch.empty_like(hidden_states).uniform_( | ||||
|                 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise | ||||
|             ) | ||||
|         router_logits = super().forward(hidden_states) | ||||
|         routing_weights, selected_experts = sparsemixer( | ||||
|             router_logits, | ||||
|             jitter_eps=self.router_jitter_noise, | ||||
|             training=self.training, | ||||
|         ) | ||||
|         routing_weights = torch.zeros_like(router_logits).scatter_(1, selected_experts, routing_weights) | ||||
|         return routing_weights, selected_experts | ||||
|  | ||||
|  | ||||
| class PhimoeSparseMoeBlock(nn.Module): | ||||
|     """ | ||||
|     This implementation is | ||||
| @ -527,10 +535,19 @@ class PhimoeSparseMoeBlock(nn.Module): | ||||
|         self.ffn_dim = config.intermediate_size | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.router = PhimoeTopKRouter(config) | ||||
|         self.router_jitter_noise = config.router_jitter_noise | ||||
|         self.gate = PhimoeRouter(config) | ||||
|         self.experts = PhimoeExperts(config) | ||||
|         self.input_jitter_noise = config.input_jitter_noise | ||||
|  | ||||
|     def route_tokens_to_experts(self, router_logits): | ||||
|         routing_weights, selected_experts = sparsemixer( | ||||
|             router_logits, | ||||
|             jitter_eps=self.router_jitter_noise, | ||||
|             training=self.training, | ||||
|         ) | ||||
|         return routing_weights, selected_experts | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         if self.training and self.input_jitter_noise > 0: | ||||
| @ -540,7 +557,8 @@ class PhimoeSparseMoeBlock(nn.Module): | ||||
|  | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states = hidden_states.reshape(-1, hidden_dim) | ||||
|         routing_weights, selected_experts = self.router(hidden_states) | ||||
|         router_logits = self.gate(hidden_states) | ||||
|         routing_weights, selected_experts = self.route_tokens_to_experts(router_logits) | ||||
|         final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) | ||||
|         return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||||
|  | ||||
| @ -573,7 +591,7 @@ class PhimoeDecoderLayer(GradientCheckpointingLayer): | ||||
|  | ||||
|         self.self_attn = PhimoeAttention(config, layer_idx) | ||||
|  | ||||
|         self.mlp = PhimoeSparseMoeBlock(config) | ||||
|         self.block_sparse_moe = PhimoeSparseMoeBlock(config) | ||||
|         self.input_layernorm = PhimoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.post_attention_layernorm = PhimoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|  | ||||
| @ -601,7 +619,7 @@ class PhimoeDecoderLayer(GradientCheckpointingLayer): | ||||
|         hidden_states = residual + hidden_states | ||||
|         residual = hidden_states | ||||
|         hidden_states = self.post_attention_layernorm(hidden_states) | ||||
|         hidden_states = self.mlp(hidden_states) | ||||
|         hidden_states = self.block_sparse_moe(hidden_states) | ||||
|         hidden_states = residual + hidden_states | ||||
|         return hidden_states | ||||
|  | ||||
| @ -619,7 +637,7 @@ class PhimoePreTrainedModel(PreTrainedModel): | ||||
|     _can_compile_fullgraph = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported) | ||||
|     _supports_attention_backend = True | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(PhimoeTopKRouter, layer_name="mlp.router", index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "hidden_states": PhimoeDecoderLayer, | ||||
|         "attentions": PhimoeAttention, | ||||
|     } | ||||
|  | ||||
| @ -30,6 +30,7 @@ from ..mixtral.modeling_mixtral import ( | ||||
|     MixtralDecoderLayer, | ||||
|     MixtralExperts, | ||||
|     MixtralForCausalLM, | ||||
|     MixtralMLP, | ||||
|     MixtralModel, | ||||
|     MixtralPreTrainedModel, | ||||
|     MixtralRotaryEmbedding, | ||||
| @ -86,6 +87,10 @@ class PhimoeAttention(LlamaAttention): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class PhimoeMLP(MixtralMLP): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class PhimoeMultiplier(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     def forward( | ||||
| @ -271,29 +276,30 @@ def sparsemixer(scores, jitter_eps, training, top_k=2): | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class PhimoeExperts(MixtralExperts): | ||||
|     pass | ||||
| class PhimoeExperts(MixtralExperts, nn.ModuleList): | ||||
|     def __init__(self, config: PhimoeConfig): | ||||
|         nn.ModuleList.__init__(self) | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_local_experts | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(PhimoeMLP(config)) | ||||
|  | ||||
|  | ||||
| class PhimoeTopKRouter(nn.Linear): | ||||
| class PhimoeRouter(nn.Linear): | ||||
|     def __init__(self, config: PhimoeConfig): | ||||
|         super().__init__(config.hidden_size, config.num_local_experts, bias=False) | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.router_jitter_noise = config.router_jitter_noise | ||||
|         self.input_jitter_noise = config.input_jitter_noise | ||||
|         self.input_jitter_noise = config.router_jitter_noise | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|     def forward(self, hidden_states): | ||||
|         if self.training and self.input_jitter_noise > 0: | ||||
|             hidden_states *= torch.empty_like(hidden_states).uniform_( | ||||
|                 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise | ||||
|             ) | ||||
|         router_logits = super().forward(hidden_states) | ||||
|         routing_weights, selected_experts = sparsemixer( | ||||
|             router_logits, | ||||
|             jitter_eps=self.router_jitter_noise, | ||||
|             training=self.training, | ||||
|         ) | ||||
|         routing_weights = torch.zeros_like(router_logits).scatter_(1, selected_experts, routing_weights) | ||||
|         return routing_weights, selected_experts | ||||
|         return router_logits | ||||
|  | ||||
|  | ||||
| class PhimoeSparseMoeBlock(nn.Module): | ||||
| @ -314,10 +320,19 @@ class PhimoeSparseMoeBlock(nn.Module): | ||||
|         self.ffn_dim = config.intermediate_size | ||||
|         self.num_experts = config.num_local_experts | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.router = PhimoeTopKRouter(config) | ||||
|         self.router_jitter_noise = config.router_jitter_noise | ||||
|         self.gate = PhimoeRouter(config) | ||||
|         self.experts = PhimoeExperts(config) | ||||
|         self.input_jitter_noise = config.input_jitter_noise | ||||
|  | ||||
|     def route_tokens_to_experts(self, router_logits): | ||||
|         routing_weights, selected_experts = sparsemixer( | ||||
|             router_logits, | ||||
|             jitter_eps=self.router_jitter_noise, | ||||
|             training=self.training, | ||||
|         ) | ||||
|         return routing_weights, selected_experts | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         if self.training and self.input_jitter_noise > 0: | ||||
| @ -327,7 +342,8 @@ class PhimoeSparseMoeBlock(nn.Module): | ||||
|  | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states = hidden_states.reshape(-1, hidden_dim) | ||||
|         routing_weights, selected_experts = self.router(hidden_states) | ||||
|         router_logits = self.gate(hidden_states) | ||||
|         routing_weights, selected_experts = self.route_tokens_to_experts(router_logits) | ||||
|         final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) | ||||
|         return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||||
|  | ||||
| @ -338,7 +354,7 @@ class PhimoeDecoderLayer(MixtralDecoderLayer): | ||||
|  | ||||
| class PhimoePreTrainedModel(MixtralPreTrainedModel): | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(PhimoeTopKRouter, layer_name="mlp.router", index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "hidden_states": PhimoeDecoderLayer, | ||||
|         "attentions": PhimoeAttention, | ||||
|     } | ||||
|  | ||||
| @ -289,81 +289,66 @@ class Qwen2MoeAttention(nn.Module): | ||||
|         return attn_output, attn_weights | ||||
|  | ||||
|  | ||||
| class Qwen2MoeExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class Qwen2MoeExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.moe_intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(config.num_experts): | ||||
|             self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class Qwen2MoeTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_experts | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         if self.norm_topk_prob: | ||||
|             router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_top_value = router_top_value.to(router_logits.dtype) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|  | ||||
|  | ||||
| class Qwen2MoeSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.gate = Qwen2MoeTopKRouter(config) | ||||
|         # gating | ||||
|         self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||||
|         self.experts = Qwen2MoeExperts(config) | ||||
|         self.num_experts_per_tok = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|  | ||||
|         self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) | ||||
|         self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states, router_logits): | ||||
|         routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) | ||||
|         routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) | ||||
|         if self.norm_topk_prob: | ||||
|             routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return selected_experts, routing_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states_reshaped = hidden_states.view(-1, hidden_dim) | ||||
|         shared_expert_output = self.shared_expert(hidden_states_reshaped) | ||||
|         routing_weights, selected_experts = self.gate(hidden_states_reshaped) | ||||
|         router_logits = self.gate(hidden_states_reshaped) | ||||
|         selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) | ||||
|         expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) | ||||
|  | ||||
|         shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output | ||||
| @ -434,7 +419,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): | ||||
|     _can_compile_fullgraph = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported) | ||||
|     _supports_attention_backend = True | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(Qwen2MoeTopKRouter, index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "hidden_states": Qwen2MoeDecoderLayer, | ||||
|         "attentions": Qwen2MoeAttention, | ||||
|     } | ||||
|  | ||||
| @ -82,47 +82,40 @@ class Qwen2MoeAttention(LlamaAttention): | ||||
|         self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) | ||||
|  | ||||
|  | ||||
| class Qwen2MoeExperts(MixtralExperts): | ||||
| class Qwen2MoeExperts(MixtralExperts, nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__(config) | ||||
|         nn.ModuleList.__init__(self) | ||||
|         self.num_experts = config.num_experts | ||||
|         self.intermediate_dim = config.moe_intermediate_size | ||||
|  | ||||
|  | ||||
| class Qwen2MoeTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_experts | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         if self.norm_topk_prob: | ||||
|             router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_top_value = router_top_value.to(router_logits.dtype) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|         for _ in range(config.num_experts): | ||||
|             self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|  | ||||
| class Qwen2MoeSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.gate = Qwen2MoeTopKRouter(config) | ||||
|         # gating | ||||
|         self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||||
|         self.experts = Qwen2MoeExperts(config) | ||||
|         self.num_experts_per_tok = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|  | ||||
|         self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) | ||||
|         self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states, router_logits): | ||||
|         routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) | ||||
|         routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) | ||||
|         if self.norm_topk_prob: | ||||
|             routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return selected_experts, routing_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states_reshaped = hidden_states.view(-1, hidden_dim) | ||||
|         shared_expert_output = self.shared_expert(hidden_states_reshaped) | ||||
|         routing_weights, selected_experts = self.gate(hidden_states_reshaped) | ||||
|         router_logits = self.gate(hidden_states_reshaped) | ||||
|         selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) | ||||
|         expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) | ||||
|  | ||||
|         shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output | ||||
| @ -150,7 +143,7 @@ class Qwen2MoeDecoderLayer(LlamaDecoderLayer, nn.Module): | ||||
| @auto_docstring | ||||
| class Qwen2MoePreTrainedModel(MixtralPreTrainedModel): | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(Qwen2MoeTopKRouter, index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "hidden_states": Qwen2MoeDecoderLayer, | ||||
|         "attentions": Qwen2MoeAttention, | ||||
|     } | ||||
|  | ||||
| @ -209,78 +209,61 @@ class Qwen3MoeMLP(nn.Module): | ||||
|         return down_proj | ||||
|  | ||||
|  | ||||
| class Qwen3Moe(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class Qwen3MoeExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|     def __init__(self, config: Qwen3MoeConfig): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.moe_intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class Qwen3MoeTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_experts | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         if self.norm_topk_prob: | ||||
|             router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_top_value = router_top_value.to(router_logits.dtype) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|  | ||||
|  | ||||
| class Qwen3MoeSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config: Qwen3MoeConfig): | ||||
|         super().__init__() | ||||
|         self.experts = Qwen3Moe(config) | ||||
|         self.router = Qwen3MoeTopKRouter(config) | ||||
|         self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||||
|         self.experts = Qwen3MoeExperts(config) | ||||
|         self.num_experts_per_tok = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states, router_logits): | ||||
|         routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) | ||||
|         routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) | ||||
|         if self.norm_topk_prob: | ||||
|             routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return selected_experts, routing_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states_reshaped = hidden_states.view(-1, hidden_dim) | ||||
|         routing_weights, selected_experts = self.router(hidden_states_reshaped) | ||||
|         router_logits = self.gate(hidden_states_reshaped) | ||||
|         selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) | ||||
|         final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) | ||||
|         return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||||
|  | ||||
| @ -367,7 +350,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): | ||||
|     _can_compile_fullgraph = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported) | ||||
|     _supports_attention_backend = True | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.router", index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "hidden_states": Qwen3MoeDecoderLayer, | ||||
|         "attentions": Qwen3MoeAttention, | ||||
|     } | ||||
|  | ||||
| @ -17,6 +17,7 @@ | ||||
| from typing import Optional, Union | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|  | ||||
| from ...cache_utils import Cache | ||||
| @ -31,12 +32,13 @@ from ..llama.modeling_llama import ( | ||||
|     LlamaRMSNorm, | ||||
| ) | ||||
| from ..mixtral.modeling_mixtral import ( | ||||
|     MixtralExperts, | ||||
|     MixtralForCausalLM, | ||||
|     MixtralModel, | ||||
|     MixtralPreTrainedModel, | ||||
|     load_balancing_loss_func, | ||||
| ) | ||||
| from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeExperts, Qwen2MoeMLP, Qwen2MoeTopKRouter | ||||
| from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeMLP | ||||
| from ..qwen3.modeling_qwen3 import Qwen3Attention | ||||
| from .configuration_qwen3_moe import Qwen3MoeConfig | ||||
|  | ||||
| @ -55,24 +57,35 @@ class Qwen3MoeMLP(Qwen2MoeMLP): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class Qwen3Moe(Qwen2MoeExperts): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class Qwen3MoeTopKRouter(Qwen2MoeTopKRouter): | ||||
|     pass | ||||
| class Qwen3MoeExperts(MixtralExperts, nn.ModuleList): | ||||
|     def __init__(self, config: Qwen3MoeConfig): | ||||
|         nn.ModuleList.__init__(self) | ||||
|         self.num_experts = config.num_experts | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|  | ||||
| class Qwen3MoeSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config: Qwen3MoeConfig): | ||||
|         super().__init__() | ||||
|         self.experts = Qwen3Moe(config) | ||||
|         self.router = Qwen3MoeTopKRouter(config) | ||||
|         self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||||
|         self.experts = Qwen3MoeExperts(config) | ||||
|         self.num_experts_per_tok = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states, router_logits): | ||||
|         routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) | ||||
|         routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) | ||||
|         if self.norm_topk_prob: | ||||
|             routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return selected_experts, routing_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states_reshaped = hidden_states.view(-1, hidden_dim) | ||||
|         routing_weights, selected_experts = self.router(hidden_states_reshaped) | ||||
|         router_logits = self.gate(hidden_states_reshaped) | ||||
|         selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) | ||||
|         final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) | ||||
|         return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) | ||||
|  | ||||
| @ -87,7 +100,7 @@ class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer): | ||||
|  | ||||
| class Qwen3MoePreTrainedModel(MixtralPreTrainedModel): | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.router", index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "hidden_states": Qwen3MoeDecoderLayer, | ||||
|         "attentions": Qwen3MoeAttention, | ||||
|     } | ||||
|  | ||||
| @ -819,81 +819,66 @@ class Qwen3NextMLP(nn.Module): | ||||
|         return down_proj | ||||
|  | ||||
|  | ||||
| class Qwen3NextExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class Qwen3NextExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.moe_intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(config.num_experts): | ||||
|             self.append(Qwen3NextMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class Qwen3NextTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_experts | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         if self.norm_topk_prob: | ||||
|             router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_top_value = router_top_value.to(router_logits.dtype) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|  | ||||
|  | ||||
| class Qwen3NextSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.gate = Qwen3NextTopKRouter(config) | ||||
|         # gating | ||||
|         self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||||
|         self.experts = Qwen3NextExperts(config) | ||||
|         self.num_experts_per_tok = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|  | ||||
|         self.shared_expert = Qwen3NextMLP(config, intermediate_size=config.shared_expert_intermediate_size) | ||||
|         self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states, router_logits): | ||||
|         routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) | ||||
|         routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) | ||||
|         if self.norm_topk_prob: | ||||
|             routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return selected_experts, routing_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states_reshaped = hidden_states.view(-1, hidden_dim) | ||||
|         shared_expert_output = self.shared_expert(hidden_states_reshaped) | ||||
|         routing_weights, selected_experts = self.gate(hidden_states_reshaped) | ||||
|         router_logits = self.gate(hidden_states_reshaped) | ||||
|         selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) | ||||
|         expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) | ||||
|  | ||||
|         shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output | ||||
|  | ||||
| @ -1323,43 +1323,37 @@ class Qwen3OmniMoeThinkerTextMLP(nn.Module): | ||||
|         return down_proj | ||||
|  | ||||
|  | ||||
| class Qwen3OmniMoeThinkerTextExperts(nn.Module): | ||||
| class Qwen3OmniMoeThinkerTextExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: Qwen3OmniMoeThinkerConfig): | ||||
|         nn.ModuleList.__init__(self) | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_experts | ||||
|         for _ in range(self.num_experts): | ||||
|             self.append(Qwen3OmniMoeThinkerTextMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| @ -2774,83 +2768,68 @@ class Qwen3OmniMoeTalkerTextMLP(nn.Module): | ||||
|         return down_proj | ||||
|  | ||||
|  | ||||
| class Qwen3OmniMoeTalkerTextExperts(nn.Module): | ||||
|     """Collection of expert weights stored as 3D tensors.""" | ||||
| class Qwen3OmniMoeTalkerTextExperts(nn.ModuleList): | ||||
|     """ | ||||
|     ModuleList of experts. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.num_experts = config.num_experts | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.intermediate_dim = config.moe_intermediate_size | ||||
|         self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||||
|         self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | ||||
|         self.act_fn = ACT2FN[config.hidden_act] | ||||
|         for _ in range(config.num_experts): | ||||
|             self.append(Qwen3OmniMoeTalkerTextMLP(config, intermediate_size=config.moe_intermediate_size)) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         top_k_index: torch.Tensor, | ||||
|         top_k_weights: torch.Tensor, | ||||
|         self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Args: | ||||
|             hidden_states: (batch_size * sequence_length, hidden_dim) | ||||
|             selected_experts: (batch_size * sequence_length, top_k) | ||||
|             routing_weights: (batch_size * sequence_length, top_k) | ||||
|         Returns: | ||||
|             (batch_size * sequence_length, hidden_dim) | ||||
|         """ | ||||
|         final_hidden_states = torch.zeros_like(hidden_states) | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) | ||||
|  | ||||
|         num_experts = top_k_weights.shape[1] | ||||
|         expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1).permute(2, 1, 0) | ||||
|         expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||||
|         for expert_idx in expert_hit: | ||||
|             expert_idx = expert_idx[0] | ||||
|             if expert_idx == num_experts: | ||||
|                 continue | ||||
|             with torch.no_grad(): | ||||
|                 _, token_idx = torch.where(expert_mask[expert_idx]) | ||||
|             current_state = hidden_states[token_idx] | ||||
|             gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | ||||
|             current_hidden_states = self.act_fn(gate) * up | ||||
|             current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | ||||
|  | ||||
|             routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) | ||||
|             current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) | ||||
|             final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | ||||
|  | ||||
|             idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) | ||||
|             current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) | ||||
|             current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] | ||||
|             final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | ||||
|         return final_hidden_states | ||||
|  | ||||
|  | ||||
| class Qwen3OmniMoeTalkerTextTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_experts | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         if self.norm_topk_prob: | ||||
|             router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_top_value = router_top_value.to(router_logits.dtype) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|  | ||||
|  | ||||
| class Qwen3OmniMoeTalkerTextSparseMoeBlock(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.gate = Qwen3OmniMoeTalkerTextTopKRouter(config) | ||||
|         # gating | ||||
|         self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||||
|         self.experts = Qwen3OmniMoeTalkerTextExperts(config) | ||||
|         self.num_experts_per_tok = config.num_experts_per_tok | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|  | ||||
|         self.shared_expert = Qwen3OmniMoeTalkerTextMLP( | ||||
|             config, intermediate_size=config.shared_expert_intermediate_size | ||||
|         ) | ||||
|         self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) | ||||
|  | ||||
|     def route_tokens_to_experts(self, hidden_states, router_logits): | ||||
|         routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) | ||||
|         routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) | ||||
|         if self.norm_topk_prob: | ||||
|             routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | ||||
|         routing_weights = routing_weights.to(router_logits.dtype) | ||||
|         return selected_experts, routing_weights | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|         batch_size, sequence_length, hidden_dim = hidden_states.shape | ||||
|         hidden_states_reshaped = hidden_states.view(-1, hidden_dim) | ||||
|         shared_expert_output = self.shared_expert(hidden_states_reshaped) | ||||
|         routing_weights, selected_experts = self.gate(hidden_states_reshaped) | ||||
|         router_logits = self.gate(hidden_states_reshaped) | ||||
|         selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) | ||||
|         expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) | ||||
|  | ||||
|         shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output | ||||
|  | ||||
| @ -365,27 +365,6 @@ class Qwen3VLMoeTextDecoderLayer(GradientCheckpointingLayer): | ||||
|         return hidden_states | ||||
|  | ||||
|  | ||||
| class Qwen3VLMoeTextTopKRouter(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.top_k = config.num_experts_per_tok | ||||
|         self.num_experts = config.num_experts | ||||
|         self.norm_topk_prob = config.norm_topk_prob | ||||
|         self.hidden_dim = config.hidden_size | ||||
|         self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||||
|         router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts) | ||||
|         router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) | ||||
|         router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k) | ||||
|         if self.norm_topk_prob: | ||||
|             router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | ||||
|         router_top_value = router_top_value.to(router_logits.dtype) | ||||
|         router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) | ||||
|         return router_scores, router_indices | ||||
|  | ||||
|  | ||||
| @auto_docstring | ||||
| class Qwen3VLMoePreTrainedModel(PreTrainedModel): | ||||
|     config: Qwen3VLMoeConfig | ||||
| @ -399,7 +378,7 @@ class Qwen3VLMoePreTrainedModel(PreTrainedModel): | ||||
|     _can_compile_fullgraph = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported) | ||||
|     _supports_attention_backend = True | ||||
|     _can_record_outputs = { | ||||
|         "router_logits": OutputRecorder(Qwen3VLMoeTextTopKRouter, layer_name="mlp.router", index=0), | ||||
|         "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), | ||||
|         "hidden_states": Qwen3VLMoeTextDecoderLayer, | ||||
|         "attentions": Qwen3VLMoeTextAttention, | ||||
|     } | ||||
|  | ||||
| @ -510,7 +510,7 @@ class SiglipPreTrainedModel(PreTrainedModel): | ||||
|             nn.init.xavier_uniform_(module.fc2.weight) | ||||
|             nn.init.normal_(module.fc1.bias, std=1e-6) | ||||
|             nn.init.normal_(module.fc2.bias, std=1e-6) | ||||
|         elif isinstance(module, SiglipMultiheadAttentionPoolingHead): | ||||
|         elif "MultiheadAttentionPoolingHead" in module.__class__.__name__: | ||||
|             nn.init.xavier_uniform_(module.probe.data) | ||||
|             nn.init.xavier_uniform_(module.attention.in_proj_weight.data) | ||||
|             nn.init.zeros_(module.attention.in_proj_bias.data) | ||||
| @ -678,9 +678,14 @@ class SiglipTextModel(SiglipPreTrainedModel): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SiglipVisionTransformer(nn.Module): | ||||
| class SiglipVisionTransformer(SiglipPreTrainedModel): | ||||
|     _can_record_outputs = { | ||||
|         "hidden_states": SiglipEncoderLayer, | ||||
|         "attentions": SiglipAttention, | ||||
|     } | ||||
|  | ||||
|     def __init__(self, config: SiglipVisionConfig): | ||||
|         super().__init__() | ||||
|         super().__init__(config) | ||||
|         self.config = config | ||||
|         embed_dim = config.hidden_size | ||||
|  | ||||
| @ -691,6 +696,7 @@ class SiglipVisionTransformer(nn.Module): | ||||
|         if self.use_head: | ||||
|             self.head = SiglipMultiheadAttentionPoolingHead(config) | ||||
|  | ||||
|     @check_model_inputs(tie_last_hidden_states=False) | ||||
|     @auto_docstring | ||||
|     def forward( | ||||
|         self, | ||||
|  | ||||
| @ -349,99 +349,6 @@ class Siglip2EncoderLayer(GradientCheckpointingLayer): | ||||
|         return hidden_states | ||||
|  | ||||
|  | ||||
| class Siglip2Encoder(nn.Module): | ||||
|     """ | ||||
|     Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a | ||||
|     [`Siglip2EncoderLayer`]. | ||||
|  | ||||
|     Args: | ||||
|         config: Siglip2Config | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: Siglip2Config): | ||||
|         super().__init__() | ||||
|         self.config = config | ||||
|         self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) | ||||
|         self.gradient_checkpointing = False | ||||
|  | ||||
|     # Ignore copy | ||||
|     @auto_docstring | ||||
|     def forward( | ||||
|         self, | ||||
|         inputs_embeds, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         **kwargs: Unpack[TransformersKwargs], | ||||
|     ) -> BaseModelOutput: | ||||
|         hidden_states = inputs_embeds | ||||
|         for encoder_layer in self.layers: | ||||
|             hidden_states = encoder_layer( | ||||
|                 hidden_states, | ||||
|                 attention_mask, | ||||
|                 **kwargs, | ||||
|             ) | ||||
|  | ||||
|         return BaseModelOutput(last_hidden_state=hidden_states) | ||||
|  | ||||
|  | ||||
| class Siglip2VisionTransformer(nn.Module): | ||||
|     def __init__(self, config: Siglip2VisionConfig): | ||||
|         super().__init__() | ||||
|         self.config = config | ||||
|         embed_dim = config.hidden_size | ||||
|  | ||||
|         self.embeddings = Siglip2VisionEmbeddings(config) | ||||
|         self.encoder = Siglip2Encoder(config) | ||||
|         self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | ||||
|         self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head | ||||
|         if self.use_head: | ||||
|             self.head = Siglip2MultiheadAttentionPoolingHead(config) | ||||
|  | ||||
|     @auto_docstring | ||||
|     def forward( | ||||
|         self, | ||||
|         pixel_values: torch.FloatTensor, | ||||
|         attention_mask: torch.Tensor, | ||||
|         spatial_shapes: torch.LongTensor, | ||||
|         output_attentions: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
|     ) -> BaseModelOutputWithPooling: | ||||
|         r""" | ||||
|         spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): | ||||
|             Tensor containing the spatial dimensions (height, width) of the input images. | ||||
|         """ | ||||
|         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||||
|         output_hidden_states = ( | ||||
|             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||||
|         ) | ||||
|  | ||||
|         hidden_states = self.embeddings(pixel_values, spatial_shapes) | ||||
|  | ||||
|         if attention_mask is not None and self.config._attn_implementation != "flash_attention_2": | ||||
|             # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] | ||||
|             encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) | ||||
|         else: | ||||
|             encoder_attention_mask = attention_mask | ||||
|  | ||||
|         encoder_outputs: BaseModelOutput = self.encoder( | ||||
|             inputs_embeds=hidden_states, | ||||
|             attention_mask=encoder_attention_mask, | ||||
|             output_attentions=output_attentions, | ||||
|             output_hidden_states=output_hidden_states, | ||||
|         ) | ||||
|  | ||||
|         last_hidden_state = encoder_outputs.last_hidden_state | ||||
|         last_hidden_state = self.post_layernorm(last_hidden_state) | ||||
|  | ||||
|         pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None | ||||
|  | ||||
|         return BaseModelOutputWithPooling( | ||||
|             last_hidden_state=last_hidden_state, | ||||
|             pooler_output=pooler_output, | ||||
|             hidden_states=encoder_outputs.hidden_states, | ||||
|             attentions=encoder_outputs.attentions, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def _trunc_normal_(tensor, mean, std, a, b): | ||||
|     # Cut & paste from PyTorch official master until it's in a few official releases - RW | ||||
|     # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | ||||
| @ -585,7 +492,7 @@ class Siglip2PreTrainedModel(PreTrainedModel): | ||||
|             nn.init.xavier_uniform_(module.fc2.weight) | ||||
|             nn.init.normal_(module.fc1.bias, std=1e-6) | ||||
|             nn.init.normal_(module.fc2.bias, std=1e-6) | ||||
|         elif isinstance(module, Siglip2MultiheadAttentionPoolingHead): | ||||
|         elif "MultiheadAttentionPoolingHead" in module.__class__.__name__: | ||||
|             nn.init.xavier_uniform_(module.probe.data) | ||||
|             nn.init.xavier_uniform_(module.attention.in_proj_weight.data) | ||||
|             nn.init.zeros_(module.attention.in_proj_bias.data) | ||||
| @ -607,6 +514,105 @@ class Siglip2PreTrainedModel(PreTrainedModel): | ||||
|             module.weight.data.fill_(1.0) | ||||
|  | ||||
|  | ||||
| class Siglip2Encoder(nn.Module): | ||||
|     """ | ||||
|     Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a | ||||
|     [`Siglip2EncoderLayer`]. | ||||
|  | ||||
|     Args: | ||||
|         config: Siglip2Config | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: Siglip2Config): | ||||
|         super().__init__() | ||||
|         self.config = config | ||||
|         self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) | ||||
|         self.gradient_checkpointing = False | ||||
|  | ||||
|     # Ignore copy | ||||
|     @auto_docstring | ||||
|     def forward( | ||||
|         self, | ||||
|         inputs_embeds, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         **kwargs: Unpack[TransformersKwargs], | ||||
|     ) -> BaseModelOutput: | ||||
|         hidden_states = inputs_embeds | ||||
|         for encoder_layer in self.layers: | ||||
|             hidden_states = encoder_layer( | ||||
|                 hidden_states, | ||||
|                 attention_mask, | ||||
|                 **kwargs, | ||||
|             ) | ||||
|  | ||||
|         return BaseModelOutput(last_hidden_state=hidden_states) | ||||
|  | ||||
|  | ||||
| class Siglip2VisionTransformer(Siglip2PreTrainedModel): | ||||
|     _can_record_outputs = { | ||||
|         "hidden_states": Siglip2EncoderLayer, | ||||
|         "attentions": Siglip2Attention, | ||||
|     } | ||||
|  | ||||
|     def __init__(self, config: Siglip2VisionConfig): | ||||
|         super().__init__(config) | ||||
|         self.config = config | ||||
|         embed_dim = config.hidden_size | ||||
|  | ||||
|         self.embeddings = Siglip2VisionEmbeddings(config) | ||||
|         self.encoder = Siglip2Encoder(config) | ||||
|         self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | ||||
|         self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head | ||||
|         if self.use_head: | ||||
|             self.head = Siglip2MultiheadAttentionPoolingHead(config) | ||||
|  | ||||
|     @check_model_inputs(tie_last_hidden_states=False) | ||||
|     @auto_docstring | ||||
|     def forward( | ||||
|         self, | ||||
|         pixel_values: torch.FloatTensor, | ||||
|         attention_mask: torch.Tensor, | ||||
|         spatial_shapes: torch.LongTensor, | ||||
|         output_attentions: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
|     ) -> BaseModelOutputWithPooling: | ||||
|         r""" | ||||
|         spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): | ||||
|             Tensor containing the spatial dimensions (height, width) of the input images. | ||||
|         """ | ||||
|         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||||
|         output_hidden_states = ( | ||||
|             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||||
|         ) | ||||
|  | ||||
|         hidden_states = self.embeddings(pixel_values, spatial_shapes) | ||||
|  | ||||
|         if attention_mask is not None and self.config._attn_implementation != "flash_attention_2": | ||||
|             # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] | ||||
|             encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) | ||||
|         else: | ||||
|             encoder_attention_mask = attention_mask | ||||
|  | ||||
|         encoder_outputs: BaseModelOutput = self.encoder( | ||||
|             inputs_embeds=hidden_states, | ||||
|             attention_mask=encoder_attention_mask, | ||||
|             output_attentions=output_attentions, | ||||
|             output_hidden_states=output_hidden_states, | ||||
|         ) | ||||
|  | ||||
|         last_hidden_state = encoder_outputs.last_hidden_state | ||||
|         last_hidden_state = self.post_layernorm(last_hidden_state) | ||||
|  | ||||
|         pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None | ||||
|  | ||||
|         return BaseModelOutputWithPooling( | ||||
|             last_hidden_state=last_hidden_state, | ||||
|             pooler_output=pooler_output, | ||||
|             hidden_states=encoder_outputs.hidden_states, | ||||
|             attentions=encoder_outputs.attentions, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Siglip2TextEmbeddings(nn.Module): | ||||
|     def __init__(self, config: Siglip2TextConfig): | ||||
|         super().__init__() | ||||
|  | ||||
| @ -37,6 +37,7 @@ from transformers.models.siglip.modeling_siglip import ( | ||||
|  | ||||
| from ...modeling_attn_mask_utils import _prepare_4d_attention_mask | ||||
| from ...utils import auto_docstring, filter_out_non_signature_kwargs | ||||
| from ...utils.generic import check_model_inputs | ||||
|  | ||||
|  | ||||
| class Siglip2TextConfig(SiglipTextConfig): | ||||
| @ -230,6 +231,10 @@ class Siglip2VisionEmbeddings(nn.Module): | ||||
|         return embeddings | ||||
|  | ||||
|  | ||||
| class Siglip2PreTrainedModel(SiglipPreTrainedModel): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class Siglip2VisionTransformer(SiglipVisionTransformer): | ||||
|     def __init__(self, config: Siglip2VisionConfig): | ||||
|         super().__init__(config) | ||||
| @ -280,10 +285,6 @@ class Siglip2VisionTransformer(SiglipVisionTransformer): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Siglip2PreTrainedModel(SiglipPreTrainedModel): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class Siglip2TextModel(SiglipTextModel): | ||||
|     pass | ||||
|  | ||||
| @ -314,6 +315,8 @@ class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): | ||||
|  | ||||
| class Siglip2VisionModel(SiglipVisionModel): | ||||
|     # Update: add `spatial_shapes` and `pixel_attention_mask` | ||||
|     @check_model_inputs(tie_last_hidden_states=False) | ||||
|     @auto_docstring | ||||
|     def forward( | ||||
|         self, | ||||
|         pixel_values: torch.FloatTensor, | ||||
|  | ||||
| @ -59,15 +59,11 @@ class HfQuantizer(ABC): | ||||
|         requires_parameters_quantization (`bool`): | ||||
|             Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is | ||||
|             required to create a new xxxParameter in order to properly quantize the model. | ||||
|         requires_full_weights (`bool`): | ||||
|             Whether the quantization method needs the full (non-sharded) weights for conversion. If set to `False`, only | ||||
|             the relevant tensor slices will be provided during weight loading. | ||||
|     """ | ||||
|  | ||||
|     requires_calibration = False | ||||
|     required_packages = None | ||||
|     requires_parameters_quantization = False | ||||
|     requires_full_weights = True | ||||
|  | ||||
|     def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): | ||||
|         self.quantization_config = quantization_config | ||||
|  | ||||
| @ -75,8 +75,6 @@ class FineGrainedFP8HfQuantizer(HfQuantizer): | ||||
|             dtype = torch.float32 | ||||
|         return dtype | ||||
|  | ||||
|     # TODO: make this into a `ConversionType` ops -> potentially requires all weights on all ranks | ||||
|     # depending on the layer type (moe -> no if ep) | ||||
|     def create_quantized_param( | ||||
|         self, | ||||
|         model: "PreTrainedModel", | ||||
| @ -95,9 +93,8 @@ class FineGrainedFP8HfQuantizer(HfQuantizer): | ||||
|                 if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: | ||||
|                     raise ValueError("Expect quantized weights but got an unquantized weight") | ||||
|             else: | ||||
|                 return | ||||
|                 # if tensor_name == "weight_scale_inv": | ||||
|                 #     raise ValueError("Expect unquantized weights but got a quantized weight_scale") | ||||
|                 if tensor_name == "weight_scale_inv": | ||||
|                     raise ValueError("Expect unquantized weights but got a quantized weight_scale") | ||||
|  | ||||
|         param_value = param_value.to(target_device) | ||||
|  | ||||
| @ -140,10 +137,10 @@ class FineGrainedFP8HfQuantizer(HfQuantizer): | ||||
|         _load_parameter_into_model(model, param_name.rsplit(".", 1)[0] + ".weight_scale_inv", scale) | ||||
|  | ||||
|     def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: | ||||
|         from ..integrations.finegrained_fp8 import FP8Expert, FP8Linear | ||||
|         from ..integrations.finegrained_fp8 import FP8Linear | ||||
|  | ||||
|         module, tensor_name = get_module_from_name(model, param_name) | ||||
|         if isinstance(module, (FP8Linear, FP8Expert)): | ||||
|         if isinstance(module, FP8Linear): | ||||
|             if self.pre_quantized or tensor_name == "bias": | ||||
|                 return False | ||||
|             else: | ||||
| @ -185,9 +182,6 @@ class FineGrainedFP8HfQuantizer(HfQuantizer): | ||||
|                         not_missing_keys.append(missing) | ||||
|         return [k for k in missing_keys if k not in not_missing_keys] | ||||
|  | ||||
|     # TODO: similarly, just as we have a weight weight remapping we | ||||
|     # need to have a cleaner way to remap the quantized keys. | ||||
|     # 1. A SINGLE normal_key -> quantized keys used for ckpt renaming and for TP_plan as well | ||||
|     def update_tp_plan(self, config): | ||||
|         if "Qwen3" in config.__class__.__name__: | ||||
|             text_plan = { | ||||
|  | ||||
| @ -877,7 +877,13 @@ def check_model_inputs(tie_last_hidden_states=True): | ||||
|                 def wrapped_forward(*args, **kwargs): | ||||
|                     if key == "hidden_states" and len(collected_outputs[key]) == 0: | ||||
|                         collected_outputs[key] += (args[0],) | ||||
|                     output = orig_forward(*args, **kwargs) | ||||
|                     if kwargs.get("debug_io", False): | ||||
|                         with model_addition_debugger_context( | ||||
|                             module, kwargs.get("debug_io_dir", "~/model_debug"), kwargs.get("prune_layers") | ||||
|                         ): | ||||
|                             output = orig_forward(*args, **kwargs) | ||||
|                     else: | ||||
|                         output = orig_forward(*args, **kwargs) | ||||
|                     if not isinstance(output, tuple): | ||||
|                         collected_outputs[key] += (output,) | ||||
|                     elif output[index] is not None: | ||||
| @ -918,13 +924,7 @@ def check_model_inputs(tie_last_hidden_states=True): | ||||
|                             monkey_patched_layers.append((module, original_forward)) | ||||
|  | ||||
|             try: | ||||
|                 if kwargs.get("debug_io", False): | ||||
|                     with model_addition_debugger_context( | ||||
|                         self, kwargs.get("debug_io_dir", "model_debug"), kwargs.get("prune_layers") | ||||
|                     ): | ||||
|                         outputs = func(self, *args, **kwargs) | ||||
|                 else: | ||||
|                     outputs = func(self, *args, **kwargs) | ||||
|                 outputs = func(self, *args, **kwargs) | ||||
|             except TypeError as original_exception: | ||||
|                 # If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly. | ||||
|                 # Get a TypeError even after removing the recordable kwargs -> re-raise the original exception | ||||
|  | ||||
| @ -1176,12 +1176,9 @@ def is_mistral_common_available() -> bool: | ||||
|  | ||||
| @lru_cache | ||||
| def is_opentelemetry_available() -> bool: | ||||
|     try: | ||||
|         return _is_package_available("opentelemetry") and version.parse( | ||||
|             importlib.metadata.version("opentelemetry-api") | ||||
|         ) >= version.parse("1.30.0") | ||||
|     except Exception as _: | ||||
|         return False | ||||
|     return _is_package_available("opentelemetry") and version.parse( | ||||
|         importlib.metadata.version("opentelemetry-api") | ||||
|     ) >= version.parse("1.30.0") | ||||
|  | ||||
|  | ||||
| def check_torch_load_is_safe() -> None: | ||||
|  | ||||
| @ -1,236 +0,0 @@ | ||||
| import logging | ||||
| import re | ||||
| import shutil | ||||
| import sys | ||||
| from collections import OrderedDict, defaultdict | ||||
| from collections.abc import Iterable | ||||
| from typing import Any, Optional | ||||
|  | ||||
|  | ||||
| _DIGIT_RX = re.compile(r"(?<=\.)(\d+)(?=\.|$)")  # numbers between dots or at the end | ||||
|  | ||||
|  | ||||
| def _pattern_of(key: str) -> str: | ||||
|     """Replace every dot-delimited integer with '*' to get the structure.""" | ||||
|     return _DIGIT_RX.sub("*", key) | ||||
|  | ||||
|  | ||||
| def _fmt_indices(values: list[int]) -> str: | ||||
|     """Format a list of ints as single number, {a, b, ...}, or first...last.""" | ||||
|     if len(values) == 1: | ||||
|         return str(values[0]) | ||||
|     values = sorted(values) | ||||
|     if len(values) > 10: | ||||
|         return f"{values[0]}...{values[-1]}" | ||||
|     return ", ".join(map(str, values)) | ||||
|  | ||||
|  | ||||
| def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]: | ||||
|     """ | ||||
|     Merge keys like 'layers.0.x', 'layers.1.x' into 'layers.{0, 1}.x' | ||||
|     BUT only merge together keys that have the exact same value. | ||||
|     Returns a new dict {merged_key: value}. | ||||
|     """ | ||||
|     # (pattern, value) -> list[set[int]] (per-star index values) | ||||
|     not_mapping = False | ||||
|     if not isinstance(mapping, dict): | ||||
|         mapping = {k: k for k in mapping} | ||||
|         not_mapping = True | ||||
|  | ||||
|     bucket: dict[tuple[str, Any], list[set[int]]] = defaultdict(list) | ||||
|     for key, val in mapping.items(): | ||||
|         digs = _DIGIT_RX.findall(key) | ||||
|         patt = _pattern_of(key) | ||||
|         for i, d in enumerate(digs): | ||||
|             if len(bucket[patt]) <= i: | ||||
|                 bucket[patt].append(set()) | ||||
|             bucket[patt][i].add(int(d)) | ||||
|         bucket[patt].append(val) | ||||
|  | ||||
|     out_items = {} | ||||
|     for patt, values in bucket.items(): | ||||
|         sets, val = values[:-1], values[-1] | ||||
|         parts = patt.split("*")  # stars are between parts | ||||
|         final = parts[0] | ||||
|         for i in range(1, len(parts)): | ||||
|             # i-1 is the star index before parts[i] | ||||
|             if i - 1 < len(sets) and sets[i - 1]: | ||||
|                 insert = _fmt_indices(sorted(sets[i - 1])) | ||||
|                 if len(sets[i - 1]) > 1: | ||||
|                     final += "{" + insert + "}" | ||||
|                 else: | ||||
|                     final += insert | ||||
|             else: | ||||
|                 # If no digits observed for this star position, keep a literal '*' | ||||
|                 final += "*" | ||||
|             final += parts[i] | ||||
|  | ||||
|         out_items[final] = val | ||||
|  | ||||
|     # Stable ordering by merged key | ||||
|     out = OrderedDict(out_items) | ||||
|     if not_mapping: | ||||
|         return out.keys() | ||||
|     return out | ||||
|  | ||||
|  | ||||
| class ANSI: | ||||
|     palette = { | ||||
|         "reset": "[0m", | ||||
|         "red": "[31m", | ||||
|         "yellow": "[33m", | ||||
|         "orange": "[38;5;208m", | ||||
|         "purple": "[35m", | ||||
|         "bold": "[1m", | ||||
|         "italic": "[3m", | ||||
|         "dim": "[2m", | ||||
|     } | ||||
|  | ||||
|     def __init__(self, enable): | ||||
|         self.enable = enable | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
|         return self.palette[key] if self.enable else "" | ||||
|  | ||||
|  | ||||
| _ansi_re = re.compile(r"\x1b\[[0-9;]*m") | ||||
|  | ||||
|  | ||||
| def _strip_ansi(s: str) -> str: | ||||
|     return _ansi_re.sub("", str(s)) | ||||
|  | ||||
|  | ||||
| def _pad(text, width): | ||||
|     t = str(text) | ||||
|     pad = max(0, width - len(_strip_ansi(t))) | ||||
|     return t + " " * pad | ||||
|  | ||||
|  | ||||
| def _make_table(rows, headers): | ||||
|     # compute display widths while ignoring ANSI codes | ||||
|     cols = list(zip(*([headers] + rows))) if rows else [headers] | ||||
|     widths = [max(len(_strip_ansi(x)) for x in col) for col in cols] | ||||
|     header_line = " | ".join(_pad(h, w) for h, w in zip(headers, widths)) | ||||
|     sep_line = "-+-".join("-" * w for w in widths) | ||||
|     body = [" | ".join(_pad(c, w) for c, w in zip(r, widths)) for r in rows] | ||||
|     return "\n".join([header_line, sep_line] + body) | ||||
|  | ||||
|  | ||||
| def _color(s, color, ansi): | ||||
|     # ansi returns empty strings when disabled, so safe to interpolate | ||||
|     return f"{ansi[color]}{s}{ansi['reset']}" | ||||
|  | ||||
|  | ||||
| def _get_terminal_width(default=80): | ||||
|     try: | ||||
|         return shutil.get_terminal_size().columns | ||||
|     except Exception: | ||||
|         return default | ||||
|  | ||||
|  | ||||
| def log_state_dict_report( | ||||
|     *, | ||||
|     model, | ||||
|     pretrained_model_name_or_path, | ||||
|     logger: Optional[logging.Logger] = None, | ||||
|     error_msgs: Optional[Iterable[str]] = None, | ||||
|     unexpected_keys=None, | ||||
|     missing_keys=None, | ||||
|     mismatched_keys=None, | ||||
|     mismatched_shapes=None, | ||||
|     ignore_mismatched_sizes=True, | ||||
|     misc=None, | ||||
|     limit_rows=50,  # safety for huge checkpoints | ||||
|     color=True,  # allow disabling for plain logs | ||||
|     min_width_full_table=60,  # terminal min width to attempt full table | ||||
| ): | ||||
|     """Log a readable report about state_dict loading issues. | ||||
|  | ||||
|     This version is terminal-size aware: for very small terminals it falls back to a compact | ||||
|     Key | Status view so output doesn't wrap badly. | ||||
|     """ | ||||
|     if logger is None: | ||||
|         logger = logging.getLogger(__name__) | ||||
|  | ||||
|     error_msgs = error_msgs or [] | ||||
|     unexpected_keys = unexpected_keys or [] | ||||
|     missing_keys = missing_keys or [] | ||||
|     mismatched_keys = mismatched_keys or [] | ||||
|     mismatched_shapes = mismatched_shapes or [] | ||||
|     misc = misc or {} | ||||
|  | ||||
|     # Detect whether the current stdout supports ANSI colors; allow callers to pass `color=False` to force no color | ||||
|     color_enabled = bool(color and sys.stdout.isatty()) | ||||
|     ansi = ANSI(color_enabled) | ||||
|  | ||||
|     if error_msgs: | ||||
|         error_msg = "\n\t".join(error_msgs) | ||||
|         if "size mismatch" in error_msg: | ||||
|             error_msg += ( | ||||
|                 "\n\tYou may consider adding `ignore_mismatched_sizes=True` to `from_pretrained(...)` if appropriate." | ||||
|             ) | ||||
|         raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | ||||
|  | ||||
|     term_w = _get_terminal_width() | ||||
|     rows = [] | ||||
|     if unexpected_keys: | ||||
|         for k in update_key_name(unexpected_keys): | ||||
|             status = "UNEXPECTED" | ||||
|             status = _color(status, "orange", ansi) | ||||
|             rows.append([k, status, "", ""]) | ||||
|  | ||||
|     if missing_keys: | ||||
|         for k in update_key_name(missing_keys): | ||||
|             status = "MISSING" | ||||
|             status = _color(status, "red", ansi) | ||||
|             rows.append([k, status, ""]) | ||||
|  | ||||
|     if mismatched_keys: | ||||
|         iterator = {a: (b, c) for a, b, c in mismatched_shapes} | ||||
|         for key, (shape_ckpt, shape_model) in update_key_name(iterator).items(): | ||||
|             status = "MISMATCH" | ||||
|             status = _color(status, "yellow", ansi) | ||||
|             data = [key, status] | ||||
|             if term_w > limit_rows: | ||||
|                 data.append( | ||||
|                     " ".join(["Reinit due to size mismatch", f"ckpt: {str(shape_ckpt)} vs model:{str(shape_model)}"]) | ||||
|                 ) | ||||
|             rows.append(data) | ||||
|  | ||||
|     if misc: | ||||
|         for k, v in update_key_name(misc).items(): | ||||
|             status = "MISC" | ||||
|             status = _color(status, "purple", ansi) | ||||
|             _details = v[:term_w] | ||||
|             rows.append([k, status, _details]) | ||||
|  | ||||
|     if not rows: | ||||
|         print(f"No key issues when initializing {model.__class__.__name__} from {pretrained_model_name_or_path}.") | ||||
|         return | ||||
|  | ||||
|     headers = ["Key", "Status"] | ||||
|     if term_w > 200: | ||||
|         headers += ["Details"] | ||||
|     else: | ||||
|         headers += ["", ""] | ||||
|     table = _make_table(rows, headers=headers) | ||||
|  | ||||
|     prelude = ( | ||||
|         f"{ansi['bold']}{model.__class__.__name__} LOAD REPORT{ansi['reset']} from: {pretrained_model_name_or_path}\n" | ||||
|     ) | ||||
|     tips = f"\n\n{ansi['italic']}Notes:" | ||||
|     if unexpected_keys: | ||||
|         tips += f"\n- {_color('UNEXPECTED', 'orange', ansi) + ansi['italic']}\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch." | ||||
|     if missing_keys: | ||||
|         tips += f"\n- {_color('MISSING', 'red', ansi) + ansi['italic']}\t:those params were newly initialized because missing form the checkpoint. Consider training on your downstream task." | ||||
|     if mismatched_keys: | ||||
|         tips += f"\n- {_color('MISMATCH', 'yellow', ansi) + ansi['italic']}\t:ckpt weights were loaded, but they did not match the original empty weight." | ||||
|     if misc: | ||||
|         tips += f"\n- {_color('MISC', 'purple', ansi) + ansi['italic']}\t:originate from the conversion scheme" | ||||
|     tips += f"{ansi['reset']}" | ||||
|  | ||||
|     logger.warning(prelude + table + tips) | ||||
|     if not ignore_mismatched_sizes and mismatched_keys: | ||||
|         raise RuntimeError( | ||||
|             "You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!" | ||||
|         ) | ||||
| @ -232,7 +232,7 @@ class AutoformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_encoder_decoder_model_standalone(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() | ||||
|  | ||||
| @ -539,7 +539,7 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
| @ -625,7 +625,7 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
| @ -708,7 +708,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_inputs_embeds(self): | ||||
|         config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||||
|  | ||||
| @ -438,7 +438,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -297,7 +297,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -241,7 +241,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -246,7 +246,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -200,7 +200,7 @@ class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase): | ||||
|         with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|             model.save_pretrained(tmpdirname) | ||||
|             _, info = FastSpeech2ConformerModel.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|         self.assertEqual(info["missing_keys"], set()) | ||||
|         self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_forward_signature(self): | ||||
|         config, _ = self.model_tester.prepare_config_and_inputs_for_common() | ||||
| @ -618,7 +618,7 @@ class FastSpeech2ConformerWithHifiGanTest(ModelTesterMixin, unittest.TestCase): | ||||
|         with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|             model.save_pretrained(tmpdirname) | ||||
|             _, info = FastSpeech2ConformerWithHifiGan.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|         self.assertEqual(info["missing_keys"], set()) | ||||
|         self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_forward_signature(self): | ||||
|         config, _ = self.model_tester.prepare_config_and_inputs_for_common() | ||||
|  | ||||
| @ -248,7 +248,7 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_ensure_weights_are_shared(self): | ||||
|         config, inputs_dict = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -218,7 +218,7 @@ class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_encoder_decoder_model_standalone(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() | ||||
|  | ||||
| @ -316,7 +316,7 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -272,7 +272,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -246,7 +246,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -265,7 +265,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -462,7 +462,7 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -274,7 +274,7 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config, inputs_dict = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -256,7 +256,7 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -276,7 +276,7 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_hidden_states_output(self): | ||||
|         def check_hidden_states_output(inputs_dict, config, model_class): | ||||
|  | ||||
| @ -208,7 +208,7 @@ class PatchTSTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_hidden_states_output(self): | ||||
|         def check_hidden_states_output(inputs_dict, config, model_class): | ||||
|  | ||||
| @ -253,7 +253,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -231,7 +231,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -261,7 +261,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_decoder_model_past_with_large_inputs(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -282,7 +282,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_model_forward(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -354,7 +354,7 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, Generatio | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_model_forward(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
| @ -859,7 +859,7 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase): | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_model_forward(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
| @ -1359,7 +1359,7 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase): | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_model_forward(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -205,7 +205,7 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_encoder_decoder_model_standalone(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() | ||||
|  | ||||
| @ -422,7 +422,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi | ||||
|             with tempfile.TemporaryDirectory() as tmpdirname: | ||||
|                 model.save_pretrained(tmpdirname) | ||||
|                 model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) | ||||
|             self.assertEqual(info["missing_keys"], set()) | ||||
|             self.assertEqual(info["missing_keys"], []) | ||||
|  | ||||
|     def test_model_forward(self): | ||||
|         config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||||
|  | ||||
| @ -1924,7 +1924,7 @@ class ModelTesterMixin: | ||||
|                         v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}" | ||||
|                     ) | ||||
|                 # Checking there was no complain of missing weights | ||||
|                 self.assertEqual(infos["missing_keys"], set()) | ||||
|                 self.assertEqual(infos["missing_keys"], []) | ||||
|  | ||||
|                 # Checking the tensor sharing are correct | ||||
|                 ptrs = defaultdict(list) | ||||
| @ -1958,7 +1958,7 @@ class ModelTesterMixin: | ||||
|                         v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}" | ||||
|                     ) | ||||
|                 # Checking there was no complain of missing weights | ||||
|                 self.assertEqual(infos["missing_keys"], set()) | ||||
|                 self.assertEqual(infos["missing_keys"], []) | ||||
|  | ||||
|     def test_tied_weights_keys(self): | ||||
|         original_config, _ = self.model_tester.prepare_config_and_inputs_for_common() | ||||
| @ -2485,17 +2485,17 @@ class ModelTesterMixin: | ||||
|                         new_model = AutoModelForSequenceClassification.from_pretrained( | ||||
|                             tmp_dir, num_labels=42, ignore_mismatched_sizes=True | ||||
|                         ) | ||||
|                     self.assertIn("Reinit due to size mismatch", cl.out) | ||||
|                     self.assertIn("the shapes did not match", cl.out) | ||||
|                     new_model.to(torch_device) | ||||
|                     inputs = self._prepare_for_class(inputs_dict, model_class) | ||||
|                     logits = new_model(**inputs).logits | ||||
|                     self.assertEqual(logits.shape[1], 2)  # we still want to load :) | ||||
|                     self.assertEqual(logits.shape[1], 42) | ||||
|  | ||||
|                     with CaptureLogger(logger) as cl: | ||||
|                         new_model_without_prefix = AutoModel.from_pretrained( | ||||
|                             tmp_dir, vocab_size=10, ignore_mismatched_sizes=True | ||||
|                         ) | ||||
|                     self.assertIn("Reinit due to size mismatch", cl.out) | ||||
|                     self.assertIn("the shapes did not match", cl.out) | ||||
|                     input_ids = ids_tensor((2, 8), 10) | ||||
|                     new_model_without_prefix.to(torch_device) | ||||
|                     if self.is_encoder_decoder: | ||||
| @ -2536,7 +2536,7 @@ class ModelTesterMixin: | ||||
|  | ||||
|                     with CaptureLogger(logger) as cl: | ||||
|                         new_model = model_class.from_pretrained(tmp_dir, num_labels=42, ignore_mismatched_sizes=True) | ||||
|                     self.assertIn("Reinit due to size mismatch", cl.out) | ||||
|                     self.assertIn("the shapes did not match", cl.out) | ||||
|  | ||||
|                     # Find the name of the module with the mismatched size | ||||
|                     top_linear_modules = [ | ||||
| @ -3895,125 +3895,7 @@ class ModelTesterMixin: | ||||
|                     ): | ||||
|                         self.assertEqual(k1, k2) | ||||
|                         self.assertEqual(v1.dtype, v2.dtype) | ||||
|                     self.assertTrue((v1 == v2).all()) | ||||
|  | ||||
|  | ||||
| @require_torch | ||||
| def test_weight_conversion_operations_roundtrip(): | ||||
|     import torch | ||||
|  | ||||
|     from transformers.core_model_loading import ( | ||||
|         Chunk, | ||||
|         Concatenate, | ||||
|         Fp8Dequantize, | ||||
|         Fp8Quantize, | ||||
|         MergeModuleList, | ||||
|         Shard, | ||||
|         WeightConversion, | ||||
|         convert_state_dict, | ||||
|     ) | ||||
|  | ||||
|     state_dict = { | ||||
|         "experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]), | ||||
|         "experts.1.w1.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]), | ||||
|         "experts.0.w3.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]), | ||||
|         "experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]), | ||||
|         "self_attn.q_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0]]), | ||||
|         "self_attn.k_proj.weight": torch.tensor([[5.0, 6.0], [7.0, 8.0]]), | ||||
|         "self_attn.v_proj.weight": torch.tensor([[9.0, 10.0], [11.0, 12.0]]), | ||||
|         "self_attn.out_proj.weight": torch.arange(12.0).reshape(6, 2), | ||||
|         "mlp.w2.weight": torch.tensor([[1.0, 0.0], [0.0, 1.0]]), | ||||
|     } | ||||
|  | ||||
|     forward_mapping = [ | ||||
|         WeightConversion( | ||||
|             ["experts.*.w1.weight", "experts.*.w3.weight"], | ||||
|             "experts.gate_up_proj.weight", | ||||
|             [MergeModuleList(dim=0), Concatenate(dim=0), Fp8Quantize(block_size=(1, 1))], | ||||
|         ), | ||||
|         WeightConversion( | ||||
|             ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], | ||||
|             "self_attn.qkv_proj.weight", | ||||
|             Concatenate(dim=0), | ||||
|         ), | ||||
|         WeightConversion( | ||||
|             "self_attn.out_proj.weight", | ||||
|             ["self_attn.out_proj.weight.shard0", "self_attn.out_proj.weight.shard1"], | ||||
|             Shard(dim=0, world_size=2, return_all=True), | ||||
|         ), | ||||
|         WeightConversion("mlp.w2.weight", "mlp.down_proj.weight"), | ||||
|     ] | ||||
|  | ||||
|     converted_state, _ = convert_state_dict(None, state_dict, forward_mapping, tp_plan=None, quantization_config=None) | ||||
|  | ||||
|     expected_qkv = torch.cat( | ||||
|         ( | ||||
|             state_dict["self_attn.q_proj.weight"], | ||||
|             state_dict["self_attn.k_proj.weight"], | ||||
|             state_dict["self_attn.v_proj.weight"], | ||||
|         ), | ||||
|         dim=0, | ||||
|     ) | ||||
|     torch.testing.assert_close(converted_state["self_attn.qkv_proj.weight"], expected_qkv) | ||||
|  | ||||
|     reconstructed_out_proj = torch.cat( | ||||
|         (converted_state["self_attn.out_proj.weight.shard0"], converted_state["self_attn.out_proj.weight.shard1"]), | ||||
|         dim=0, | ||||
|     ) | ||||
|     torch.testing.assert_close(reconstructed_out_proj, state_dict["self_attn.out_proj.weight"]) | ||||
|     torch.testing.assert_close(converted_state["mlp.down_proj.weight"], state_dict["mlp.w2.weight"]) | ||||
|  | ||||
|     inverse_mapping = [ | ||||
|         WeightConversion( | ||||
|             ["experts.gate_up_proj.weight", "experts.gate_up_proj.scale"], | ||||
|             "experts.gate_up_proj.dequantized", | ||||
|             Fp8Dequantize(block_size=(1, 1)), | ||||
|         ), | ||||
|         WeightConversion( | ||||
|             "experts.gate_up_proj.dequantized", | ||||
|             ["experts.w1.concat", "experts.w3.concat"], | ||||
|             Chunk(dim=0, sizes=[4, 4]), | ||||
|         ), | ||||
|         WeightConversion( | ||||
|             "experts.w1.concat", | ||||
|             ["experts.0.w1.weight", "experts.1.w1.weight"], | ||||
|             Chunk(dim=0, sizes=[2, 2]), | ||||
|         ), | ||||
|         WeightConversion( | ||||
|             "experts.w3.concat", | ||||
|             ["experts.0.w3.weight", "experts.1.w3.weight"], | ||||
|             Chunk(dim=0, sizes=[2, 2]), | ||||
|         ), | ||||
|         WeightConversion( | ||||
|             "self_attn.qkv_proj.weight", | ||||
|             [ | ||||
|                 "self_attn.q_proj.weight", | ||||
|                 "self_attn.k_proj.weight", | ||||
|                 "self_attn.v_proj.weight", | ||||
|             ], | ||||
|             Chunk(dim=0, sizes=[2, 2, 2]), | ||||
|         ), | ||||
|         WeightConversion( | ||||
|             ["self_attn.out_proj.weight.shard0", "self_attn.out_proj.weight.shard1"], | ||||
|             "self_attn.out_proj.weight", | ||||
|             Concatenate(dim=0), | ||||
|         ), | ||||
|         WeightConversion("mlp.down_proj.weight", "mlp.w2.weight"), | ||||
|     ] | ||||
|  | ||||
|     roundtrip_state, _ = convert_state_dict( | ||||
|         None, converted_state, inverse_mapping, tp_plan=None, quantization_config=None | ||||
|     ) | ||||
|  | ||||
|     torch.testing.assert_close(roundtrip_state["experts.0.w1.weight"], state_dict["experts.0.w1.weight"]) | ||||
|     torch.testing.assert_close(roundtrip_state["experts.1.w1.weight"], state_dict["experts.1.w1.weight"]) | ||||
|     torch.testing.assert_close(roundtrip_state["experts.0.w3.weight"], state_dict["experts.0.w3.weight"]) | ||||
|     torch.testing.assert_close(roundtrip_state["experts.1.w3.weight"], state_dict["experts.1.w3.weight"]) | ||||
|     torch.testing.assert_close(roundtrip_state["self_attn.q_proj.weight"], state_dict["self_attn.q_proj.weight"]) | ||||
|     torch.testing.assert_close(roundtrip_state["self_attn.k_proj.weight"], state_dict["self_attn.k_proj.weight"]) | ||||
|     torch.testing.assert_close(roundtrip_state["self_attn.v_proj.weight"], state_dict["self_attn.v_proj.weight"]) | ||||
|     torch.testing.assert_close(roundtrip_state["self_attn.out_proj.weight"], state_dict["self_attn.out_proj.weight"]) | ||||
|     torch.testing.assert_close(roundtrip_state["mlp.w2.weight"], state_dict["mlp.w2.weight"]) | ||||
|                         self.assertTrue((v1 == v2).all()) | ||||
|  | ||||
|  | ||||
| global_rng = random.Random() | ||||
|  | ||||
| @ -1,112 +0,0 @@ | ||||
| # Copyright 2019 HuggingFace Inc. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import unittest | ||||
|  | ||||
| from transformers.core_model_loading import build_glob_alt, match_glob | ||||
|  | ||||
|  | ||||
| class TestWeightGlobMatching(unittest.TestCase): | ||||
|     def setUp(self): | ||||
|         self.weight_globs_digits = [ | ||||
|             "model.layers.*.mlp.gate_up_proj.weight", | ||||
|             "model.layers.*.self_attn.q_proj.weight", | ||||
|             "embed_tokens.weight", | ||||
|         ] | ||||
|         self.alt_digits, self.map_digits = build_glob_alt(self.weight_globs_digits, digits_only=True) | ||||
|  | ||||
|         self.weight_globs_any = [ | ||||
|             "model.layers.*.mlp.gate_up_proj.weight", | ||||
|             "model.layers.*.self_attn.q_proj.weight", | ||||
|             "embed_tokens.weight", | ||||
|         ] | ||||
|         self.alt_any, self.map_any = build_glob_alt(self.weight_globs_any, digits_only=False) | ||||
|  | ||||
|     def test_exact_match(self): | ||||
|         self.assertEqual(match_glob("embed_tokens.weight", self.alt_digits, self.map_digits), "embed_tokens.weight") | ||||
|  | ||||
|     def test_digits_only_star_accepts_digits(self): | ||||
|         self.assertEqual( | ||||
|             match_glob("model.layers.0.mlp.gate_up_proj.weight", self.alt_digits, self.map_digits), | ||||
|             "model.layers.*.mlp.gate_up_proj.weight", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             match_glob("model.layers.12.self_attn.q_proj.weight", self.alt_digits, self.map_digits), | ||||
|             "model.layers.*.self_attn.q_proj.weight", | ||||
|         ) | ||||
|  | ||||
|     def test_digits_only_star_rejects_nondigits(self): | ||||
|         # 'a' is not digits, so it should not match with digits_only=True | ||||
|         self.assertIsNone(match_glob("model.layers.a.mlp.gate_up_proj.weight", self.alt_digits, self.map_digits)) | ||||
|  | ||||
|     def test_anychar_star_accepts_nondigits(self): | ||||
|         self.assertEqual( | ||||
|             match_glob("model.layers.a.mlp.gate_up_proj.weight", self.alt_any, self.map_any), | ||||
|             "model.layers.*.mlp.gate_up_proj.weight", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             match_glob("model.layers.00x.mlp.gate_up_proj.weight", self.alt_any, self.map_any), | ||||
|             "model.layers.*.mlp.gate_up_proj.weight", | ||||
|         ) | ||||
|  | ||||
|     def test_no_match(self): | ||||
|         self.assertIsNone(match_glob("model.layers.0.mlp.up_proj.weight", self.alt_digits, self.map_digits)) | ||||
|  | ||||
|     def test_leftmost_alternative_wins_for_overlapping_patterns(self): | ||||
|         # Overlapping patterns: both could match; ensure leftmost wins | ||||
|         globs = [ | ||||
|             "model.layers.*.mlp.*.weight",  # broader (first) | ||||
|             "model.layers.0.mlp.gate_up_proj.weight",  # more specific (second) | ||||
|         ] | ||||
|         alt, mapping = build_glob_alt(globs, digits_only=False) | ||||
|  | ||||
|         # Both branches match; Python's regex picks the leftmost alternative → index 0 | ||||
|         self.assertEqual( | ||||
|             match_glob("model.layers.0.mlp.gate_up_proj.weight", alt, mapping), "model.layers.*.mlp.*.weight" | ||||
|         ) | ||||
|  | ||||
|     def test_multiple_patterns_same_prefix(self): | ||||
|         globs = [ | ||||
|             "model.layers.*.self_attn.q_proj.weight", | ||||
|             "model.layers.*.self_attn.k_proj.weight", | ||||
|             "model.layers.*.self_attn.v_proj.weight", | ||||
|         ] | ||||
|         alt, mapping = build_glob_alt(globs, digits_only=True) | ||||
|  | ||||
|         self.assertEqual( | ||||
|             match_glob("model.layers.3.self_attn.q_proj.weight", alt, mapping), | ||||
|             "model.layers.*.self_attn.q_proj.weight", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             match_glob("model.layers.3.self_attn.k_proj.weight", alt, mapping), | ||||
|             "model.layers.*.self_attn.k_proj.weight", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             match_glob("model.layers.3.self_attn.v_proj.weight", alt, mapping), | ||||
|             "model.layers.*.self_attn.v_proj.weight", | ||||
|         ) | ||||
|  | ||||
|     def test_anchor_full_match_only(self): | ||||
|         # Make sure partial strings don't match—anchors ^...$ are in each branch | ||||
|         self.assertIsNone(match_glob("foo.model.layers.0.mlp.gate_up_proj.weight.bar", self.alt_any, self.map_any)) | ||||
|  | ||||
|     def test_large_batch_performance_smoke(self): | ||||
|         # Not a perf benchmark, but ensures building and matching a larger alternation is OK | ||||
|         globs = [f"model.layers.*.mlp.block{i}.weight" for i in range(200)] | ||||
|         alt, mapping = build_glob_alt(globs, digits_only=True) | ||||
|         key = "model.layers.123.mlp.block57.weight" | ||||
|         self.assertEqual(match_glob(key, alt, mapping), "model.layers.*.mlp.block57.weight") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
| @ -1,216 +0,0 @@ | ||||
| # Copyright 2024 HuggingFace Inc. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import re | ||||
| import unittest | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from transformers.core_model_loading import ( | ||||
|     Chunk, | ||||
|     Concatenate, | ||||
|     MergeModulelist, | ||||
|     WeightConverter, | ||||
|     _apply_star_subst, | ||||
|     _glob_to_regex_src, | ||||
|     build_glob_alt, | ||||
|     convert_and_load_state_dict_in_model, | ||||
|     glob_to_re, | ||||
|     match_glob, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class TestGlobRegexHelpers(unittest.TestCase): | ||||
|     def test_glob_to_regex_src_digits_only(self): | ||||
|         pattern = _glob_to_regex_src("model.layers.*.mlp.weight", digits_only=True) | ||||
|         self.assertEqual(pattern, r"model\.layers\.(\d+)\.mlp\.weight") | ||||
|  | ||||
|     def test_glob_to_regex_src_any_chars(self): | ||||
|         pattern = _glob_to_regex_src("model.layers.*.mlp.weight", digits_only=False) | ||||
|         self.assertEqual(pattern, r"model\.layers\.(.+)\.mlp\.weight") | ||||
|  | ||||
|     def test_glob_to_re_fullmatch(self): | ||||
|         regex_src = glob_to_re("model.layers.*.mlp.weight", digits_only=True) | ||||
|         regex = re.compile(f"^{regex_src}$") | ||||
|         self.assertIsNotNone(regex.fullmatch("model.layers.12.mlp.weight")) | ||||
|         self.assertIsNone(regex.fullmatch("model.layers.foo.mlp.weight")) | ||||
|  | ||||
|     def test_apply_star_subst(self): | ||||
|         pattern = "model.layers.*.block.*.weight" | ||||
|         replaced = _apply_star_subst(pattern, ["03", "attn"]) | ||||
|         self.assertEqual(replaced, "model.layers.03.block.attn.weight") | ||||
|  | ||||
|     def test_build_glob_alt_without_prefix(self): | ||||
|         globs = ["model.layers.*.weight"] | ||||
|         alt, mapping = build_glob_alt(globs, allow_prefix=False) | ||||
|         self.assertIsNone(match_glob("foo.model.layers.0.weight", alt, mapping)) | ||||
|         self.assertEqual(match_glob("model.layers.0.weight", alt, mapping), "model.layers.*.weight") | ||||
|  | ||||
|     def test_build_glob_alt_with_prefix(self): | ||||
|         globs = ["layers.*.weight"] | ||||
|         alt, mapping = build_glob_alt(globs, allow_prefix=True) | ||||
|         self.assertEqual(match_glob("model.layers.0.weight", alt, mapping), "layers.*.weight") | ||||
|  | ||||
|  | ||||
| class DummyParamModule(nn.Module): | ||||
|     def __init__(self, shape): | ||||
|         super().__init__() | ||||
|         self.weight = nn.Parameter(torch.zeros(shape)) | ||||
|  | ||||
|  | ||||
| class DummySelfAttn(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.q_proj = DummyParamModule((1, 2)) | ||||
|         self.k_proj = DummyParamModule((1, 2)) | ||||
|         self.v_proj = DummyParamModule((1, 2)) | ||||
|  | ||||
|  | ||||
| class DummyExperts(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.gate_up_proj = DummyParamModule((2, 4, 2)) | ||||
|         self.down_proj = DummyParamModule((2, 2, 2)) | ||||
|  | ||||
|  | ||||
| class DummyLayer(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.self_attn = DummySelfAttn() | ||||
|         self.experts = DummyExperts() | ||||
|  | ||||
|  | ||||
| class DummyTopModel(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.layers = nn.ModuleList([DummyLayer(), DummyLayer()]) | ||||
|  | ||||
|  | ||||
| class DummyMLP(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.down_proj = DummyParamModule((2, 2)) | ||||
|  | ||||
|  | ||||
| class DummyRoot(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.model = DummyTopModel() | ||||
|         self.mlp = DummyMLP() | ||||
|  | ||||
|  | ||||
| class TestConvertAndLoadStateDict(unittest.TestCase): | ||||
|     def test_moe_and_qkv_conversion(self): | ||||
|         model = DummyRoot() | ||||
|  | ||||
|         raw_tensors = { | ||||
|             "model.layers.0.experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]), | ||||
|             "model.layers.0.experts.1.w1.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]), | ||||
|             "model.layers.0.experts.0.w3.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]), | ||||
|             "model.layers.0.experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]), | ||||
|             "model.layers.0.experts.0.w2.weight": torch.tensor([[20.0, 21.0], [22.0, 23.0]]), | ||||
|             "model.layers.0.experts.1.w2.weight": torch.tensor([[24.0, 25.0], [26.0, 27.0]]), | ||||
|             "model.layers.1.experts.0.w1.weight": torch.tensor([[30.0, 31.0], [32.0, 33.0]]), | ||||
|             "model.layers.1.experts.1.w1.weight": torch.tensor([[34.0, 35.0], [36.0, 37.0]]), | ||||
|             "model.layers.1.experts.0.w3.weight": torch.tensor([[38.0, 39.0], [40.0, 41.0]]), | ||||
|             "model.layers.1.experts.1.w3.weight": torch.tensor([[42.0, 43.0], [44.0, 45.0]]), | ||||
|             "model.layers.1.experts.0.w2.weight": torch.tensor([[46.0, 47.0], [48.0, 49.0]]), | ||||
|             "model.layers.1.experts.1.w2.weight": torch.tensor([[50.0, 51.0], [52.0, 53.0]]), | ||||
|             "model.layers.0.self_attn.qkv_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), | ||||
|             "model.layers.1.self_attn.qkv_proj.weight": torch.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]), | ||||
|             "mlp.w2.weight": torch.tensor([[60.0, 61.0], [62.0, 63.0]]), | ||||
|         } | ||||
|         state_dict = {k: (k, v.clone()) for k, v in raw_tensors.items()} | ||||
|  | ||||
|         weight_mapping = [ | ||||
|             WeightConverter( | ||||
|                 ["model.layers.*.experts.*.w1.weight", "model.layers.*.experts.*.w3.weight"], | ||||
|                 "model.layers.*.experts.gate_up_proj.weight", | ||||
|                 operations=[MergeModulelist(dim=0), Concatenate(dim=1)], | ||||
|             ), | ||||
|             WeightConverter( | ||||
|                 "model.layers.*.experts.*.w2.weight", | ||||
|                 "model.layers.*.experts.down_proj.weight", | ||||
|                 operations=[MergeModulelist(dim=0)], | ||||
|             ), | ||||
|             WeightConverter( | ||||
|                 "model.layers.*.self_attn.qkv_proj.weight", | ||||
|                 [ | ||||
|                     "model.layers.*.self_attn.q_proj.weight", | ||||
|                     "model.layers.*.self_attn.k_proj.weight", | ||||
|                     "model.layers.*.self_attn.v_proj.weight", | ||||
|                 ], | ||||
|                 operations=[Concatenate(dim=0), Chunk(dim=0, chunks=3)], | ||||
|             ), | ||||
|             WeightConverter("mlp.w2.weight", "mlp.down_proj.weight"), | ||||
|         ] | ||||
|  | ||||
|         missing, unexpected, mismatch, misc = convert_and_load_state_dict_in_model( | ||||
|             model, state_dict, weight_mapping, tp_plan=None, quantizer=None | ||||
|         ) | ||||
|  | ||||
|         self.assertEqual(missing, set()) | ||||
|         self.assertEqual(unexpected, set()) | ||||
|         self.assertEqual(mismatch, set()) | ||||
|         self.assertEqual(misc, {}) | ||||
|  | ||||
|         model_state = model.state_dict() | ||||
|  | ||||
|         def cat_gate(layer_prefix: str) -> torch.Tensor: | ||||
|             w1 = [ | ||||
|                 raw_tensors[f"{layer_prefix}.experts.0.w1.weight"], | ||||
|                 raw_tensors[f"{layer_prefix}.experts.1.w1.weight"], | ||||
|             ] | ||||
|             w3 = [ | ||||
|                 raw_tensors[f"{layer_prefix}.experts.0.w3.weight"], | ||||
|                 raw_tensors[f"{layer_prefix}.experts.1.w3.weight"], | ||||
|             ] | ||||
|             return torch.cat([torch.stack(w1, dim=0), torch.stack(w3, dim=0)], dim=1) | ||||
|  | ||||
|         torch.testing.assert_close( | ||||
|             model_state["model.layers.0.experts.gate_up_proj.weight"], cat_gate("model.layers.0") | ||||
|         ) | ||||
|         torch.testing.assert_close( | ||||
|             model_state["model.layers.1.experts.gate_up_proj.weight"], cat_gate("model.layers.1") | ||||
|         ) | ||||
|  | ||||
|         def stack_down(layer_prefix: str) -> torch.Tensor: | ||||
|             return torch.stack( | ||||
|                 [ | ||||
|                     raw_tensors[f"{layer_prefix}.experts.0.w2.weight"], | ||||
|                     raw_tensors[f"{layer_prefix}.experts.1.w2.weight"], | ||||
|                 ], | ||||
|                 dim=0, | ||||
|             ) | ||||
|  | ||||
|         torch.testing.assert_close( | ||||
|             model_state["model.layers.0.experts.down_proj.weight"], stack_down("model.layers.0") | ||||
|         ) | ||||
|         torch.testing.assert_close( | ||||
|             model_state["model.layers.1.experts.down_proj.weight"], stack_down("model.layers.1") | ||||
|         ) | ||||
|  | ||||
|         for layer_idx in range(2): | ||||
|             key = f"model.layers.{layer_idx}.self_attn.qkv_proj.weight" | ||||
|             expected_q, expected_k, expected_v = torch.chunk(raw_tensors[key], chunks=3, dim=0) | ||||
|             prefix = f"model.layers.{layer_idx}.self_attn" | ||||
|             torch.testing.assert_close(model_state[f"{prefix}.q_proj.weight"], expected_q) | ||||
|             torch.testing.assert_close(model_state[f"{prefix}.k_proj.weight"], expected_k) | ||||
|             torch.testing.assert_close(model_state[f"{prefix}.v_proj.weight"], expected_v) | ||||
|  | ||||
|         torch.testing.assert_close(model_state["mlp.down_proj.weight"], raw_tensors["mlp.w2.weight"]) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
| @ -90,6 +90,8 @@ PRIVATE_MODELS = [ | ||||
|     "Kosmos2_5TextForCausalLM", | ||||
|     "Kosmos2_5VisionModel", | ||||
|     "SmolVLMVisionTransformer", | ||||
|     "SiglipVisionTransformer", | ||||
|     "Siglip2VisionTransformer", | ||||
|     "AriaTextForCausalLM", | ||||
|     "AriaTextModel", | ||||
|     "Phi4MultimodalAudioModel", | ||||
| @ -358,7 +360,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ | ||||
|     "SegGptForImageSegmentation", | ||||
|     "SiglipVisionModel", | ||||
|     "SiglipTextModel", | ||||
|     "SiglipVisionTransformer", | ||||
|     "Siglip2VisionModel", | ||||
|     "Siglip2VisionTransformer", | ||||
|     "Siglip2TextModel", | ||||
|     "ChameleonVQVAE",  # no autoclass for VQ-VAE models | ||||
|     "VitPoseForPoseEstimation", | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	