mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Fix cache-related tests (#39676)
* fix * fix kyutai at last * fix unrelated tests and copies * update musicgen as well * revert tensor * fix old test failures * why it wasn't added?
This commit is contained in:
committed by
GitHub
parent
fc2bd1eac0
commit
1c6b47451d
@ -2055,7 +2055,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
generation_config.cache_implementation = None
|
||||
|
||||
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
|
||||
self.config.get_text_config(), "cache_implementation", None
|
||||
self.config.get_text_config(decoder=True), "cache_implementation", None
|
||||
)
|
||||
if generation_config.cache_implementation is not None:
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||
|
@ -1215,12 +1215,15 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
|
||||
cache_methods = [
|
||||
"_prepare_cache_for_generation",
|
||||
"_get_cache",
|
||||
"_supports_default_dynamic_cache",
|
||||
"_get_layer_device_map_for_cache_init",
|
||||
]
|
||||
for method in cache_methods:
|
||||
setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
|
||||
|
||||
setattr(
|
||||
self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model)
|
||||
)
|
||||
|
||||
self.codec_model._prepare_cache_for_generation(
|
||||
generation_config=self.codec_model.generation_config,
|
||||
model_kwargs=temporary_model_kwargs,
|
||||
|
@ -344,12 +344,15 @@ class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMix
|
||||
cache_methods = [
|
||||
"_prepare_cache_for_generation",
|
||||
"_get_cache",
|
||||
"_supports_default_dynamic_cache",
|
||||
"_get_layer_device_map_for_cache_init",
|
||||
]
|
||||
for method in cache_methods:
|
||||
setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
|
||||
|
||||
setattr(
|
||||
self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model)
|
||||
)
|
||||
|
||||
self.codec_model._prepare_cache_for_generation(
|
||||
generation_config=self.codec_model.generation_config,
|
||||
model_kwargs=temporary_model_kwargs,
|
||||
|
@ -1246,7 +1246,29 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
# 6. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 6. Prepare the cache.
|
||||
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
||||
# - different models have a different cache name expected by the model (default = "past_key_values")
|
||||
# - `max_length`, prepared above, is used to determine the maximum cache length
|
||||
max_cache_length = generation_config.max_length - 1
|
||||
if (
|
||||
input_ids_length.shape[1] != input_ids_length
|
||||
and model_input_name == "inputs_embeds"
|
||||
and not self.config.is_encoder_decoder
|
||||
):
|
||||
max_cache_length += input_ids_length.shape[1]
|
||||
self._prepare_cache_for_generation(
|
||||
generation_config,
|
||||
model_kwargs,
|
||||
assistant_model=None,
|
||||
batch_size=batch_size,
|
||||
max_cache_length=max_cache_length,
|
||||
device=input_ids_length.device,
|
||||
)
|
||||
|
||||
# 7. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
|
||||
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
|
||||
input_ids,
|
||||
@ -1260,15 +1282,15 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
|
||||
# stash the delay mask so that we don't have to recompute it in each forward pass
|
||||
model_kwargs["delay_pattern_mask"] = delay_pattern_mask
|
||||
|
||||
# 7. determine generation mode
|
||||
# 8. determine generation mode
|
||||
generation_mode = generation_config.get_generation_mode()
|
||||
|
||||
# 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
|
||||
# 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
|
||||
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
|
||||
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
|
||||
generation_config.guidance_scale = None
|
||||
|
||||
# 9. prepare distribution pre_processing samplers
|
||||
# 10. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
|
@ -2162,6 +2162,28 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 7. Prepare the cache.
|
||||
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
||||
# - different models have a different cache name expected by the model (default = "past_key_values")
|
||||
# - `max_length`, prepared above, is used to determine the maximum cache length
|
||||
max_cache_length = generation_config.max_length - 1
|
||||
if (
|
||||
inputs_tensor.shape[1] != input_ids_length
|
||||
and model_input_name == "inputs_embeds"
|
||||
and not self.config.is_encoder_decoder
|
||||
):
|
||||
max_cache_length += inputs_tensor.shape[1]
|
||||
self._prepare_cache_for_generation(
|
||||
generation_config,
|
||||
model_kwargs,
|
||||
assistant_model=None,
|
||||
batch_size=batch_size,
|
||||
max_cache_length=max_cache_length,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
|
||||
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
|
||||
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
|
||||
input_ids,
|
||||
@ -2175,15 +2197,15 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
# 7. determine generation mode
|
||||
# 8. determine generation mode
|
||||
generation_mode = generation_config.get_generation_mode()
|
||||
|
||||
# 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
|
||||
# 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
|
||||
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
|
||||
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
|
||||
generation_config.guidance_scale = None
|
||||
|
||||
# 9. prepare distribution pre_processing samplers
|
||||
# 10. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
|
@ -1204,8 +1204,6 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
|
||||
if isinstance(past_key_values, EncoderDecoderCache):
|
||||
reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past)
|
||||
|
||||
if isinstance(past_key_values, EncoderDecoderCache):
|
||||
reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past)
|
||||
return reordered_past
|
||||
|
||||
def marginalize(self, seq_logits, doc_scores, n_docs=None):
|
||||
@ -1593,13 +1591,6 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
# 11. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_beams,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
**model_kwargs,
|
||||
)
|
||||
return self._beam_search(
|
||||
input_ids,
|
||||
logits_processor=pre_processor,
|
||||
|
@ -261,6 +261,17 @@ class RoFormerSelfAttention(nn.Module):
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
# Apply RoPE if self attention
|
||||
if not is_cross_attention and sinusoidal_pos is not None:
|
||||
if self.rotary_value:
|
||||
query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings(
|
||||
sinusoidal_pos, query_layer, key_layer, value_layer
|
||||
)
|
||||
else:
|
||||
query_layer, key_layer = self.apply_rotary_position_embeddings(
|
||||
sinusoidal_pos, query_layer, key_layer
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
@ -381,13 +392,13 @@ class RoFormerAttention(nn.Module):
|
||||
):
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
sinusoidal_pos,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
attention_mask=attention_mask,
|
||||
sinusoidal_pos=sinusoidal_pos,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
|
@ -274,7 +274,7 @@ class SuperGlueSelfAttention(nn.Module):
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else encoder_attention_mask
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
key_layer = (
|
||||
|
@ -515,7 +515,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# test that changing `strategy` won't error out
|
||||
model.vision_feature_select_strategy = "full"
|
||||
|
||||
inputs = self.processor(self.prompt, self.image, return_tensors="pt").to(model.device)
|
||||
inputs = self.processor(text=self.prompt, images=self.image, return_tensors="pt").to(model.device)
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=30)
|
||||
@ -536,7 +536,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(granite_model_path)
|
||||
self.processor = AutoProcessor.from_pretrained(granite_model_path)
|
||||
prompt = "<|user|>\n<image>\nWhat is shown in this image?\n<|assistant|>\n"
|
||||
inputs = self.processor(prompt, self.image, return_tensors="pt").to(model.device)
|
||||
inputs = self.processor(text=prompt, images=self.image, return_tensors="pt").to(model.device)
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=30)
|
||||
|
@ -467,7 +467,9 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
inputs_single = self.processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device)
|
||||
inputs_single = self.processor(text=self.prompt_video, videos=[self.video], return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# verify generation
|
||||
output_batched = model.generate(**inputs_batched, do_sample=False, max_new_tokens=50)
|
||||
|
@ -413,7 +413,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
@ -698,7 +697,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(
|
||||
text=text * 2,
|
||||
text=[text] * 2,
|
||||
audio=[self.raw_audio, self.raw_audio],
|
||||
images=[self.raw_image, self.raw_image],
|
||||
return_tensors="pt",
|
||||
|
@ -403,7 +403,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
@ -362,7 +362,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
@ -119,7 +119,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom",
|
||||
"content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom natural language processing\nTo machine learning and more, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and",
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -150,7 +150,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
[
|
||||
{
|
||||
"input_text": "<image> What this is? Assistant: This is",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable. The photo captures a moment of tranquility and companionship between the two feline friends.",
|
||||
}
|
||||
],
|
||||
)
|
||||
@ -161,11 +161,11 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
[
|
||||
{
|
||||
"input_text": "<image> What this is? Assistant: This is",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they appear to be sleeping or resting. The blanket is placed on a couch, and the cats are positioned in such a way that they are facing the camera. The image captures a peaceful moment between the two cats, and it's a great way to showcase their cuteness and relaxed demeanor.",
|
||||
},
|
||||
{
|
||||
"input_text": "<image> What this is? Assistant: This is",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they appear to be sleeping or resting. The blanket is placed on a couch, and the cats are positioned in such a way that they are facing the camera. The image captures a peaceful moment between the two cats, and it's a great way to showcase their cuteness and relaxed demeanor.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
Reference in New Issue
Block a user