Fix EncoderDecoder cache (#41612)

* Fix EncoderDecoder cache

* Add the option for the ddp data tuples to have 2 elems

* Modifiy the order of the KV and sliding

* Adapted RAG and Whisper to new EncoderDecoderCache

* A single comma

* Remove kwargs in map

* Fixed order in manual injection cache test

* Slight changes to support legacy format

* Removed Nonnes
This commit is contained in:
Rémi Ouazan
2025-10-16 14:55:41 +02:00
committed by GitHub
parent 35dc8f0a2e
commit eef9fb2af3
4 changed files with 57 additions and 38 deletions

View File

@ -1807,8 +1807,8 @@ class ModelUtilsTest(TestCasePlus):
# simulate injecting virtual tokens like in prefix tuning
num_virtual_tokens = 3
past_key_values = [
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
(torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
(torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
]
past_key_values = DynamicCache(past_key_values)
model_inputs["attention_mask"] = torch.cat(