diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d5cd109227f..86d19cb09b1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -562,6 +562,8 @@ title: LED - local: model_doc/lfm2 title: LFM2 + - local: model_doc/lfm2_moe + title: LFM2Moe - local: model_doc/llama title: LLaMA - local: model_doc/llama2 diff --git a/docs/source/en/model_doc/lfm2.md b/docs/source/en/model_doc/lfm2.md index 58f1d754588..131733ed6ec 100644 --- a/docs/source/en/model_doc/lfm2.md +++ b/docs/source/en/model_doc/lfm2.md @@ -23,15 +23,15 @@ rendered properly in your Markdown viewer. ## Overview -[LFM2](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) represents a new generation of Liquid Foundation Models developed by [Liquid AI](https://liquid.ai/), specifically designed for edge AI and on-device deployment. +[LFM2](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) represents a new generation of Liquid Foundation Models developed by Liquid AI, specifically designed for edge AI and on-device deployment. -The models are available in three sizes (350M, 700M, and 1.2B parameters) and are engineered to run efficiently on CPU, GPU, and NPU hardware, making them particularly well-suited for applications requiring low latency, offline operation, and privacy. +The models are available in four sizes (350M, 700M, 1.2B, and 2.6B parameters) and are engineered to run efficiently on CPU, GPU, and NPU hardware, making them particularly well-suited for applications requiring low latency, offline operation, and privacy. ## Architecture -The architecture consists of 16 blocks total: 10 double-gated short-range convolution blocks and 6 blocks of grouped query attention. This design stems from the concept of dynamical systems, where linear operations are modulated by input-dependent gates, allowing for "liquid" dynamics that can adapt in real-time. The short convolutions are particularly optimized for embedded SoC CPUs, making them ideal for devices that require fast, local inference without relying on cloud connectivity. +The architecture consists of blocks of gated short convolution blocks and blocks of grouped query attention with QK layernorm. This design stems from the concept of dynamical systems, where linear operations are modulated by input-dependent gates. The short convolutions are particularly optimized for embedded SoC CPUs, making them ideal for devices that require fast, local inference without relying on cloud connectivity. -The key architectural innovation of LFM2 lies in its systematic approach to balancing quality, latency, and memory efficiency through our STAR neural architecture search engine. Using STAR, Liquid AI optimized the models for real-world performance on embedded hardware, measuring actual peak memory usage and inference speed on Qualcomm Snapdragon processors. This results in models that achieve 2x faster decode and prefill performance compared to similar-sized models, while maintaining superior benchmark performance across knowledge, mathematics, instruction following, and multilingual tasks. +LFM2 was designed to maximize quality under strict speed and memory constraints. This was accomplished through a systematic architecture search to optimize the models for real-world performance on embedded hardware by measuring actual peak memory usage and inference speed on Qualcomm Snapdragon processors. This results in models that achieve 2x faster decode and prefill performance compared to similar-sized models, while maintaining superior benchmark performance across knowledge, mathematics, instruction following, and multilingual tasks. ## Example diff --git a/docs/source/en/model_doc/lfm2_moe.md b/docs/source/en/model_doc/lfm2_moe.md new file mode 100644 index 00000000000..bdaaebaa604 --- /dev/null +++ b/docs/source/en/model_doc/lfm2_moe.md @@ -0,0 +1,83 @@ + + + +# Lfm2Moe + +## Overview + +LFM2-MoE is a Mixture-of-Experts (MoE) variant of [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38). The LFM2 family is optimized for on-device inference by combining short‑range, input‑aware gated convolutions with grouped‑query attention (GQA) in a layout tuned to maximize quality under strict speed and memory constraints. + +LFM2‑MoE keeps this fast backbone and introduces sparse MoE feed‑forward networks to add representational capacity without significantly increasing the active compute path. The first LFM2-MoE release is LFM2-8B-A1B, with 8.3B total parameters and 1.5B active parameters. The model excels in quality (comparable to 3-4B dense models) and speed (faster than other 1.5B class models). + +## Example + +The following example shows how to generate an answer using the `AutoModelForCausalLM` class. + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Load model and tokenizer +model_id = "LiquidAI/LFM2-8B-A1B" +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + dtype="bfloat16", +# attn_implementation="flash_attention_2" <- uncomment on compatible GPU +) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# Generate answer +prompt = "What is C. elegans?" +input_ids = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + return_tensors="pt", + tokenize=True, +).to(model.device) + +output = model.generate( + input_ids, + do_sample=True, + temperature=0.3, + min_p=0.15, + repetition_penalty=1.05, + max_new_tokens=512, +) + +print(tokenizer.decode(output[0], skip_special_tokens=False)) +``` + +## Lfm2MoeConfig + +[[autodoc]] Lfm2MoeConfig + +## Lfm2MoeForCausalLM + +[[autodoc]] Lfm2MoeForCausalLM + +## Lfm2MoeModel + +[[autodoc]] Lfm2MoeModel + - forward + +## Lfm2MoePreTrainedModel + +[[autodoc]] Lfm2MoePreTrainedModel + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index c721f24a506..9d0c4445475 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -186,6 +186,7 @@ if TYPE_CHECKING: from .led import * from .levit import * from .lfm2 import * + from .lfm2_moe import * from .lfm2_vl import * from .lightglue import * from .lilt import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index a2e4f05bae5..1cb5f37dad6 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -226,6 +226,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("led", "LEDConfig"), ("levit", "LevitConfig"), ("lfm2", "Lfm2Config"), + ("lfm2_moe", "Lfm2MoeConfig"), ("lfm2_vl", "Lfm2VlConfig"), ("lightglue", "LightGlueConfig"), ("lilt", "LiltConfig"), @@ -670,6 +671,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("led", "LED"), ("levit", "LeViT"), ("lfm2", "Lfm2"), + ("lfm2_moe", "Lfm2Moe"), ("lfm2_vl", "Lfm2Vl"), ("lightglue", "LightGlue"), ("lilt", "LiLT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 298834bebe9..4248fdabdad 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -226,6 +226,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("led", "LEDModel"), ("levit", "LevitModel"), ("lfm2", "Lfm2Model"), + ("lfm2_moe", "Lfm2MoeModel"), ("lfm2_vl", "Lfm2VlModel"), ("lightglue", "LightGlueForKeypointMatching"), ("lilt", "LiltModel"), @@ -694,6 +695,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("jamba", "JambaForCausalLM"), ("jetmoe", "JetMoeForCausalLM"), ("lfm2", "Lfm2ForCausalLM"), + ("lfm2_moe", "Lfm2MoeForCausalLM"), ("llama", "LlamaForCausalLM"), ("llama4", "Llama4ForCausalLM"), ("llama4_text", "Llama4ForCausalLM"), diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 07aced67f65..752621d6c0c 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -163,7 +163,6 @@ class Lfm2HybridConvCache: dtype=self._dtype, device=device, ) - torch._dynamo.mark_static_address(conv_state) self.conv_cache.append(conv_state) self.key_cache.append(torch.tensor([])) self.value_cache.append(torch.tensor([])) @@ -595,7 +594,6 @@ class Lfm2Model(Lfm2PreTrainedModel): self.layers = nn.ModuleList( [Lfm2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.rotary_emb = Lfm2RotaryEmbedding(config=config) self.gradient_checkpointing = False self.pos_emb = Lfm2RotaryEmbedding(config) self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 16a69fa0dc0..355d3baff3a 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -121,7 +121,6 @@ class Lfm2HybridConvCache: dtype=self._dtype, device=device, ) - torch._dynamo.mark_static_address(conv_state) self.conv_cache.append(conv_state) self.key_cache.append(torch.tensor([])) self.value_cache.append(torch.tensor([])) @@ -441,7 +440,7 @@ class Lfm2Model(LlamaModel): self.pos_emb = Lfm2RotaryEmbedding(config) self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) del self.norm - del self.rotary_emv + del self.rotary_emb def forward( self, diff --git a/src/transformers/models/lfm2_moe/__init__.py b/src/transformers/models/lfm2_moe/__init__.py new file mode 100644 index 00000000000..3ebaf8f93e8 --- /dev/null +++ b/src/transformers/models/lfm2_moe/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_lfm2_moe import * + from .modeling_lfm2_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py b/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py new file mode 100644 index 00000000000..550954ecfd2 --- /dev/null +++ b/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py @@ -0,0 +1,169 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from ...configuration_utils import PretrainedConfig + + +class Lfm2MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Lfm2MoeModel`]. It is used to instantiate a LFM2 Moe + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LFM2-8B-A1B model. + e.g. [LiquidAI/LFM2-8B-A1B](https://huggingface.co/LiquidAI/LFM2-8B-A1B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 65536): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Lfm2Model`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 7168): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1792): + Intermediate size of the routed expert. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the conv layers. + conv_L_cache (`int`, *optional*, defaults to 3): + L_cache dim in the conv layers. + num_dense_layers (`int`, *optional*, defaults to 2): + Number of dense Lfm2MoeMLP layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 32): + Number of routed experts. + use_expert_bias (`bool`, *optional*, defaults to `True`): + Whether to use the expert bias on the routing weights. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for routed experts in MoE models. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + layer_types (`Optional`, *optional*): + Type of each layers. + + ```python + >>> from transformers import Lfm2MoeModel, Lfm2MoeConfig + + >>> # Initializing a LFM2 Moe model + >>> configuration = Lfm2MoeConfig() + + >>> # Initializing a model from the LFM2-8B-A1B style configuration + >>> model = Lfm2MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "lfm2_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 65536, + hidden_size: int = 2048, + intermediate_size: int = 7168, + moe_intermediate_size: int = 1792, + num_hidden_layers: int = 32, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = True, + rope_theta: float = 1000000.0, + max_position_embeddings: int = 128_000, + use_cache: bool = True, + norm_eps: float = 0.00001, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + conv_bias: bool = False, + conv_L_cache: int = 3, + num_dense_layers: int = 2, + num_experts_per_tok: int = 4, + num_experts: int = 32, + use_expert_bias: bool = True, + routed_scaling_factor: float = 1.0, + norm_topk_prob: bool = True, + layer_types: Optional[list[str]] = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.use_cache = use_cache + self.norm_eps = norm_eps + + # attn operator config + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + # custom operator config + self.conv_bias = conv_bias + self.conv_L_cache = conv_L_cache + + # moe config + self.num_dense_layers = num_dense_layers + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.use_expert_bias = use_expert_bias + self.routed_scaling_factor = routed_scaling_factor + self.norm_topk_prob = norm_topk_prob + self.layer_types = layer_types + + tie_word_embeddings = kwargs.get("tie_embedding", tie_word_embeddings) # to fit original config keys + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Lfm2MoeConfig"] diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py new file mode 100644 index 00000000000..6f879ec9c5e --- /dev/null +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -0,0 +1,813 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lfm2_moe/modular_lfm2_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lfm2_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from ...utils.import_utils import is_causal_conv1d_available +from .configuration_lfm2_moe import Lfm2MoeConfig + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_fn, causal_conv1d_update = None, None + + +@use_kernel_forward_from_hub("RMSNorm") +class Lfm2MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Lfm2MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Lfm2MoeRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Lfm2MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Lfm2MoeMLP(nn.Module): + def __init__(self, config: Lfm2MoeConfig, intermediate_size: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Lfm2MoeExperts(nn.ModuleList): + """ + ModuleList of experts. + """ + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + for _ in range(config.num_experts): + self.append(Lfm2MoeMLP(config, intermediate_size=config.moe_intermediate_size)) + + def forward( + self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + ) -> torch.Tensor: + """ + Args: + hidden_states: (batch_size * sequence_length, hidden_dim) + selected_experts: (batch_size * sequence_length, top_k) + routing_weights: (batch_size * sequence_length, top_k) + Returns: + (batch_size * sequence_length, hidden_dim) + """ + final_hidden_states = torch.zeros_like(hidden_states) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) + current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + return final_hidden_states + + +class Lfm2MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.routed_scaling_factor = config.routed_scaling_factor + self.norm_topk_prob = config.norm_topk_prob + self.use_expert_bias = config.use_expert_bias + + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = Lfm2MoeExperts(config) + if self.use_expert_bias: + self.register_buffer("expert_bias", torch.zeros(config.num_experts, dtype=torch.float32)) + + def route_tokens_to_experts(self, router_logits): + routing_weights = router_logits.sigmoid() + if self.use_expert_bias: + scores_for_routing = routing_weights + self.expert_bias + _, selected_experts = torch.topk(scores_for_routing, k=self.top_k, dim=-1) + routing_weights = torch.gather(routing_weights, dim=1, index=selected_experts).type_as(router_logits) + else: + routing_weights, selected_experts = torch.topk(routing_weights, k=self.top_k, dim=-1) + + if self.norm_topk_prob: + routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-6) + routing_weights = routing_weights * self.routed_scaling_factor + return selected_experts, routing_weights + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states_reshaped) + selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) + final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + + +class Lfm2MoeHybridConvCache: + """ + Attention and conv cache for Lfm2Moe. + + It stores the Key and Value states as a list of tensors, one for each layer. + Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`. + Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. + """ + + # Override @property existing in Cache + max_batch_size = None + is_compileable = False + key_cache = None + value_cache = None + + def __init__( + self, + config: Lfm2MoeConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, str, None] = None, + ): + self.key_cache = [] + self.value_cache = [] + self.max_batch_size = max_batch_size + self.layer_types = config.layer_types + self.first_attention_layer = self.layer_types.index("full_attention") + self.conv_L_cache = config.conv_L_cache + self._dtype = dtype + + self.conv_cache: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + + for _ in range(config.num_hidden_layers): + conv_state = torch.zeros( + self.max_batch_size, + config.hidden_size, + self.conv_L_cache, + dtype=self._dtype, + device=device, + ) + self.conv_cache.append(conv_state) + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the cache + if self.key_cache[layer_idx].numel() == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + if self.conv_cache[layer_idx].numel(): + device = self.conv_cache[layer_idx].device + self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + full_mask_kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, full_mask_kv_offset + + def crop(self, max_length: int): + """Crop the cache to the given length""" + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + def __len__(self) -> int: + return len(self.key_cache) + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_cache)): + # In-place ops prevent breaking the static address + self.conv_cache[layer_idx].zero_() + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Lfm2MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Lfm2MoeConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.q_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Lfm2MoeHybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2) + key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + output = self.out_proj(attn_output) + return output, attn_weights + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +kernel_modules = (causal_conv1d_fn, causal_conv1d_update) +is_fast_path_available = all(kernel_modules) + + +class Lfm2MoeShortConv(nn.Module): + def __init__( + self, + config: Lfm2MoeConfig, + layer_idx: int, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.L_cache = config.conv_L_cache + self.bias = config.conv_bias + + self.conv = nn.Conv1d( + in_channels=config.hidden_size, + out_channels=config.hidden_size, + kernel_size=self.L_cache, + groups=config.hidden_size, + bias=self.bias, + padding=self.L_cache - 1, + ) + self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def cuda_kernels_forward( + self, + x: torch.Tensor, + past_key_values: Optional[Lfm2MoeHybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + + Bx = B * x + + conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) + if past_key_values is not None and cache_position[0] > 0: + conv_out = causal_conv1d_update( + Bx.squeeze(-1), + past_key_values.conv_cache[self.layer_idx], + conv_weights, + self.conv.bias, + None, + ) + conv_out = conv_out.unsqueeze(-1) + else: + if past_key_values is not None: + conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + past_key_values.conv_cache[self.layer_idx].copy_(conv_state) + + conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) + + y = C * conv_out + y = self.out_proj(y.transpose(-1, -2).contiguous()) + return y + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def slow_forward( + self, + x: torch.Tensor, + past_key_values: Optional[Lfm2MoeHybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + seqlen = x.shape[1] + + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + + Bx = B * x + + if past_key_values is not None and cache_position[0] > 0: + conv_state = past_key_values.conv_cache[self.layer_idx] + cache_position = cache_position.clamp(0, self.L_cache - 1) + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype) + past_key_values.conv_cache[self.layer_idx].copy_(conv_state) + conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) + if self.bias: + conv_out += self.conv.bias + + conv_out = conv_out.unsqueeze(-1) + else: + if past_key_values is not None: + conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + past_key_values.conv_cache[self.layer_idx].copy_(conv_state) + + conv_out = self.conv(Bx)[..., :seqlen] + + y = C * conv_out + y = y.transpose(-1, -2).contiguous() + y = self.out_proj(y) + return y + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[Lfm2MoeHybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling(): + return self.cuda_kernels_forward(hidden_states, past_key_values, cache_position, attention_mask) + return self.slow_forward(hidden_states, past_key_values, cache_position, attention_mask) + + +class Lfm2MoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Lfm2MoeConfig, layer_idx: int): + super().__init__() + self.is_attention_layer = config.layer_types[layer_idx] == "full_attention" + + if self.is_attention_layer: + self.self_attn = Lfm2MoeAttention(config, layer_idx) + else: + self.conv = Lfm2MoeShortConv(config, layer_idx) + self.feed_forward = ( + Lfm2MoeMLP(config, intermediate_size=config.intermediate_size) + if layer_idx < config.num_dense_layers + else Lfm2MoeSparseMoeBlock(config) + ) + self.operator_norm = Lfm2MoeRMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = Lfm2MoeRMSNorm(config.hidden_size, eps=config.norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Lfm2MoeHybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + if self.is_attention_layer: + hidden_states, _ = self.self_attn( + hidden_states=self.operator_norm(hidden_states), + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + else: + hidden_states = self.conv( + hidden_states=self.operator_norm(hidden_states), + past_key_values=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = hidden_states + residual + hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states)) + + return hidden_states + + +@auto_docstring +class Lfm2MoePreTrainedModel(PreTrainedModel): + config: Lfm2MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Lfm2MoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Lfm2MoeDecoderLayer, + "attentions": Lfm2MoeAttention, + } + + +@auto_docstring +class Lfm2MoeModel(Lfm2MoePreTrainedModel): + def __init__(self, config: Lfm2MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Lfm2MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + self.pos_emb = Lfm2MoeRotaryEmbedding(config) + self.embedding_norm = Lfm2MoeRMSNorm(config.hidden_size, eps=config.norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Lfm2MoeHybridConvCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + batch_size = inputs_embeds.shape[0] + past_key_values = Lfm2MoeHybridConvCache( + config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.pos_emb(hidden_states, position_ids) + + # decoder layers + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.embedding_norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Lfm2MoeForCausalLM(Lfm2MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Lfm2MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Lfm2MoeForCausalLM + + >>> model = Lfm2MoeForCausalLM.from_pretrained("meta-lfm2_moe/Lfm2Moe-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-lfm2_moe/Lfm2Moe-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Lfm2MoeForCausalLM", "Lfm2MoeModel", "Lfm2MoePreTrainedModel"] diff --git a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py new file mode 100644 index 00000000000..9a4f5ff73c8 --- /dev/null +++ b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py @@ -0,0 +1,204 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch import nn + +from ...masking_utils import create_causal_mask +from ...modeling_outputs import MoeModelOutputWithPast +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, logging +from ...utils.import_utils import is_causal_conv1d_available +from ..lfm2.modeling_lfm2 import Lfm2Attention, Lfm2DecoderLayer, Lfm2HybridConvCache, Lfm2MLP, Lfm2ShortConv +from ..llama.modeling_llama import LlamaForCausalLM, LlamaPreTrainedModel, LlamaRMSNorm, LlamaRotaryEmbedding +from ..mixtral.modeling_mixtral import MixtralModel +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeExperts +from .configuration_lfm2_moe import Lfm2MoeConfig + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_fn, causal_conv1d_update = None, None + + +kernel_modules = (causal_conv1d_fn, causal_conv1d_update) +is_fast_path_available = all(kernel_modules) + + +logger = logging.get_logger(__name__) + + +class Lfm2MoeRMSNorm(LlamaRMSNorm): + pass + + +class Lfm2MoeRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class Lfm2MoeMLP(Lfm2MLP): + def __init__(self, config: Lfm2MoeConfig, intermediate_size: Optional[int] = None): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + +class Lfm2MoeExperts(Qwen2MoeExperts): + pass + + +class Lfm2MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.routed_scaling_factor = config.routed_scaling_factor + self.norm_topk_prob = config.norm_topk_prob + self.use_expert_bias = config.use_expert_bias + + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = Lfm2MoeExperts(config) + if self.use_expert_bias: + self.register_buffer("expert_bias", torch.zeros(config.num_experts, dtype=torch.float32)) + + def route_tokens_to_experts(self, router_logits): + routing_weights = router_logits.sigmoid() + if self.use_expert_bias: + scores_for_routing = routing_weights + self.expert_bias + _, selected_experts = torch.topk(scores_for_routing, k=self.top_k, dim=-1) + routing_weights = torch.gather(routing_weights, dim=1, index=selected_experts).type_as(router_logits) + else: + routing_weights, selected_experts = torch.topk(routing_weights, k=self.top_k, dim=-1) + + if self.norm_topk_prob: + routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-6) + routing_weights = routing_weights * self.routed_scaling_factor + return selected_experts, routing_weights + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states_reshaped) + selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) + final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + + +class Lfm2MoeHybridConvCache(Lfm2HybridConvCache): + pass + + +class Lfm2MoeAttention(Lfm2Attention): + pass + + +class Lfm2MoeShortConv(Lfm2ShortConv): + pass + + +class Lfm2MoeDecoderLayer(Lfm2DecoderLayer): + def __init__(self, config: Lfm2MoeConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.feed_forward = ( + Lfm2MoeMLP(config, intermediate_size=config.intermediate_size) + if layer_idx < config.num_dense_layers + else Lfm2MoeSparseMoeBlock(config) + ) + + +class Lfm2MoePreTrainedModel(LlamaPreTrainedModel): + _can_compile_fullgraph = False + + +class Lfm2MoeModel(MixtralModel): + def __init__(self, config: Lfm2MoeConfig): + super().__init__(config) + self.pos_emb = Lfm2MoeRotaryEmbedding(config) + self.embedding_norm = Lfm2MoeRMSNorm(config.hidden_size, eps=config.norm_eps) + del self.norm + del self.rotary_emb + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Lfm2MoeHybridConvCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + batch_size = inputs_embeds.shape[0] + past_key_values = Lfm2MoeHybridConvCache( + config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.pos_emb(hidden_states, position_ids) + + # decoder layers + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.embedding_norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class Lfm2MoeForCausalLM(LlamaForCausalLM): + pass + + +__all__ = ["Lfm2MoeForCausalLM", "Lfm2MoeModel", "Lfm2MoePreTrainedModel"] diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py index 577359e22ca..dbd52c8307c 100644 --- a/tests/causal_lm_tester.py +++ b/tests/causal_lm_tester.py @@ -448,6 +448,7 @@ class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM # named location of the RoPE layer class. base_model = self.model_tester.base_model_class(config) possible_rope_attributes = [ + "pos_emb", "rotary_emb", # most common case "global_rotary_emb", "local_rotary_emb", diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py index 8007d0db87a..e3cac3927ee 100644 --- a/tests/models/lfm2/test_modeling_lfm2.py +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -23,12 +23,15 @@ from transformers.testing_utils import ( require_torch, require_torch_accelerator, slow, + torch_device, ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): + import torch + from transformers import Lfm2ForCausalLM, Lfm2Model @@ -60,22 +63,82 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = Lfm2ForCausalLM if is_torch_available() else None - @unittest.skip( - "Lfm2 alternates between attention and conv layers, so attention are only returned for attention layers" - ) def test_attention_outputs(self): - pass + """Lfm2Moe alternates between attention and short-conv layers.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + seq_len = getattr(self.model_tester, "seq_length", None) - @unittest.skip("Lfm2 has a special cache format as it alternates between attention and conv layers") + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager").to(torch_device).eval() + config = model.config + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config).to(torch_device).eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + self.assertListEqual(list(attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config).to(torch_device).eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + self_attentions = outputs.attentions + + self.assertEqual(out_len + 1, len(outputs)) + self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types)) + self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + + @pytest.mark.generate def test_past_key_values_format(self): - pass + """Lfm2Moe has a special cache format as it alternates between attention and conv layers""" + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - @unittest.skip( - "Lfm2 has a special cache format which is not compatible with compile as it has static address for conv cache" - ) - @pytest.mark.torch_compile_test - def test_sdpa_can_compile_dynamic(self): - pass + model = model_class(config).to(torch_device).eval() + if "use_cache" not in inputs: + inputs["use_cache"] = True + outputs = model(**inputs) + + past_kv = outputs["past_key_values"] + + num_query_attention_heads = config.num_attention_heads + embed_dim = config.hidden_size + per_head_embed_dim = embed_dim // num_query_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_query_attention_heads) + + batch_size, seq_length = inputs["input_ids"].shape[:2] + default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) + default_conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) + + num_cache_decoder_layers = len(past_kv) + self.assertEqual(num_cache_decoder_layers, config.num_hidden_layers) + + for i in range(config.num_hidden_layers): + if config.layer_types[i] == "full_attention": + self_attention_layer_keys = past_kv.key_cache[i] + self_attention_layer_values = past_kv.value_cache[i] + self.assertEqual(self_attention_layer_keys.shape, default_self_attention_shape) + self.assertEqual(self_attention_layer_values.shape, default_self_attention_shape) + else: + conv_layer = past_kv.conv_cache[i] + self.assertEqual(conv_layer.shape, default_conv_shape) @require_torch_accelerator diff --git a/tests/models/lfm2_moe/__init__.py b/tests/models/lfm2_moe/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py new file mode 100644 index 00000000000..9d24a5fde32 --- /dev/null +++ b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch LLaMA model.""" + +import unittest + +import pytest + +from transformers import AutoTokenizer, is_torch_available, set_seed +from transformers.testing_utils import ( + cleanup, + require_read_token, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +if is_torch_available(): + import torch + + from transformers import Lfm2MoeConfig, Lfm2MoeForCausalLM, Lfm2MoeModel + + +class Lfm2MoeModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = Lfm2MoeConfig + base_model_class = Lfm2MoeModel + causal_lm_class = Lfm2MoeForCausalLM + + def __init__( + self, + parent, + layer_types=["full_attention", "conv"], + ): + super().__init__(parent) + self.layer_types = layer_types + + +@require_torch +class Lfm2MoeModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = (Lfm2MoeModel, Lfm2MoeForCausalLM) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": Lfm2MoeModel, + "text-generation": Lfm2MoeForCausalLM, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False + model_tester_class = Lfm2MoeModelTester + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = Lfm2MoeForCausalLM if is_torch_available() else None + + def test_attention_outputs(self): + """Lfm2Moe alternates between attention and short-conv layers.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + seq_len = getattr(self.model_tester, "seq_length", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager").to(torch_device).eval() + config = model.config + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config).to(torch_device).eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + self.assertListEqual(list(attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config).to(torch_device).eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + self_attentions = outputs.attentions + + self.assertEqual(out_len + 1, len(outputs)) + self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types)) + self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + + @pytest.mark.generate + def test_past_key_values_format(self): + """Lfm2Moe has a special cache format as it alternates between attention and conv layers""" + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + model = model_class(config).to(torch_device).eval() + if "use_cache" not in inputs: + inputs["use_cache"] = True + outputs = model(**inputs) + + past_kv = outputs["past_key_values"] + + num_query_attention_heads = config.num_attention_heads + embed_dim = config.hidden_size + per_head_embed_dim = embed_dim // num_query_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_query_attention_heads) + + batch_size, seq_length = inputs["input_ids"].shape[:2] + default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) + default_conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) + + num_cache_decoder_layers = len(past_kv) + self.assertEqual(num_cache_decoder_layers, config.num_hidden_layers) + + for i in range(config.num_hidden_layers): + if config.layer_types[i] == "full_attention": + self_attention_layer_keys = past_kv.key_cache[i] + self_attention_layer_values = past_kv.value_cache[i] + self.assertEqual(self_attention_layer_keys.shape, default_self_attention_shape) + self.assertEqual(self_attention_layer_values.shape, default_self_attention_shape) + else: + conv_layer = past_kv.conv_cache[i] + self.assertEqual(conv_layer.shape, default_conv_shape) + + +@require_torch_accelerator +@require_read_token +@slow +class Lfm2MoeIntegrationTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = None + + @classmethod + def tearDownClass(cls): + del cls.model + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @classmethod + def get_model(cls): + if cls.model is None: + cls.model = Lfm2MoeForCausalLM.from_pretrained( + "LiquidAI/LFM2-8B-A1B", device_map="auto", dtype=torch.bfloat16 + ) + return cls.model + + @slow + def test_model_1a8b_logits(self): + set_seed(1789) + input_ids = [1, 22998, 768, 1947, 797, 22017, 811, 6332, 928, 5743, 797, 779, 48123, 772, 33551, 60996, 523] + model = self.get_model() + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + with torch.no_grad(): + out = model(input_ids).logits.float().cpu() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor( + [ + [ + -1.3855, + -0.5123, + -1.3143, + -1.2144, + -1.0791, + -1.2117, + -1.4704, + -0.7648, + -0.6175, + -1.2402, + -1.1459, + -1.0083, + -1.0247, + -0.8830, + -1.5643, + -1.7266, + -1.6254, + ] + ] + ) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + # Expected portion of the logits + EXPECTED_SLICE = torch.tensor( + [-1.2656, 2.4844, 5.5000, -1.3359, -1.3203, -1.3438, 1.9375, 5.8438, -0.6523, -1.2891] + ) + torch.testing.assert_close(out[0, 0, :10], EXPECTED_SLICE, rtol=1e-4, atol=1e-4) + + @slow + def test_model_1a8b_generation(self): + EXPECTED_TEXT_COMPLETION = """In 1st century A.D., the Roman Empire controlled much of Europe, North Africa, and parts of the Middle East.""" + set_seed(1789) + prompt = "In 1st century A.D., the Roman Empire" + tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-8B-A1B", use_fast=False) + model = self.get_model() + input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to( + model.model.embed_tokens.weight.device + ) + with torch.no_grad(): + generated_ids = model.generate(input_ids, max_new_tokens=15, do_sample=False) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + def test_model_1a8b_batched_chat_generation(self): + prompts = ["Who are you?", "Complete the text: Lorem ipsum dolor ", "The Meji Restoration in Japan ended"] + EXPECTED_TEXT_COMPLETIONS = [ + "Who are you?? \nI am an artificial intelligence assistant designed to provide information, answer questions", + "Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor", + "The Meji Restoration in Japan ended (1868) marked the: \nA) Establishment of a constitutional", + ] + set_seed(1789) + tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-8B-A1B", use_fast=False) + model = self.get_model() + batched_input_ids = tokenizer(prompts, return_tensors="pt", padding=True).to( + model.model.embed_tokens.weight.device + ) + with torch.no_grad(): + generated_ids = model.generate(**batched_input_ids, max_new_tokens=15, do_sample=False) + text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETIONS, text) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 9e857f250cb..8545d91c07b 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -36,6 +36,7 @@ SPECIAL_CASES_TO_ALLOW = { "Ernie4_5Config": ["tie_word_embeddings"], "Ernie4_5_MoeConfig": ["tie_word_embeddings"], "Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"], + "Lfm2MoeConfig": ["tie_word_embeddings"], # used internally during generation to provide the custom logit processors with their necessary information "DiaConfig": [ "delay_pattern",