mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
* fix * Update tests/models/longcat_flash/test_modeling_longcat_flash.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
451 lines
18 KiB
Python
451 lines
18 KiB
Python
# Copyright 2025 Meituan and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch LongcatFlash model."""
|
|
|
|
import copy
|
|
import tempfile
|
|
import unittest
|
|
|
|
from parameterized import parameterized
|
|
from pytest import mark
|
|
|
|
from transformers import LongcatFlashConfig, is_torch_available, set_seed
|
|
from transformers.testing_utils import (
|
|
require_bitsandbytes,
|
|
require_flash_attn,
|
|
require_large_cpu_ram,
|
|
require_torch,
|
|
require_torch_gpu,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
|
from ...test_modeling_common import ids_tensor
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import AutoTokenizer, LongcatFlashForCausalLM, LongcatFlashModel
|
|
|
|
|
|
class LongcatFlashModelTester(CausalLMModelTester):
|
|
if is_torch_available():
|
|
base_model_class = LongcatFlashModel
|
|
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=2,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
hidden_size=144,
|
|
ffn_hidden_size=288,
|
|
expert_ffn_hidden_size=48,
|
|
num_layers=1, # We have `self.num_hidden_layers = 2 * num_layers` in the body. See `LongcatFlashConfig`.
|
|
num_attention_heads=8,
|
|
num_key_value_heads=8,
|
|
kv_lora_rank=16,
|
|
q_lora_rank=48,
|
|
qk_rope_head_dim=4,
|
|
v_head_dim=8,
|
|
qk_nope_head_dim=8,
|
|
head_dim=4,
|
|
n_routed_experts=4,
|
|
zero_expert_num=2,
|
|
moe_topk=2,
|
|
routed_scaling_factor=1.0,
|
|
hidden_act="silu",
|
|
max_position_embeddings=128,
|
|
initializer_range=0.02,
|
|
rms_norm_eps=1e-6,
|
|
bos_token_id=1,
|
|
eos_token_id=2,
|
|
pad_token_id=3,
|
|
type_sequence_label_size=2,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
):
|
|
super().__init__(parent)
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.ffn_hidden_size = ffn_hidden_size
|
|
self.expert_ffn_hidden_size = expert_ffn_hidden_size
|
|
self.num_layers = num_layers
|
|
self.num_hidden_layers = 2 * num_layers # for compatibility
|
|
self.expected_num_hidden_layers = 2 # embedding + 2 layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.q_lora_rank = q_lora_rank
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.head_dim = head_dim
|
|
self.n_routed_experts = n_routed_experts
|
|
self.zero_expert_num = zero_expert_num
|
|
self.moe_topk = moe_topk
|
|
self.routed_scaling_factor = routed_scaling_factor
|
|
self.hidden_act = hidden_act
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.initializer_range = initializer_range
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
self.pad_token_id = pad_token_id
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
|
|
def get_config(self):
|
|
return LongcatFlashConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
ffn_hidden_size=self.ffn_hidden_size,
|
|
expert_ffn_hidden_size=self.expert_ffn_hidden_size,
|
|
num_layers=self.num_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
num_key_value_heads=self.num_key_value_heads,
|
|
kv_lora_rank=self.kv_lora_rank,
|
|
q_lora_rank=self.q_lora_rank,
|
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
v_head_dim=self.v_head_dim,
|
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
head_dim=self.head_dim,
|
|
n_routed_experts=self.n_routed_experts,
|
|
zero_expert_num=self.zero_expert_num,
|
|
moe_topk=self.moe_topk,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
hidden_act=self.hidden_act,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
initializer_range=self.initializer_range,
|
|
rms_norm_eps=self.rms_norm_eps,
|
|
pad_token_id=self.pad_token_id,
|
|
)
|
|
|
|
def create_and_check_model(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = LongcatFlashModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_for_causal_lm(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
):
|
|
model = LongcatFlashForCausalLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
|
|
|
token_type_ids = None
|
|
|
|
sequence_labels = None
|
|
token_labels = None
|
|
choice_labels = None
|
|
if self.use_labels:
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
|
|
config = self.get_config()
|
|
|
|
return (
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = config_and_inputs
|
|
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class LongcatFlashModelTest(CausalLMModelTest, unittest.TestCase):
|
|
pipeline_model_mapping = (
|
|
{
|
|
"feature-extraction": LongcatFlashModel,
|
|
"text-generation": LongcatFlashForCausalLM,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
|
|
model_split_percents = [0.5, 0.8]
|
|
|
|
model_tester_class = LongcatFlashModelTester
|
|
|
|
@unittest.skip("LongcatFlash buffers include complex numbers, which breaks this test")
|
|
def test_save_load_fast_init_from_base(self):
|
|
pass
|
|
|
|
@unittest.skip("LongcatFlash buffers include complex numbers, which breaks this test")
|
|
def test_save_load_fast_init_to_base(self):
|
|
pass
|
|
|
|
def test_past_key_values_format(self):
|
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
|
batch_size, seq_length = inputs["input_ids"].shape
|
|
|
|
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
|
v_embed_dim = config.v_head_dim
|
|
|
|
self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
|
self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
|
|
|
num_hidden_layers = config.num_hidden_layers
|
|
all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)]
|
|
|
|
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
|
|
|
|
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
|
from transformers.cache_utils import Cache
|
|
|
|
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
|
|
|
|
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
|
v_embed_dim = config.v_head_dim
|
|
|
|
expected_key_shape = (batch_size, config.num_key_value_heads, cache_length, k_embed_dim)
|
|
expected_value_shape = (batch_size, config.num_key_value_heads, cache_length, v_embed_dim)
|
|
|
|
if isinstance(decoder_past_key_values, Cache):
|
|
for layer_idx in range(config.num_hidden_layers):
|
|
self.assertEqual(decoder_past_key_values.layers[layer_idx].keys.shape, expected_key_shape)
|
|
self.assertEqual(decoder_past_key_values.layers[layer_idx].values.shape, expected_value_shape)
|
|
else:
|
|
for layer_past in decoder_past_key_values:
|
|
self.assertEqual(layer_past[0].shape, expected_key_shape)
|
|
self.assertEqual(layer_past[1].shape, expected_value_shape)
|
|
|
|
@unittest.skip("MoE experts may not receive gradients with small test data")
|
|
def test_training_gradient_checkpointing(self):
|
|
pass
|
|
|
|
@unittest.skip("MoE experts may not receive gradients with small test data")
|
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
|
pass
|
|
|
|
@unittest.skip("MoE experts may not receive gradients with small test data")
|
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
|
pass
|
|
|
|
@unittest.skip("LongcatFlash router uses weight.type() directly in forward which prevents offloading")
|
|
def test_cpu_offload(self):
|
|
pass
|
|
|
|
@unittest.skip("LongcatFlash router uses weight.type() directly in forward which prevents offloading")
|
|
def test_disk_offload_bin(self):
|
|
pass
|
|
|
|
@unittest.skip("LongcatFlash router uses weight.type() directly in forward which prevents offloading")
|
|
def test_disk_offload_safetensors(self):
|
|
pass
|
|
|
|
@unittest.skip("Most probably because of the MOE, the moe and router does not ignore padding tokens")
|
|
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SDPA can't dispatch on flash due to unsupported head dims")
|
|
def test_sdpa_can_dispatch_on_flash(self):
|
|
pass
|
|
|
|
@staticmethod
|
|
def _prepare_config_headdim(config, requested_dim):
|
|
# there's specific head dims due to lora compressions in longcat
|
|
config = copy.deepcopy(config)
|
|
config.attention_dropout = 0
|
|
|
|
if requested_dim > config.qk_rope_head_dim:
|
|
config.qk_rope_head_dim = requested_dim
|
|
config.qk_nope_head_dim = max(config.qk_nope_head_dim, requested_dim)
|
|
config.v_head_dim = max(config.v_head_dim, requested_dim)
|
|
config.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
|
config.head_dim = requested_dim
|
|
config.q_lora_rank = max(config.q_lora_rank, requested_dim * 4)
|
|
config.kv_lora_rank = max(config.kv_lora_rank, requested_dim * 2)
|
|
config.hidden_size = max(config.hidden_size, config.num_attention_heads * requested_dim)
|
|
|
|
return config
|
|
|
|
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
|
|
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
|
|
|
set_seed(42)
|
|
original_model = self.model_tester_class.base_model_class(config)
|
|
original_model.to(torch_device)
|
|
original_model.eval()
|
|
original_short_output = original_model(short_input).last_hidden_state
|
|
original_long_output = original_model(long_input).last_hidden_state
|
|
|
|
set_seed(42)
|
|
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
|
scaled_model = self.model_tester_class.base_model_class(config)
|
|
scaled_model.to(torch_device)
|
|
scaled_model.eval()
|
|
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
|
|
if scaling_type == "dynamic":
|
|
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
|
else:
|
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
|
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
|
|
@require_flash_attn
|
|
@require_torch_gpu
|
|
@require_bitsandbytes
|
|
@mark.flash_attn_test
|
|
@slow
|
|
def test_flash_attn_2_fp32_ln(self):
|
|
if not self.has_attentions:
|
|
self.skipTest(reason="Model architecture does not support attentions")
|
|
|
|
for model_class in self.all_generative_model_classes: # TODO: this test should run on all classes instead
|
|
if not model_class._supports_flash_attn:
|
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
model = model_class(config)
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
|
|
dummy_input = inputs_dict[model.main_input_name]
|
|
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
batch_size = dummy_attention_mask.shape[0]
|
|
|
|
is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size
|
|
|
|
# To avoid errors with padding_side=="right"
|
|
if is_padding_right:
|
|
dummy_attention_mask = torch.ones_like(dummy_input)
|
|
|
|
model = model_class.from_pretrained(
|
|
tmpdirname,
|
|
dtype=torch.float16,
|
|
attn_implementation="flash_attention_2",
|
|
device_map="auto", # small change to ensure device placement
|
|
)
|
|
|
|
# no upcasting at all
|
|
|
|
if model.config.is_encoder_decoder:
|
|
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
|
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
|
|
|
_ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
|
|
# with attention mask
|
|
_ = model(
|
|
dummy_input,
|
|
attention_mask=dummy_attention_mask,
|
|
decoder_input_ids=dummy_decoder_input_ids,
|
|
decoder_attention_mask=dummy_decoder_attention_mask,
|
|
)
|
|
else:
|
|
_ = model(dummy_input)
|
|
# with attention mask
|
|
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
|
|
|
|
|
@slow
|
|
class LongcatFlashIntegrationTest(unittest.TestCase):
|
|
short_model_id = "hf-internal-testing/LongCat-ShortCat"
|
|
# This is a cut-down model that matches part of the early logits of the larger one
|
|
# Only a couple experts + layers
|
|
# But if it fails, it means the larger model might have issues as well
|
|
model_id = "meituan-longcat/LongCat-Flash-Chat"
|
|
|
|
@slow
|
|
def test_shortcat_generation(self):
|
|
self.model = LongcatFlashForCausalLM.from_pretrained(
|
|
self.short_model_id,
|
|
device_map="auto",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
self.model.generation_config.bos_token_id = 1
|
|
self.model.generation_config.pad_token_id = 3
|
|
self.model.generation_config.eos_token_id = 2
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
|
|
chat = [{"role": "user", "content": "Paris is..."}]
|
|
inputs = self.tokenizer.apply_chat_template(
|
|
chat, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
|
).to(self.model.device)
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model.generate(inputs, max_new_tokens=10, do_sample=False)
|
|
|
|
response = self.tokenizer.batch_decode(outputs, skip_special_tokens=False)[0]
|
|
expected_output = "[Round 0] USER:Paris is... ASSISTANT: dig年车龄juanaheast稍achaotingupebarebones"
|
|
|
|
self.assertEqual(response, expected_output)
|
|
|
|
@slow
|
|
@require_large_cpu_ram
|
|
def test_longcat_generation_cpu(self):
|
|
# takes absolutely forever and a lot RAM, but allows to test the output in the CI
|
|
model = LongcatFlashForCausalLM.from_pretrained(self.model_id, device_map="auto", dtype=torch.bfloat16)
|
|
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
|
|
chat = [{"role": "user", "content": "Paris is..."}]
|
|
inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
|
|
|
with torch.no_grad():
|
|
outputs = model.generate(inputs, max_new_tokens=3, do_sample=False)
|
|
|
|
response = tokenizer.batch_decode(outputs, skip_special_tokens=False)[0]
|
|
expected_output = "[Round 0] USER:Paris is... ASSISTANT:Paris is..."
|
|
|
|
self.assertEqual(response, expected_output)
|