🚨 [v5] generate delegates default cache initialization to the model (#41505)

This commit is contained in:
Joao Gante
2025-10-13 13:20:48 +01:00
committed by GitHub
parent d7c9fbdb64
commit d621be8286
31 changed files with 114 additions and 50 deletions

View File

@ -2167,7 +2167,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
class WhisperNoSpeechDetection(LogitsProcessor):
r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation"""
"""
This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits
to follow the original implementation
"""
def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False):
self.no_speech_token = no_speech_token
@ -2188,6 +2191,10 @@ class WhisperNoSpeechDetection(LogitsProcessor):
self.model = model
def set_inputs(self, inputs):
# build `cache_position` on the fly
seq_length = inputs["input_ids"].shape[1]
inputs = self.model._get_initial_cache_position(seq_length, self.model.device, inputs)
# prepare other inputs
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
self.inputs["input_features"] = self.inputs.pop("inputs")

View File

@ -548,6 +548,12 @@ class GenerationMixin(ContinuousMixin):
# 2. Generic cache-dependent input preparation
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
# We check `use_cache` below because some stateful models (like `recurrent_gemma`) expect input slicing if
# their caching mechanism is used. To define `use_cache`, the user-defined argument takes precedence.
use_cache = kwargs.get("use_cache")
if use_cache is None:
use_cache = getattr(self.config, "use_cache", False)
if past_key_values is None or use_cache:
# TODO (joao): handle the case where cache length == input_ids length. The function below results in an
# exception because we get empty input_ids after slicing. In essence, we need to roll back the cache 1
# token to recompute the logits for the first token to be generated (but not all caches support roll backs)
@ -589,7 +595,7 @@ class GenerationMixin(ContinuousMixin):
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
model_input = kwargs.get(model_input_name)
if model_input is not None:
if past_key_values is not None:
if past_key_values is not None or use_cache:
current_input_length = (
model_inputs["inputs_embeds"].shape[1]
if model_inputs.get("inputs_embeds") is not None
@ -1999,17 +2005,15 @@ class GenerationMixin(ContinuousMixin):
elif "dynamic" in generation_config.cache_implementation:
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
# TODO (joao): remove this `else` when we remove the last traces of the legacy cache format (v4.58.0, search
# for `instance(past_key_values, Cache)` as well). In general, if `cache_implementation` is unset, cache
# initialization should happen inside the model at prefill time.
else:
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
# TODO (joao): this logic is incomplete, e.g. `offloaded` should apply to both caches. Refactor this function
# to correctly pass parameterization to both caches.
if requires_cross_attention_cache and not isinstance(model_kwargs[cache_name], EncoderDecoderCache):
model_kwargs[cache_name] = EncoderDecoderCache(
model_kwargs[cache_name], # self-attention cache
if (
requires_cross_attention_cache
and "past_key_values" in model_kwargs
and not isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
):
model_kwargs["past_key_values"] = EncoderDecoderCache(
model_kwargs["past_key_values"], # self-attention cache
DynamicCache(**dynamic_cache_kwargs), # cross-attention cache
)
@ -3335,7 +3339,7 @@ class GenerationMixin(ContinuousMixin):
# pluck the cache from the beam indices that will be used in the next iteration
# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
if model_kwargs.get("past_key_values", None) is not None:
if model_kwargs.get("past_key_values") is not None:
beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len])
if hasattr(self, "_reorder_cache"):
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)

View File

@ -802,7 +802,7 @@ class BartDecoder(BartPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -671,7 +671,11 @@ class BertModel(BertPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -533,7 +533,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -1976,7 +1976,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -752,7 +752,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -739,7 +739,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -651,7 +651,11 @@ class CamembertModel(CamembertPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -611,7 +611,11 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -610,7 +610,11 @@ class ElectraModel(ElectraPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -647,7 +647,11 @@ class ErnieModel(ErniePreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")

View File

@ -234,7 +234,11 @@ class ErnieModel(BertModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")

View File

@ -616,10 +616,6 @@ class FSMTDecoder(nn.Module):
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# initialize `past_key_values`
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
x += positions
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
@ -912,6 +908,9 @@ class FSMTModel(PretrainedFSMTModel):
if decoder_input_ids is None and decoder_inputs_embeds is None:
raise ValueError("Make sure that `decoder_input_ids` or `decoder_inputs_embeds` are passed.")
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,

View File

@ -1029,7 +1029,7 @@ class Kosmos2TextTransformer(nn.Module):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -750,7 +750,7 @@ class MarianDecoder(MarianPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -800,7 +800,7 @@ class MBartDecoder(MBartPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -802,7 +802,7 @@ class MvpDecoder(MvpPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -802,7 +802,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -715,7 +715,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -1182,7 +1182,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -621,7 +621,11 @@ class RobertaModel(RobertaPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -645,7 +645,11 @@ class RobertaPreLayerNormModel(RobertaPreLayerNormPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -430,9 +430,6 @@ class RoFormerEncoder(nn.Module):
)
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@ -736,6 +733,13 @@ class RoFormerModel(RoFormerPreTrainedModel):
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
if attention_mask is None:

View File

@ -550,7 +550,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -808,12 +808,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
if self.config.is_encoder_decoder:
past_key_values = EncoderDecoderCache(
DynamicCache(config=self.config), DynamicCache(config=self.config)
)
else:
past_key_values = DynamicCache(config=self.config)
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
past_key_values_length = 0
if cache_position is not None:

View File

@ -469,7 +469,7 @@ class XGLMModel(XGLMPreTrainedModel):
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)

View File

@ -640,7 +640,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -627,7 +627,11 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -742,7 +742,11 @@ class XmodModel(XmodPreTrainedModel):
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch RoFormer model."""
import copy
import unittest
from transformers import RoFormerConfig, is_torch_available
@ -209,6 +210,8 @@ class RoFormerModelTester:
token_labels,
choice_labels,
):
config = copy.deepcopy(config)
config.is_decoder = True
model = RoFormerForCausalLM(config=config).to(torch_device).eval()
torch.manual_seed(0)
output_without_past_cache = model.generate(