Compare commits

...

1 Commits

Author SHA1 Message Date
d404483547 fix 2025-07-05 13:28:53 +02:00

View File

@ -2298,7 +2298,7 @@ class GenerationTesterMixin:
NOTE: despite the test logic being the same, different implementations actually need different decorators, hence
this separate function.
"""
max_new_tokens = 30
max_new_tokens = 5
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn_2",
@ -2351,8 +2351,6 @@ class GenerationTesterMixin:
attn_implementation="eager",
).to(torch_device)
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
del model_eager
gc.collect()
model_attn = model_class.from_pretrained(
tmpdirname,
@ -2360,8 +2358,6 @@ class GenerationTesterMixin:
attn_implementation=attn_implementation,
).to(torch_device)
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
del model_attn
gc.collect()
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)