Compare commits

...

4 Commits

Author SHA1 Message Date
95185c2eed Merge branch 'main' into fix_whatever 2025-07-09 16:44:31 +02:00
36adaea730 fix 2025-07-08 12:56:09 +02:00
de020917bf fix 2025-07-08 12:56:09 +02:00
801e062c99 fix 2025-07-08 12:56:09 +02:00
2 changed files with 14 additions and 7 deletions

View File

@ -993,16 +993,23 @@ def check_model_inputs(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
use_cache = kwargs.get("use_cache", getattr(self.config, "use_cache", False))
return_dict = kwargs.pop("return_dict", getattr(self.config, "return_dict", True))
all_args = kwargs.copy()
use_cache = kwargs.get("use_cache", None)
if use_cache is None:
use_cache = getattr(self.config, "use_cache", False)
return_dict = kwargs.pop("return_dict", None)
if return_dict is None:
return_dict = getattr(self.config, "return_dict", True)
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
kwargs["use_cache"] = False
use_cache = False
kwargs["use_cache"] = use_cache
all_args = kwargs.copy()
if "kwargs" in all_args:
for k, v in all_args["kwargs"].items():
all_args[k] = v

View File

@ -311,7 +311,7 @@ class T5GemmaModelTester:
decoder_attention_mask=decoder_attention_mask,
labels=lm_labels,
)
self.parent.assertEqual(len(outputs), 4)
self.parent.assertEqual(len(outputs), 5)
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(outputs["loss"].size(), ())
@ -1067,7 +1067,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
for i in range(num_decoder_layers):
if is_legacy_cache:
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
@ -1687,7 +1687,7 @@ class TestAsymmetricT5Gemma(unittest.TestCase):
labels=lm_labels,
)
# outputs = model(*inputs)
assert len(outputs) == 4
assert len(outputs) == 5
assert outputs["logits"].size() == (tester.batch_size, tester.seq_length, tester.vocab_size)
assert outputs["loss"].size() == ()
return model.model