[Model] Lfm2Moe (#41401)

* [new-models] LFM2-MoE

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [docs] add in template lfm2_moe doc files

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [configuration] update configuration class

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [modular][lfm] minor: fix rotary_emb typo

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [modeling] modular/modeling files for Lfm2Moe

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [modeling][lfm2_moe] fix Lfm2Moe modular/modeling

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [configuration][lfm2_moe] update configuration keys with latest config changes

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [misc] make fixup

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [modular][lfm2_moe] address comments: dtype, mlp, buffers

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [configuration][lfm2_moe] add initializer_range

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [modular][lfm2_moe] include init_weights to pass test_initialization

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [tests][causal_lm] include pos_emb as possible rope attribute

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [modeling][lfm2_moe] remove load_balancing_loss_func due to lack of support for hooking expert biases

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [misc] make style

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [modeling][lfm2_moe] MoE refactor PR update in LFM2Moe

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [tests] lfm2_moe: unit tests

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [misc] update LFM2-8B-A1B repo id

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [tests] lfm2: update ModelTests for lfm2

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* Update LFM2 documentation

Updated the LFM2 documentation to reflect the addition of a new model size and clarified architectural details.

* Add Lfm2Moe documentation

Add Lfm2Moe model documentation with overview and example usage.

* [misc] fix ci

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [docs] remove trust_remote_code

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* [misc] ci: fix modular

Signed-off-by: Paul Pak <paulpak58@gmail.com>

* reapply modular

* simplify

* remove static address and inplace op

* simplify

* simplify a bit more the modular

* imports

---------

Signed-off-by: Paul Pak <paulpak58@gmail.com>
Co-authored-by: Maxime Labonne <81252890+mlabonne@users.noreply.github.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Paul Pak
2025-10-07 22:09:58 +09:00
committed by GitHub
parent b4428d545f
commit 0c9a72e457
17 changed files with 1633 additions and 20 deletions

View File

@ -562,6 +562,8 @@
title: LED
- local: model_doc/lfm2
title: LFM2
- local: model_doc/lfm2_moe
title: LFM2Moe
- local: model_doc/llama
title: LLaMA
- local: model_doc/llama2

View File

@ -23,15 +23,15 @@ rendered properly in your Markdown viewer.
## Overview
[LFM2](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) represents a new generation of Liquid Foundation Models developed by [Liquid AI](https://liquid.ai/), specifically designed for edge AI and on-device deployment.
[LFM2](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) represents a new generation of Liquid Foundation Models developed by Liquid AI, specifically designed for edge AI and on-device deployment.
The models are available in three sizes (350M, 700M, and 1.2B parameters) and are engineered to run efficiently on CPU, GPU, and NPU hardware, making them particularly well-suited for applications requiring low latency, offline operation, and privacy.
The models are available in four sizes (350M, 700M, 1.2B, and 2.6B parameters) and are engineered to run efficiently on CPU, GPU, and NPU hardware, making them particularly well-suited for applications requiring low latency, offline operation, and privacy.
## Architecture
The architecture consists of 16 blocks total: 10 double-gated short-range convolution blocks and 6 blocks of grouped query attention. This design stems from the concept of dynamical systems, where linear operations are modulated by input-dependent gates, allowing for "liquid" dynamics that can adapt in real-time. The short convolutions are particularly optimized for embedded SoC CPUs, making them ideal for devices that require fast, local inference without relying on cloud connectivity.
The architecture consists of blocks of gated short convolution blocks and blocks of grouped query attention with QK layernorm. This design stems from the concept of dynamical systems, where linear operations are modulated by input-dependent gates. The short convolutions are particularly optimized for embedded SoC CPUs, making them ideal for devices that require fast, local inference without relying on cloud connectivity.
The key architectural innovation of LFM2 lies in its systematic approach to balancing quality, latency, and memory efficiency through our STAR neural architecture search engine. Using STAR, Liquid AI optimized the models for real-world performance on embedded hardware, measuring actual peak memory usage and inference speed on Qualcomm Snapdragon processors. This results in models that achieve 2x faster decode and prefill performance compared to similar-sized models, while maintaining superior benchmark performance across knowledge, mathematics, instruction following, and multilingual tasks.
LFM2 was designed to maximize quality under strict speed and memory constraints. This was accomplished through a systematic architecture search to optimize the models for real-world performance on embedded hardware by measuring actual peak memory usage and inference speed on Qualcomm Snapdragon processors. This results in models that achieve 2x faster decode and prefill performance compared to similar-sized models, while maintaining superior benchmark performance across knowledge, mathematics, instruction following, and multilingual tasks.
## Example

View File

@ -0,0 +1,83 @@
<!--Copyright 2025 the HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.
-->
# Lfm2Moe
## Overview
LFM2-MoE is a Mixture-of-Experts (MoE) variant of [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38). The LFM2 family is optimized for on-device inference by combining shortrange, inputaware gated convolutions with groupedquery attention (GQA) in a layout tuned to maximize quality under strict speed and memory constraints.
LFM2MoE keeps this fast backbone and introduces sparse MoE feedforward networks to add representational capacity without significantly increasing the active compute path. The first LFM2-MoE release is LFM2-8B-A1B, with 8.3B total parameters and 1.5B active parameters. The model excels in quality (comparable to 3-4B dense models) and speed (faster than other 1.5B class models).
## Example
The following example shows how to generate an answer using the `AutoModelForCausalLM` class.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
model_id = "LiquidAI/LFM2-8B-A1B"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
dtype="bfloat16",
# attn_implementation="flash_attention_2" <- uncomment on compatible GPU
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Generate answer
prompt = "What is C. elegans?"
input_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt",
tokenize=True,
).to(model.device)
output = model.generate(
input_ids,
do_sample=True,
temperature=0.3,
min_p=0.15,
repetition_penalty=1.05,
max_new_tokens=512,
)
print(tokenizer.decode(output[0], skip_special_tokens=False))
```
## Lfm2MoeConfig
[[autodoc]] Lfm2MoeConfig
## Lfm2MoeForCausalLM
[[autodoc]] Lfm2MoeForCausalLM
## Lfm2MoeModel
[[autodoc]] Lfm2MoeModel
- forward
## Lfm2MoePreTrainedModel
[[autodoc]] Lfm2MoePreTrainedModel
- forward

View File

@ -186,6 +186,7 @@ if TYPE_CHECKING:
from .led import *
from .levit import *
from .lfm2 import *
from .lfm2_moe import *
from .lfm2_vl import *
from .lightglue import *
from .lilt import *

View File

@ -226,6 +226,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("led", "LEDConfig"),
("levit", "LevitConfig"),
("lfm2", "Lfm2Config"),
("lfm2_moe", "Lfm2MoeConfig"),
("lfm2_vl", "Lfm2VlConfig"),
("lightglue", "LightGlueConfig"),
("lilt", "LiltConfig"),
@ -670,6 +671,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("led", "LED"),
("levit", "LeViT"),
("lfm2", "Lfm2"),
("lfm2_moe", "Lfm2Moe"),
("lfm2_vl", "Lfm2Vl"),
("lightglue", "LightGlue"),
("lilt", "LiLT"),

View File

@ -226,6 +226,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("led", "LEDModel"),
("levit", "LevitModel"),
("lfm2", "Lfm2Model"),
("lfm2_moe", "Lfm2MoeModel"),
("lfm2_vl", "Lfm2VlModel"),
("lightglue", "LightGlueForKeypointMatching"),
("lilt", "LiltModel"),
@ -694,6 +695,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("lfm2", "Lfm2ForCausalLM"),
("lfm2_moe", "Lfm2MoeForCausalLM"),
("llama", "LlamaForCausalLM"),
("llama4", "Llama4ForCausalLM"),
("llama4_text", "Llama4ForCausalLM"),

View File

@ -163,7 +163,6 @@ class Lfm2HybridConvCache:
dtype=self._dtype,
device=device,
)
torch._dynamo.mark_static_address(conv_state)
self.conv_cache.append(conv_state)
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
@ -595,7 +594,6 @@ class Lfm2Model(Lfm2PreTrainedModel):
self.layers = nn.ModuleList(
[Lfm2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.rotary_emb = Lfm2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.pos_emb = Lfm2RotaryEmbedding(config)
self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)

View File

@ -121,7 +121,6 @@ class Lfm2HybridConvCache:
dtype=self._dtype,
device=device,
)
torch._dynamo.mark_static_address(conv_state)
self.conv_cache.append(conv_state)
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
@ -441,7 +440,7 @@ class Lfm2Model(LlamaModel):
self.pos_emb = Lfm2RotaryEmbedding(config)
self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
del self.norm
del self.rotary_emv
del self.rotary_emb
def forward(
self,

View File

@ -0,0 +1,29 @@
# coding=utf-8
# Copyright 2025 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_lfm2_moe import *
from .modeling_lfm2_moe import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,169 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from ...configuration_utils import PretrainedConfig
class Lfm2MoeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Lfm2MoeModel`]. It is used to instantiate a LFM2 Moe
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LFM2-8B-A1B model.
e.g. [LiquidAI/LFM2-8B-A1B](https://huggingface.co/LiquidAI/LFM2-8B-A1B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 65536):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Lfm2Model`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 7168):
Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 1792):
Intermediate size of the routed expert.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
conv_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the conv layers.
conv_L_cache (`int`, *optional*, defaults to 3):
L_cache dim in the conv layers.
num_dense_layers (`int`, *optional*, defaults to 2):
Number of dense Lfm2MoeMLP layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
num_experts_per_tok (`int`, *optional*, defaults to 4):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 32):
Number of routed experts.
use_expert_bias (`bool`, *optional*, defaults to `True`):
Whether to use the expert bias on the routing weights.
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for routed experts in MoE models.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the topk probabilities.
layer_types (`Optional`, *optional*):
Type of each layers.
```python
>>> from transformers import Lfm2MoeModel, Lfm2MoeConfig
>>> # Initializing a LFM2 Moe model
>>> configuration = Lfm2MoeConfig()
>>> # Initializing a model from the LFM2-8B-A1B style configuration
>>> model = Lfm2MoeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "lfm2_moe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size: int = 65536,
hidden_size: int = 2048,
intermediate_size: int = 7168,
moe_intermediate_size: int = 1792,
num_hidden_layers: int = 32,
pad_token_id: int = 0,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = True,
rope_theta: float = 1000000.0,
max_position_embeddings: int = 128_000,
use_cache: bool = True,
norm_eps: float = 0.00001,
num_attention_heads: int = 32,
num_key_value_heads: int = 8,
conv_bias: bool = False,
conv_L_cache: int = 3,
num_dense_layers: int = 2,
num_experts_per_tok: int = 4,
num_experts: int = 32,
use_expert_bias: bool = True,
routed_scaling_factor: float = 1.0,
norm_topk_prob: bool = True,
layer_types: Optional[list[str]] = None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.use_cache = use_cache
self.norm_eps = norm_eps
# attn operator config
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
# custom operator config
self.conv_bias = conv_bias
self.conv_L_cache = conv_L_cache
# moe config
self.num_dense_layers = num_dense_layers
self.moe_intermediate_size = moe_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.use_expert_bias = use_expert_bias
self.routed_scaling_factor = routed_scaling_factor
self.norm_topk_prob = norm_topk_prob
self.layer_types = layer_types
tie_word_embeddings = kwargs.get("tie_embedding", tie_word_embeddings) # to fit original config keys
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["Lfm2MoeConfig"]

View File

@ -0,0 +1,813 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/lfm2_moe/modular_lfm2_moe.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_lfm2_moe.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.deprecation import deprecate_kwarg
from ...utils.generic import check_model_inputs
from ...utils.import_utils import is_causal_conv1d_available
from .configuration_lfm2_moe import Lfm2MoeConfig
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_fn, causal_conv1d_update = None, None
@use_kernel_forward_from_hub("RMSNorm")
class Lfm2MoeRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Lfm2MoeRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Lfm2MoeRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: Lfm2MoeConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Lfm2MoeMLP(nn.Module):
def __init__(self, config: Lfm2MoeConfig, intermediate_size: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Lfm2MoeExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
for _ in range(config.num_experts):
self.append(Lfm2MoeMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class Lfm2MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.routed_scaling_factor = config.routed_scaling_factor
self.norm_topk_prob = config.norm_topk_prob
self.use_expert_bias = config.use_expert_bias
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Lfm2MoeExperts(config)
if self.use_expert_bias:
self.register_buffer("expert_bias", torch.zeros(config.num_experts, dtype=torch.float32))
def route_tokens_to_experts(self, router_logits):
routing_weights = router_logits.sigmoid()
if self.use_expert_bias:
scores_for_routing = routing_weights + self.expert_bias
_, selected_experts = torch.topk(scores_for_routing, k=self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=1, index=selected_experts).type_as(router_logits)
else:
routing_weights, selected_experts = torch.topk(routing_weights, k=self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-6)
routing_weights = routing_weights * self.routed_scaling_factor
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states_reshaped)
selected_experts, routing_weights = self.route_tokens_to_experts(router_logits)
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
class Lfm2MoeHybridConvCache:
"""
Attention and conv cache for Lfm2Moe.
It stores the Key and Value states as a list of tensors, one for each layer.
Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`.
Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`.
"""
# Override @property existing in Cache
max_batch_size = None
is_compileable = False
key_cache = None
value_cache = None
def __init__(
self,
config: Lfm2MoeConfig,
max_batch_size: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, str, None] = None,
):
self.key_cache = []
self.value_cache = []
self.max_batch_size = max_batch_size
self.layer_types = config.layer_types
self.first_attention_layer = self.layer_types.index("full_attention")
self.conv_L_cache = config.conv_L_cache
self._dtype = dtype
self.conv_cache: list[torch.Tensor] = []
device = torch.device(device) if device is not None else None
for _ in range(config.num_hidden_layers):
conv_state = torch.zeros(
self.max_batch_size,
config.hidden_size,
self.conv_L_cache,
dtype=self._dtype,
device=device,
)
self.conv_cache.append(conv_state)
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the cache
if self.key_cache[layer_idx].numel() == 0:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].numel():
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.conv_cache[layer_idx].numel():
device = self.conv_cache[layer_idx].device
self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# take any layer that contains cache and not empty tensor
layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx
if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
full_mask_kv_offset = 0
query_length = cache_position.shape[0]
past_seen_tokens = self.get_seq_length()
kv_length = query_length + past_seen_tokens
return kv_length, full_mask_kv_offset
def crop(self, max_length: int):
"""Crop the cache to the given length"""
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
if self.get_seq_length() <= max_length:
return
for idx in range(len(self.key_cache)):
if self.key_cache[idx].numel():
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
def __len__(self) -> int:
return len(self.key_cache)
def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reset(self):
for layer_idx in range(len(self.conv_cache)):
# In-place ops prevent breaking the static address
self.conv_cache[layer_idx].zero_()
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Lfm2MoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Lfm2MoeConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.is_causal = True
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.q_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps)
self.k_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Lfm2MoeHybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
output = self.out_proj(attn_output)
return output, attn_weights
def apply_mask_to_padding_states(hidden_states, attention_mask):
"""
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
"""
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
is_fast_path_available = all(kernel_modules)
class Lfm2MoeShortConv(nn.Module):
def __init__(
self,
config: Lfm2MoeConfig,
layer_idx: int,
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.L_cache = config.conv_L_cache
self.bias = config.conv_bias
self.conv = nn.Conv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=self.L_cache,
groups=config.hidden_size,
bias=self.bias,
padding=self.L_cache - 1,
)
self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def cuda_kernels_forward(
self,
x: torch.Tensor,
past_key_values: Optional[Lfm2MoeHybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)
Bx = B * x
conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if past_key_values is not None and cache_position[0] > 0:
conv_out = causal_conv1d_update(
Bx.squeeze(-1),
past_key_values.conv_cache[self.layer_idx],
conv_weights,
self.conv.bias,
None,
)
conv_out = conv_out.unsqueeze(-1)
else:
if past_key_values is not None:
conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
y = C * conv_out
y = self.out_proj(y.transpose(-1, -2).contiguous())
return y
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def slow_forward(
self,
x: torch.Tensor,
past_key_values: Optional[Lfm2MoeHybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
seqlen = x.shape[1]
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)
Bx = B * x
if past_key_values is not None and cache_position[0] > 0:
conv_state = past_key_values.conv_cache[self.layer_idx]
cache_position = cache_position.clamp(0, self.L_cache - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype)
past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
if self.bias:
conv_out += self.conv.bias
conv_out = conv_out.unsqueeze(-1)
else:
if past_key_values is not None:
conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
conv_out = self.conv(Bx)[..., :seqlen]
y = C * conv_out
y = y.transpose(-1, -2).contiguous()
y = self.out_proj(y)
return y
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Optional[Lfm2MoeHybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, past_key_values, cache_position, attention_mask)
return self.slow_forward(hidden_states, past_key_values, cache_position, attention_mask)
class Lfm2MoeDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Lfm2MoeConfig, layer_idx: int):
super().__init__()
self.is_attention_layer = config.layer_types[layer_idx] == "full_attention"
if self.is_attention_layer:
self.self_attn = Lfm2MoeAttention(config, layer_idx)
else:
self.conv = Lfm2MoeShortConv(config, layer_idx)
self.feed_forward = (
Lfm2MoeMLP(config, intermediate_size=config.intermediate_size)
if layer_idx < config.num_dense_layers
else Lfm2MoeSparseMoeBlock(config)
)
self.operator_norm = Lfm2MoeRMSNorm(config.hidden_size, eps=config.norm_eps)
self.ffn_norm = Lfm2MoeRMSNorm(config.hidden_size, eps=config.norm_eps)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Lfm2MoeHybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if self.is_attention_layer:
hidden_states, _ = self.self_attn(
hidden_states=self.operator_norm(hidden_states),
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
else:
hidden_states = self.conv(
hidden_states=self.operator_norm(hidden_states),
past_key_values=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = hidden_states + residual
hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
return hidden_states
@auto_docstring
class Lfm2MoePreTrainedModel(PreTrainedModel):
config: Lfm2MoeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Lfm2MoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Lfm2MoeDecoderLayer,
"attentions": Lfm2MoeAttention,
}
@auto_docstring
class Lfm2MoeModel(Lfm2MoePreTrainedModel):
def __init__(self, config: Lfm2MoeConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Lfm2MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
self.pos_emb = Lfm2MoeRotaryEmbedding(config)
self.embedding_norm = Lfm2MoeRMSNorm(config.hidden_size, eps=config.norm_eps)
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs()
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Lfm2MoeHybridConvCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
batch_size = inputs_embeds.shape[0]
past_key_values = Lfm2MoeHybridConvCache(
config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.pos_emb(hidden_states, position_ids)
# decoder layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.embedding_norm(hidden_states)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
@auto_docstring
class Lfm2MoeForCausalLM(Lfm2MoePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Lfm2MoeModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
Example:
```python
>>> from transformers import AutoTokenizer, Lfm2MoeForCausalLM
>>> model = Lfm2MoeForCausalLM.from_pretrained("meta-lfm2_moe/Lfm2Moe-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-lfm2_moe/Lfm2Moe-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["Lfm2MoeForCausalLM", "Lfm2MoeModel", "Lfm2MoePreTrainedModel"]

View File

@ -0,0 +1,204 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch import nn
from ...masking_utils import create_causal_mask
from ...modeling_outputs import MoeModelOutputWithPast
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, logging
from ...utils.import_utils import is_causal_conv1d_available
from ..lfm2.modeling_lfm2 import Lfm2Attention, Lfm2DecoderLayer, Lfm2HybridConvCache, Lfm2MLP, Lfm2ShortConv
from ..llama.modeling_llama import LlamaForCausalLM, LlamaPreTrainedModel, LlamaRMSNorm, LlamaRotaryEmbedding
from ..mixtral.modeling_mixtral import MixtralModel
from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeExperts
from .configuration_lfm2_moe import Lfm2MoeConfig
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_fn, causal_conv1d_update = None, None
kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
is_fast_path_available = all(kernel_modules)
logger = logging.get_logger(__name__)
class Lfm2MoeRMSNorm(LlamaRMSNorm):
pass
class Lfm2MoeRotaryEmbedding(LlamaRotaryEmbedding):
pass
class Lfm2MoeMLP(Lfm2MLP):
def __init__(self, config: Lfm2MoeConfig, intermediate_size: Optional[int] = None):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
class Lfm2MoeExperts(Qwen2MoeExperts):
pass
class Lfm2MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.routed_scaling_factor = config.routed_scaling_factor
self.norm_topk_prob = config.norm_topk_prob
self.use_expert_bias = config.use_expert_bias
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Lfm2MoeExperts(config)
if self.use_expert_bias:
self.register_buffer("expert_bias", torch.zeros(config.num_experts, dtype=torch.float32))
def route_tokens_to_experts(self, router_logits):
routing_weights = router_logits.sigmoid()
if self.use_expert_bias:
scores_for_routing = routing_weights + self.expert_bias
_, selected_experts = torch.topk(scores_for_routing, k=self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=1, index=selected_experts).type_as(router_logits)
else:
routing_weights, selected_experts = torch.topk(routing_weights, k=self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-6)
routing_weights = routing_weights * self.routed_scaling_factor
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states_reshaped)
selected_experts, routing_weights = self.route_tokens_to_experts(router_logits)
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
class Lfm2MoeHybridConvCache(Lfm2HybridConvCache):
pass
class Lfm2MoeAttention(Lfm2Attention):
pass
class Lfm2MoeShortConv(Lfm2ShortConv):
pass
class Lfm2MoeDecoderLayer(Lfm2DecoderLayer):
def __init__(self, config: Lfm2MoeConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.feed_forward = (
Lfm2MoeMLP(config, intermediate_size=config.intermediate_size)
if layer_idx < config.num_dense_layers
else Lfm2MoeSparseMoeBlock(config)
)
class Lfm2MoePreTrainedModel(LlamaPreTrainedModel):
_can_compile_fullgraph = False
class Lfm2MoeModel(MixtralModel):
def __init__(self, config: Lfm2MoeConfig):
super().__init__(config)
self.pos_emb = Lfm2MoeRotaryEmbedding(config)
self.embedding_norm = Lfm2MoeRMSNorm(config.hidden_size, eps=config.norm_eps)
del self.norm
del self.rotary_emb
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Lfm2MoeHybridConvCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
batch_size = inputs_embeds.shape[0]
past_key_values = Lfm2MoeHybridConvCache(
config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.pos_emb(hidden_states, position_ids)
# decoder layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.embedding_norm(hidden_states)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
class Lfm2MoeForCausalLM(LlamaForCausalLM):
pass
__all__ = ["Lfm2MoeForCausalLM", "Lfm2MoeModel", "Lfm2MoePreTrainedModel"]

View File

@ -448,6 +448,7 @@ class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
# named location of the RoPE layer class.
base_model = self.model_tester.base_model_class(config)
possible_rope_attributes = [
"pos_emb",
"rotary_emb", # most common case
"global_rotary_emb",
"local_rotary_emb",

View File

@ -23,12 +23,15 @@ from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
if is_torch_available():
import torch
from transformers import Lfm2ForCausalLM, Lfm2Model
@ -60,22 +63,82 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase):
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = Lfm2ForCausalLM if is_torch_available() else None
@unittest.skip(
"Lfm2 alternates between attention and conv layers, so attention are only returned for attention layers"
)
def test_attention_outputs(self):
pass
"""Lfm2Moe alternates between attention and short-conv layers."""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
seq_len = getattr(self.model_tester, "seq_length", None)
@unittest.skip("Lfm2 has a special cache format as it alternates between attention and conv layers")
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class._from_config(config, attn_implementation="eager").to(torch_device).eval()
config = model.config
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types))
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config).to(torch_device).eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types))
self.assertListEqual(list(attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len])
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config).to(torch_device).eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self_attentions = outputs.attentions
self.assertEqual(out_len + 1, len(outputs))
self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types))
self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len])
@pytest.mark.generate
def test_past_key_values_format(self):
pass
"""Lfm2Moe has a special cache format as it alternates between attention and conv layers"""
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
@unittest.skip(
"Lfm2 has a special cache format which is not compatible with compile as it has static address for conv cache"
)
@pytest.mark.torch_compile_test
def test_sdpa_can_compile_dynamic(self):
pass
model = model_class(config).to(torch_device).eval()
if "use_cache" not in inputs:
inputs["use_cache"] = True
outputs = model(**inputs)
past_kv = outputs["past_key_values"]
num_query_attention_heads = config.num_attention_heads
embed_dim = config.hidden_size
per_head_embed_dim = embed_dim // num_query_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_query_attention_heads)
batch_size, seq_length = inputs["input_ids"].shape[:2]
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
default_conv_shape = (batch_size, config.hidden_size, config.conv_L_cache)
num_cache_decoder_layers = len(past_kv)
self.assertEqual(num_cache_decoder_layers, config.num_hidden_layers)
for i in range(config.num_hidden_layers):
if config.layer_types[i] == "full_attention":
self_attention_layer_keys = past_kv.key_cache[i]
self_attention_layer_values = past_kv.value_cache[i]
self.assertEqual(self_attention_layer_keys.shape, default_self_attention_shape)
self.assertEqual(self_attention_layer_values.shape, default_self_attention_shape)
else:
conv_layer = past_kv.conv_cache[i]
self.assertEqual(conv_layer.shape, default_conv_shape)
@require_torch_accelerator

