[testing] update test_longcat_generation_cpu (#41368)

* fix

* Update tests/models/longcat_flash/test_modeling_longcat_flash.py

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-10-06 15:21:29 +02:00
committed by GitHub
parent 4903cd4087
commit 6bf6e36d3b

View File

@ -435,16 +435,16 @@ class LongcatFlashIntegrationTest(unittest.TestCase):
@require_large_cpu_ram
def test_longcat_generation_cpu(self):
# takes absolutely forever and a lot RAM, but allows to test the output in the CI
model = LongcatFlashForCausalLM.from_pretrained(self.model_id, device_map="cpu", dtype=torch.bfloat16)
model = LongcatFlashForCausalLM.from_pretrained(self.model_id, device_map="auto", dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
chat = [{"role": "user", "content": "Paris is..."}]
inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(inputs, max_new_tokens=10, do_sample=False)
outputs = model.generate(inputs, max_new_tokens=3, do_sample=False)
response = tokenizer.batch_decode(outputs, skip_special_tokens=False)[0]
expected_output = "[Round 0] USER:Paris is... ASSISTANT:Paris is... a city of timeless charm, where"
expected_output = "[Round 0] USER:Paris is... ASSISTANT:Paris is..."
self.assertEqual(response, expected_output)