mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
🚨 [v5] generate
delegates default cache initialization to the model (#41505)
This commit is contained in:
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch RoFormer model."""
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from transformers import RoFormerConfig, is_torch_available
|
||||
@ -209,6 +210,8 @@ class RoFormerModelTester:
|
||||
token_labels,
|
||||
choice_labels,
|
||||
):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
model = RoFormerForCausalLM(config=config).to(torch_device).eval()
|
||||
torch.manual_seed(0)
|
||||
output_without_past_cache = model.generate(
|
||||
|
Reference in New Issue
Block a user