Compare commits

...

13 Commits

Author SHA1 Message Date
22fddd7058 Merge branch 'main' into fix-t5-gemma-config 2025-08-28 11:55:36 +02:00
5e07abf0c9 Fix style 2025-08-28 09:53:20 +00:00
df57c4545b Fix edge cases 2025-08-28 09:47:16 +00:00
859075fe5e Merge branch 'main' into fix-t5-gemma-config 2025-08-28 11:01:24 +02:00
4464609b67 Simplify logic 2025-08-28 08:53:01 +00:00
086df615ac Merge branch 'main' into fix-t5-gemma-config 2025-08-27 11:52:02 +02:00
ef13c59d5b Fix cache 2025-08-27 09:50:46 +00:00
4567081a7e Revert T5Gemma 2025-08-27 09:43:28 +00:00
3942fb850e Fix qualiry 2025-08-27 09:15:24 +00:00
879e295cc0 Fix config 2025-08-27 09:09:28 +00:00
cc21a55e95 Add missing setter 2025-08-27 08:24:20 +00:00
189cfe96f2 Merge branch 'main' into fix-t5-gemma-config 2025-08-27 10:11:40 +02:00
725d8f3f74 Fix num_hidden_layers 2025-08-27 08:10:03 +00:00
3 changed files with 51 additions and 11 deletions

View File

@ -2001,11 +2001,17 @@ class GenerationMixin(ContinuousMixin):
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
else:
model_kwargs[cache_name] = (
DynamicCache(**dynamic_cache_kwargs)
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(**dynamic_cache_kwargs), DynamicCache(**dynamic_cache_kwargs))
)
if not requires_cross_attention_cache:
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
else:
# For encoder-decoder models, we need to use separate configs for encoder and decoder
decoder_cache_kwargs = {}
encoder_cache_kwargs = {}
decoder_cache_kwargs["config"] = self.config.get_text_config(decoder=True)
encoder_cache_kwargs["config"] = self.config.get_text_config(encoder=True)
model_kwargs[cache_name] = EncoderDecoderCache(
DynamicCache(**decoder_cache_kwargs), DynamicCache(**encoder_cache_kwargs)
)
def _supports_logits_to_keep(self) -> bool:
"""

View File

@ -323,9 +323,26 @@ class T5GemmaConfig(PretrainedConfig):
setattr(self.decoder, key, value)
super().__setattr__(key, value)
def get_text_config(self, *args, **kwargs):
# Always return self, regardless of the decoder option.
return self
def get_text_config(self, decoder=None, encoder=None):
"""
Returns the text config related to the text input (encoder) or text output (decoder) of the model.
Args:
decoder (`Optional[bool]`, *optional*):
If set to `True`, returns the decoder config.
encoder (`Optional[bool]`, *optional*):
If set to `True`, returns the encoder config.
"""
if decoder is True and encoder is False:
return self.decoder
elif encoder is True and decoder is False:
return self.encoder
elif decoder is None and encoder is None:
# Default case - return decoder for generation compatibility
return self.decoder
else:
# For any other case (both True, both False, etc.), return self
return self
__all__ = ["T5GemmaConfig", "T5GemmaModuleConfig"]

View File

@ -207,9 +207,26 @@ class T5GemmaConfig(PretrainedConfig):
setattr(self.decoder, key, value)
super().__setattr__(key, value)
def get_text_config(self, *args, **kwargs):
# Always return self, regardless of the decoder option.
return self
def get_text_config(self, decoder=None, encoder=None):
"""
Returns the text config related to the text input (encoder) or text output (decoder) of the model.
Args:
decoder (`Optional[bool]`, *optional*):
If set to `True`, returns the decoder config.
encoder (`Optional[bool]`, *optional*):
If set to `True`, returns the encoder config.
"""
if decoder is True and encoder is False:
return self.decoder
elif encoder is True and decoder is False:
return self.encoder
elif decoder is None and encoder is None:
# Default case - return decoder for generation compatibility
return self.decoder
else:
# For any other case (both True, both False, etc.), return self
return self
class T5GemmaRMSNorm(Gemma2RMSNorm):