View File

View File

@ -0,0 +1,246 @@
# coding=utf-8
# Copyright 2025 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch LLaMA model."""
import unittest
import pytest
from transformers import AutoTokenizer, is_torch_available, set_seed
from transformers.testing_utils import (
cleanup,
require_read_token,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
if is_torch_available():
import torch
from transformers import Lfm2MoeConfig, Lfm2MoeForCausalLM, Lfm2MoeModel
class Lfm2MoeModelTester(CausalLMModelTester):
if is_torch_available():
config_class = Lfm2MoeConfig
base_model_class = Lfm2MoeModel
causal_lm_class = Lfm2MoeForCausalLM
def __init__(
self,
parent,
layer_types=["full_attention", "conv"],
):
super().__init__(parent)
self.layer_types = layer_types
@require_torch
class Lfm2MoeModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (Lfm2MoeModel, Lfm2MoeForCausalLM) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": Lfm2MoeModel,
"text-generation": Lfm2MoeForCausalLM,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False
model_tester_class = Lfm2MoeModelTester
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = Lfm2MoeForCausalLM if is_torch_available() else None
def test_attention_outputs(self):
"""Lfm2Moe alternates between attention and short-conv layers."""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
seq_len = getattr(self.model_tester, "seq_length", None)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class._from_config(config, attn_implementation="eager").to(torch_device).eval()
config = model.config
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types))
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config).to(torch_device).eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types))
self.assertListEqual(list(attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len])
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config).to(torch_device).eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self_attentions = outputs.attentions
self.assertEqual(out_len + 1, len(outputs))
self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types))
self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len])
@pytest.mark.generate
def test_past_key_values_format(self):
"""Lfm2Moe has a special cache format as it alternates between attention and conv layers"""
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(torch_device).eval()
if "use_cache" not in inputs:
inputs["use_cache"] = True
outputs = model(**inputs)
past_kv = outputs["past_key_values"]
num_query_attention_heads = config.num_attention_heads
embed_dim = config.hidden_size
per_head_embed_dim = embed_dim // num_query_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_query_attention_heads)
batch_size, seq_length = inputs["input_ids"].shape[:2]
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
default_conv_shape = (batch_size, config.hidden_size, config.conv_L_cache)
num_cache_decoder_layers = len(past_kv)
self.assertEqual(num_cache_decoder_layers, config.num_hidden_layers)
for i in range(config.num_hidden_layers):
if config.layer_types[i] == "full_attention":
self_attention_layer_keys = past_kv.key_cache[i]
self_attention_layer_values = past_kv.value_cache[i]
self.assertEqual(self_attention_layer_keys.shape, default_self_attention_shape)
self.assertEqual(self_attention_layer_values.shape, default_self_attention_shape)
else:
conv_layer = past_kv.conv_cache[i]
self.assertEqual(conv_layer.shape, default_conv_shape)
@require_torch_accelerator
@require_read_token
@slow
class Lfm2MoeIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = None
@classmethod
def tearDownClass(cls):
del cls.model
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@classmethod
def get_model(cls):
if cls.model is None:
cls.model = Lfm2MoeForCausalLM.from_pretrained(
"LiquidAI/LFM2-8B-A1B", device_map="auto", dtype=torch.bfloat16
)
return cls.model
@slow
def test_model_1a8b_logits(self):
set_seed(1789)
input_ids = [1, 22998, 768, 1947, 797, 22017, 811, 6332, 928, 5743, 797, 779, 48123, 772, 33551, 60996, 523]
model = self.get_model()
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.float().cpu()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor(
[
[
-1.3855,
-0.5123,
-1.3143,
-1.2144,
-1.0791,
-1.2117,
-1.4704,
-0.7648,
-0.6175,
-1.2402,
-1.1459,
-1.0083,
-1.0247,
-0.8830,
-1.5643,
-1.7266,
-1.6254,
]
]
)
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
# Expected portion of the logits
EXPECTED_SLICE = torch.tensor(
[-1.2656, 2.4844, 5.5000, -1.3359, -1.3203, -1.3438, 1.9375, 5.8438, -0.6523, -1.2891]
)
torch.testing.assert_close(out[0, 0, :10], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
@slow
def test_model_1a8b_generation(self):
EXPECTED_TEXT_COMPLETION = """In 1st century A.D., the Roman Empire controlled much of Europe, North Africa, and parts of the Middle East."""
set_seed(1789)
prompt = "In 1st century A.D., the Roman Empire"
tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-8B-A1B", use_fast=False)
model = self.get_model()
input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(
model.model.embed_tokens.weight.device
)
with torch.no_grad():
generated_ids = model.generate(input_ids, max_new_tokens=15, do_sample=False)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
def test_model_1a8b_batched_chat_generation(self):
prompts = ["Who are you?", "Complete the text: Lorem ipsum dolor ", "The Meji Restoration in Japan ended"]
EXPECTED_TEXT_COMPLETIONS = [
"Who are you?? \nI am an artificial intelligence assistant designed to provide information, answer questions",
"Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor",
"The Meji Restoration in Japan ended (1868) marked the: \nA) Establishment of a constitutional",
]
set_seed(1789)
tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-8B-A1B", use_fast=False)
model = self.get_model()
batched_input_ids = tokenizer(prompts, return_tensors="pt", padding=True).to(
model.model.embed_tokens.weight.device
)
with torch.no_grad():
generated_ids = model.generate(**batched_input_ids, max_new_tokens=15, do_sample=False)
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETIONS, text)

View File

@ -36,6 +36,7 @@ SPECIAL_CASES_TO_ALLOW = {
"Ernie4_5Config": ["tie_word_embeddings"],
"Ernie4_5_MoeConfig": ["tie_word_embeddings"],
"Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],
"Lfm2MoeConfig": ["tie_word_embeddings"],
# used internally during generation to provide the custom logit processors with their necessary information
"DiaConfig": [
"delay_pattern",