Compare commits

...

4 Commits

2 changed files with 13 additions and 1 deletions

View File

@ -34,7 +34,13 @@ from ...modeling_tf_utils import (
shape_list,
unpack_inputs,
)
from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
DUMMY_INPUTS,
ModelOutput,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_rag import RagConfig
from .retrieval_rag import RagRetriever
@ -228,6 +234,10 @@ class TFRagPreTrainedModel(TFPreTrainedModel):
base_model_prefix = "rag"
_keys_to_ignore_on_load_missing = [r"position_ids"]
@property
def dummy_inputs(self):
return {"input_ids": tf.constant(DUMMY_INPUTS)}
@classmethod
def from_pretrained_question_encoder_generator(
cls,

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import gc
import json
import os
import shutil
@ -904,6 +905,7 @@ class TFRagModelIntegrationTests(unittest.TestCase):
@slow
def test_rag_sequence_generate_batch_from_context_input_ids(self):
gc.collect()
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True