Compare commits

...

5 Commits

Author SHA1 Message Date
3963a17762 nit 2024-07-31 09:25:18 +02:00
6c1cc67a38 fix? 2024-07-31 09:23:21 +02:00
60f1f426d6 fix test 2024-07-31 08:17:29 +02:00
606aa37a4a add tests 2024-07-31 08:16:23 +02:00
8deb370946 fix 2024-07-31 08:11:30 +02:00
2 changed files with 11 additions and 1 deletions

View File

@ -1316,7 +1316,7 @@ class GemmaConvert(SpmConverter):
return vocab
def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Split(" ", "merged_with_previous")
return pre_tokenizers.Split(Regex('(?<!▁)▁'), "merged_with_next")
def unk_id(self, proto):
unk_id = 3

View File

@ -188,6 +188,16 @@ class GemmaIntegrationTest(unittest.TestCase):
},
)
@require_read_token
def test_word_ids(self):
fast_tokenizer = GemmaTokenizerFast.from_pretrained("google/gemma-2b-it", from_slow=True)
input_text = " rust . x. . . "
EXPECTED_WORD_IDS = [None, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5]
EXPECTED_INPUT_IDS = [2, 161, 14783, 160, 235265, 1141, 235265, 150, 235265, 140, 235265, 144]
output = fast_tokenizer(input_text)
self.assertEqual(output["input_ids"], EXPECTED_INPUT_IDS)
self.assertEqual(output.word_ids(), EXPECTED_WORD_IDS)
def test_user_added_tokens(self):
# Ensure that user added tokens are not split in the fast tokenizer
slow_tokenizer = self.tokenizer