Fix cache-related tests (#39676)

* fix

* fix kyutai at last

* fix unrelated tests and copies

* update musicgen as well

* revert tensor

* fix old test failures

* why it wasn't added?
This commit is contained in:
Raushan Turganbay
2025-07-28 17:30:11 +02:00
committed by GitHub
parent fc2bd1eac0
commit 1c6b47451d
14 changed files with 89 additions and 38 deletions

View File

@ -2055,7 +2055,7 @@ class GenerationMixin(ContinuousMixin):
generation_config.cache_implementation = None
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
self.config.get_text_config(), "cache_implementation", None
self.config.get_text_config(decoder=True), "cache_implementation", None
)
if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:

View File

@ -1215,12 +1215,15 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
cache_methods = [
"_prepare_cache_for_generation",
"_get_cache",
"_supports_default_dynamic_cache",
"_get_layer_device_map_for_cache_init",
]
for method in cache_methods:
setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
setattr(
self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model)
)
self.codec_model._prepare_cache_for_generation(
generation_config=self.codec_model.generation_config,
model_kwargs=temporary_model_kwargs,

View File

@ -344,12 +344,15 @@ class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMix
cache_methods = [
"_prepare_cache_for_generation",
"_get_cache",
"_supports_default_dynamic_cache",
"_get_layer_device_map_for_cache_init",
]
for method in cache_methods:
setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
setattr(
self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model)
)
self.codec_model._prepare_cache_for_generation(
generation_config=self.codec_model.generation_config,
model_kwargs=temporary_model_kwargs,

View File

@ -1246,7 +1246,29 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
input_ids_length=input_ids_length,
)
# 6. Prepare `input_ids` which will be used for auto-regressive generation
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 6. Prepare the cache.
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
# - different models have a different cache name expected by the model (default = "past_key_values")
# - `max_length`, prepared above, is used to determine the maximum cache length
max_cache_length = generation_config.max_length - 1
if (
input_ids_length.shape[1] != input_ids_length
and model_input_name == "inputs_embeds"
and not self.config.is_encoder_decoder
):
max_cache_length += input_ids_length.shape[1]
self._prepare_cache_for_generation(
generation_config,
model_kwargs,
assistant_model=None,
batch_size=batch_size,
max_cache_length=max_cache_length,
device=input_ids_length.device,
)
# 7. Prepare `input_ids` which will be used for auto-regressive generation
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
@ -1260,15 +1282,15 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
# stash the delay mask so that we don't have to recompute it in each forward pass
model_kwargs["delay_pattern_mask"] = delay_pattern_mask
# 7. determine generation mode
# 8. determine generation mode
generation_mode = generation_config.get_generation_mode()
# 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
# 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
# 10. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,

View File

@ -2162,6 +2162,28 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
input_ids_length=input_ids_length,
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. Prepare the cache.
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
# - different models have a different cache name expected by the model (default = "past_key_values")
# - `max_length`, prepared above, is used to determine the maximum cache length
max_cache_length = generation_config.max_length - 1
if (
inputs_tensor.shape[1] != input_ids_length
and model_input_name == "inputs_embeds"
and not self.config.is_encoder_decoder
):
max_cache_length += inputs_tensor.shape[1]
self._prepare_cache_for_generation(
generation_config,
model_kwargs,
assistant_model=None,
batch_size=batch_size,
max_cache_length=max_cache_length,
device=inputs_tensor.device,
)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
@ -2175,15 +2197,15 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
if streamer is not None:
streamer.put(input_ids.cpu())
# 7. determine generation mode
# 8. determine generation mode
generation_mode = generation_config.get_generation_mode()
# 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
# 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
# 10. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,

View File

@ -1204,8 +1204,6 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
if isinstance(past_key_values, EncoderDecoderCache):
reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past)
if isinstance(past_key_values, EncoderDecoderCache):
reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past)
return reordered_past
def marginalize(self, seq_logits, doc_scores, n_docs=None):
@ -1593,13 +1591,6 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
return self._beam_search(
input_ids,
logits_processor=pre_processor,

View File

@ -261,6 +261,17 @@ class RoFormerSelfAttention(nn.Module):
.transpose(1, 2)
)
# Apply RoPE if self attention
if not is_cross_attention and sinusoidal_pos is not None:
if self.rotary_value:
query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings(
sinusoidal_pos, query_layer, key_layer, value_layer
)
else:
query_layer, key_layer = self.apply_rotary_position_embeddings(
sinusoidal_pos, query_layer, key_layer
)
if past_key_value is not None:
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
@ -381,13 +392,13 @@ class RoFormerAttention(nn.Module):
):
self_outputs = self.self(
hidden_states,
attention_mask,
sinusoidal_pos,
head_mask,
encoder_hidden_states,
past_key_value,
output_attentions,
cache_position,
attention_mask=attention_mask,
sinusoidal_pos=sinusoidal_pos,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them

View File

@ -274,7 +274,7 @@ class SuperGlueSelfAttention(nn.Module):
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else encoder_attention_mask
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
batch_size = hidden_states.shape[0]
key_layer = (

View File

@ -515,7 +515,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
# test that changing `strategy` won't error out
model.vision_feature_select_strategy = "full"
inputs = self.processor(self.prompt, self.image, return_tensors="pt").to(model.device)
inputs = self.processor(text=self.prompt, images=self.image, return_tensors="pt").to(model.device)
# verify generation
output = model.generate(**inputs, max_new_tokens=30)
@ -536,7 +536,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
model = LlavaNextForConditionalGeneration.from_pretrained(granite_model_path)
self.processor = AutoProcessor.from_pretrained(granite_model_path)
prompt = "<|user|>\n<image>\nWhat is shown in this image?\n<|assistant|>\n"
inputs = self.processor(prompt, self.image, return_tensors="pt").to(model.device)
inputs = self.processor(text=prompt, images=self.image, return_tensors="pt").to(model.device)
# verify generation
output = model.generate(**inputs, max_new_tokens=30)

View File

@ -467,7 +467,9 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
padding=True,
).to(torch_device)
inputs_single = self.processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device)
inputs_single = self.processor(text=self.prompt_video, videos=[self.video], return_tensors="pt").to(
torch_device
)
# verify generation
output_batched = model.generate(**inputs_batched, do_sample=False, max_new_tokens=50)

View File

@ -413,7 +413,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@ -698,7 +697,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(
text=text * 2,
text=[text] * 2,
audio=[self.raw_audio, self.raw_audio],
images=[self.raw_image, self.raw_image],
return_tensors="pt",

View File

@ -403,7 +403,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)

View File

@ -362,7 +362,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)

View File

@ -119,7 +119,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
},
{
"role": "assistant",
"content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom",
"content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom natural language processing\nTo machine learning and more, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and",
},
],
}
@ -150,7 +150,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
[
{
"input_text": "<image> What this is? Assistant: This is",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable. The photo captures a moment of tranquility and companionship between the two feline friends.",
}
],
)
@ -161,11 +161,11 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
[
{
"input_text": "<image> What this is? Assistant: This is",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they appear to be sleeping or resting. The blanket is placed on a couch, and the cats are positioned in such a way that they are facing the camera. The image captures a peaceful moment between the two cats, and it's a great way to showcase their cuteness and relaxed demeanor.",
},
{
"input_text": "<image> What this is? Assistant: This is",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they appear to be sleeping or resting. The blanket is placed on a couch, and the cats are positioned in such a way that they are facing the camera. The image captures a peaceful moment between the two cats, and it's a great way to showcase their cuteness and relaxed demeanor.",
},
],
)