mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Fix EncoderDecoder cache (#41612)
* Fix EncoderDecoder cache * Add the option for the ddp data tuples to have 2 elems * Modifiy the order of the KV and sliding * Adapted RAG and Whisper to new EncoderDecoderCache * A single comma * Remove kwargs in map * Fixed order in manual injection cache test * Slight changes to support legacy format * Removed Nonnes
This commit is contained in:
@ -1807,8 +1807,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||
# simulate injecting virtual tokens like in prefix tuning
|
||||
num_virtual_tokens = 3
|
||||
past_key_values = [
|
||||
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
|
||||
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
|
||||
(torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
|
||||
(torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
|
||||
]
|
||||
past_key_values = DynamicCache(past_key_values)
|
||||
model_inputs["attention_mask"] = torch.cat(
|
||||
|
Reference in New Issue
Block a user