mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[testing] update test_longcat_generation_cpu
(#41368)
* fix * Update tests/models/longcat_flash/test_modeling_longcat_flash.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
@ -435,16 +435,16 @@ class LongcatFlashIntegrationTest(unittest.TestCase):
|
||||
@require_large_cpu_ram
|
||||
def test_longcat_generation_cpu(self):
|
||||
# takes absolutely forever and a lot RAM, but allows to test the output in the CI
|
||||
model = LongcatFlashForCausalLM.from_pretrained(self.model_id, device_map="cpu", dtype=torch.bfloat16)
|
||||
model = LongcatFlashForCausalLM.from_pretrained(self.model_id, device_map="auto", dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
|
||||
chat = [{"role": "user", "content": "Paris is..."}]
|
||||
inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(inputs, max_new_tokens=10, do_sample=False)
|
||||
outputs = model.generate(inputs, max_new_tokens=3, do_sample=False)
|
||||
|
||||
response = tokenizer.batch_decode(outputs, skip_special_tokens=False)[0]
|
||||
expected_output = "[Round 0] USER:Paris is... ASSISTANT:Paris is... a city of timeless charm, where"
|
||||
expected_output = "[Round 0] USER:Paris is... ASSISTANT:Paris is..."
|
||||
|
||||
self.assertEqual(response, expected_output)
|
||||
|
Reference in New Issue
Block a user