mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Add Whole Word Masking and Padding Strategy to DataCollatorForLanguageModeling (#39485)
* Add whole word masking * Vectorize whole word masking functions * Unit test whole word masking * Remove support for TF in whole word masking
This commit is contained in:
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
@ -22,7 +21,6 @@ from typing import Any, Callable, NewType, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..models.bert import BertTokenizer, BertTokenizerFast
|
||||
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from ..utils import PaddingStrategy
|
||||
|
||||
@ -630,6 +628,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
|
||||
with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
|
||||
tokens and the value to predict for the masked token.
|
||||
whole_word_mask (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to mask whole words instead of individual tokens.
|
||||
mlm_probability (`float`, *optional*, defaults to 0.15):
|
||||
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
|
||||
mask_replace_prob (`float`, *optional*, defaults to 0.8):
|
||||
@ -681,6 +681,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
mlm: bool = True
|
||||
whole_word_mask: bool = False
|
||||
mlm_probability: Optional[float] = 0.15
|
||||
mask_replace_prob: float = 0.8
|
||||
random_replace_prob: float = 0.1
|
||||
@ -698,6 +699,11 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
if self.mlm_probability is None or self.mlm_probability < 0 or self.mlm_probability > 1:
|
||||
raise ValueError("mlm_probability should be between 0 and 1.")
|
||||
self.mlm_probability = float(self.mlm_probability)
|
||||
elif self.whole_word_mask:
|
||||
raise ValueError(
|
||||
"Whole word masking can only be used with mlm=True."
|
||||
"If you want to use whole word masking, please set mlm=True."
|
||||
)
|
||||
if self.mask_replace_prob + self.random_replace_prob > 1:
|
||||
raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1")
|
||||
if self.mask_replace_prob < 0 or self.mask_replace_prob > 1:
|
||||
@ -708,6 +714,21 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
self.mask_replace_prob = float(self.mask_replace_prob)
|
||||
self.random_replace_prob = float(self.random_replace_prob)
|
||||
|
||||
if self.whole_word_mask:
|
||||
if not self.tokenizer.is_fast:
|
||||
warnings.warn(
|
||||
"Whole word masking depends on offset mapping which is only natively available with fast tokenizers.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if self.mask_replace_prob < 1:
|
||||
warnings.warn(
|
||||
"Random token replacement is not supported with whole word masking.",
|
||||
"Setting mask_replace_prob to 1.",
|
||||
)
|
||||
self.mask_replace_prob = 1
|
||||
self.random_replace_prob = 0
|
||||
|
||||
self.generator = None
|
||||
|
||||
def get_generator(self, seed):
|
||||
@ -762,9 +783,10 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
# If special token mask has been preprocessed, pop it from the dict.
|
||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||
offset_mapping = batch.pop("offset_mapping", None)
|
||||
if self.mlm:
|
||||
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
|
||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||
batch["input_ids"], special_tokens_mask=special_tokens_mask, offset_mapping=offset_mapping
|
||||
)
|
||||
else:
|
||||
labels = batch["input_ids"].clone()
|
||||
@ -773,9 +795,11 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> tuple[Any, Any]:
|
||||
def torch_mask_tokens(
|
||||
self, inputs: Any, special_tokens_mask: Optional[Any] = None, offset_mapping: Optional[Any] = None
|
||||
) -> tuple[Any, Any]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
||||
Prepare masked tokens inputs/labels for masked language modeling.
|
||||
"""
|
||||
import torch
|
||||
|
||||
@ -786,12 +810,24 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
special_tokens_mask = [
|
||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||
]
|
||||
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
||||
else:
|
||||
special_tokens_mask = special_tokens_mask.bool()
|
||||
|
||||
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
||||
if self.whole_word_mask:
|
||||
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
|
||||
to_numpy(offset_mapping), to_numpy(special_tokens_mask)
|
||||
)
|
||||
no_mask_mask = torch.tensor(no_mask_mask, dtype=torch.bool)
|
||||
else:
|
||||
no_mask_mask = (
|
||||
special_tokens_mask.bool()
|
||||
if isinstance(special_tokens_mask, torch.Tensor)
|
||||
else torch.tensor(special_tokens_mask, dtype=torch.bool)
|
||||
)
|
||||
|
||||
probability_matrix.masked_fill_(no_mask_mask, value=0.0)
|
||||
masked_indices = torch.bernoulli(probability_matrix, generator=self.generator).bool()
|
||||
if self.whole_word_mask:
|
||||
masked_indices = torch.BoolTensor(self._whole_word_mask(word_ids, masked_indices))
|
||||
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
@ -841,9 +877,10 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
# If special token mask has been preprocessed, pop it from the dict.
|
||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||
offset_mapping = batch.pop("offset_mapping", None)
|
||||
if self.mlm:
|
||||
batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
|
||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||
batch["input_ids"], special_tokens_mask=special_tokens_mask, offset_mapping=offset_mapping
|
||||
)
|
||||
else:
|
||||
labels = np.copy(batch["input_ids"])
|
||||
@ -852,9 +889,14 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> tuple[Any, Any]:
|
||||
def numpy_mask_tokens(
|
||||
self,
|
||||
inputs: Any,
|
||||
special_tokens_mask: Optional[Any] = None,
|
||||
offset_mapping: Optional[Any] = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
||||
Prepare masked tokens inputs/labels for masked language modeling.
|
||||
"""
|
||||
labels = np.copy(inputs)
|
||||
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
||||
@ -863,16 +905,28 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
special_tokens_mask = [
|
||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||
]
|
||||
special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
|
||||
else:
|
||||
special_tokens_mask = special_tokens_mask.astype(bool)
|
||||
|
||||
probability_matrix[special_tokens_mask] = 0
|
||||
if self.whole_word_mask:
|
||||
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
|
||||
to_numpy(offset_mapping), to_numpy(special_tokens_mask)
|
||||
)
|
||||
else:
|
||||
no_mask_mask = (
|
||||
special_tokens_mask.astype(bool)
|
||||
if isinstance(special_tokens_mask, np.ndarray)
|
||||
else np.array(special_tokens_mask, dtype=bool)
|
||||
)
|
||||
|
||||
probability_matrix[no_mask_mask] = 0
|
||||
# Numpy doesn't have bernoulli, so we use a binomial with 1 trial
|
||||
if self.generator:
|
||||
masked_indices = self.generator.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
||||
else:
|
||||
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
||||
|
||||
if self.whole_word_mask:
|
||||
masked_indices = self._whole_word_mask(word_ids, masked_indices)
|
||||
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
@ -917,6 +971,51 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
@staticmethod
|
||||
def _calc_word_ids_and_prob_mask(
|
||||
offsets: np.ndarray[np.ndarray[tuple[int, int]]], special_tokens_mask: np.ndarray[np.ndarray[int]]
|
||||
) -> tuple[np.ndarray[np.ndarray[int]], np.ndarray[np.ndarray[int]]]:
|
||||
"""
|
||||
Map tokens to word ids and create mask of tokens to not mask.
|
||||
Tokens that are part of the same word will have the same word id and we will only
|
||||
set a mask probability for the first token of each word.
|
||||
"""
|
||||
|
||||
token_starts = offsets[:, :, 0]
|
||||
token_ends = offsets[:, :, 1]
|
||||
|
||||
prev_token_ends = np.roll(token_ends, 1, axis=1)
|
||||
prev_token_ends[:, 0] = -1 # First token has no previous token
|
||||
|
||||
prev_token_special = np.roll(special_tokens_mask, 1, axis=1)
|
||||
prev_token_special[:, 0] = 0
|
||||
|
||||
# Not special token AND (gap from previous or previous token was special)
|
||||
special_tokens_mask = special_tokens_mask.astype(bool)
|
||||
is_new_word = (~special_tokens_mask) & ((token_starts != prev_token_ends) | (prev_token_special == 1))
|
||||
|
||||
word_ids = np.cumsum(is_new_word, axis=1)
|
||||
word_ids[special_tokens_mask] = -1
|
||||
|
||||
prob_mask = ~is_new_word
|
||||
|
||||
return word_ids, prob_mask
|
||||
|
||||
@staticmethod
|
||||
def _whole_word_mask(word_ids: np.ndarray[np.ndarray[int]], mask: Any) -> Any:
|
||||
"""
|
||||
Mask whole words based on word ids and mask.
|
||||
"""
|
||||
mask = to_numpy(mask)
|
||||
|
||||
valid_ids = word_ids != -1
|
||||
|
||||
# Create 3D mask where [batch, token_i, token_j] is True if token_i and token_j are the same word
|
||||
same_word = (word_ids[:, :, None] == word_ids[:, None, :]) & valid_ids[:, :, None] & valid_ids[:, None, :]
|
||||
|
||||
# For each token, set True if any token in the same word is masked
|
||||
return np.any(same_word & mask[:, None, :], axis=2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
@ -925,261 +1024,20 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
|
||||
- collates batches of tensors, honoring their tokenizer's pad_token
|
||||
- preprocesses batches for masked language modeling
|
||||
"""
|
||||
|
||||
<Tip>
|
||||
|
||||
This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
|
||||
that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
|
||||
produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
|
||||
|
||||
</Tip>"""
|
||||
|
||||
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
|
||||
if self.seed and self.generator is None:
|
||||
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
|
||||
# If no seed supplied, we will use the global RNG
|
||||
self.create_rng()
|
||||
|
||||
if isinstance(examples[0], Mapping):
|
||||
input_ids = [e["input_ids"] for e in examples]
|
||||
else:
|
||||
input_ids = examples
|
||||
examples = [{"input_ids": e} for e in examples]
|
||||
|
||||
batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
||||
|
||||
mask_labels = []
|
||||
for e in examples:
|
||||
ref_tokens = []
|
||||
for id in tolist(e["input_ids"]):
|
||||
token = self.tokenizer._convert_id_to_token(id)
|
||||
ref_tokens.append(token)
|
||||
|
||||
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
||||
if "chinese_ref" in e:
|
||||
ref_pos = tolist(e["chinese_ref"])
|
||||
len_seq = len(e["input_ids"])
|
||||
for i in range(len_seq):
|
||||
if i in ref_pos:
|
||||
ref_tokens[i] = "##" + ref_tokens[i]
|
||||
mask_labels.append(self._whole_word_mask(ref_tokens))
|
||||
batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
||||
inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
|
||||
return {"input_ids": inputs, "labels": labels}
|
||||
|
||||
def numpy_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
|
||||
if self.seed and self.generator is None:
|
||||
# If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
|
||||
# If no seed supplied, we will use the global RNG
|
||||
self.create_rng()
|
||||
|
||||
if isinstance(examples[0], Mapping):
|
||||
input_ids = [e["input_ids"] for e in examples]
|
||||
else:
|
||||
input_ids = examples
|
||||
examples = [{"input_ids": e} for e in examples]
|
||||
|
||||
batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
||||
|
||||
mask_labels = []
|
||||
for e in examples:
|
||||
ref_tokens = []
|
||||
for id in tolist(e["input_ids"]):
|
||||
token = self.tokenizer._convert_id_to_token(id)
|
||||
ref_tokens.append(token)
|
||||
|
||||
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
||||
if "chinese_ref" in e:
|
||||
ref_pos = tolist(e["chinese_ref"])
|
||||
len_seq = len(e["input_ids"])
|
||||
for i in range(len_seq):
|
||||
if i in ref_pos:
|
||||
ref_tokens[i] = "##" + ref_tokens[i]
|
||||
mask_labels.append(self._whole_word_mask(ref_tokens))
|
||||
batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
||||
inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
|
||||
return {"input_ids": inputs, "labels": labels}
|
||||
|
||||
def _shuffle(self, cand_indexes):
|
||||
# if no seed, just use random's shuffle
|
||||
if self.seed is None:
|
||||
random.shuffle(cand_indexes)
|
||||
return cand_indexes
|
||||
|
||||
# if seed is provided, use the generator to shuffle
|
||||
if self.return_tensors == "pt":
|
||||
import torch
|
||||
|
||||
indices = torch.randperm(len(cand_indexes), generator=self.generator)
|
||||
return [cand_indexes[i] for i in indices]
|
||||
|
||||
elif self.return_tensors == "np":
|
||||
self.generator.shuffle(cand_indexes)
|
||||
return cand_indexes
|
||||
|
||||
def _whole_word_mask(self, input_tokens: list[str], max_predictions=512):
|
||||
"""
|
||||
Get 0/1 labels for masked tokens with whole word mask proxy
|
||||
"""
|
||||
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
|
||||
warnings.warn(
|
||||
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
|
||||
"Please refer to the documentation for more information."
|
||||
)
|
||||
|
||||
cand_indexes = []
|
||||
for i, token in enumerate(input_tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
|
||||
if len(cand_indexes) >= 1 and token.startswith("##"):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
|
||||
cand_indexes = self._shuffle(cand_indexes)
|
||||
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
|
||||
masked_lms = []
|
||||
covered_indexes = set()
|
||||
for index_set in cand_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(masked_lms) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
for index in index_set:
|
||||
covered_indexes.add(index)
|
||||
masked_lms.append(index)
|
||||
|
||||
if len(covered_indexes) != len(masked_lms):
|
||||
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
|
||||
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
|
||||
return mask_labels
|
||||
|
||||
def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> tuple[Any, Any]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
||||
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if self.tokenizer.mask_token is None:
|
||||
raise ValueError(
|
||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
||||
" --mlm flag if you want to use this tokenizer."
|
||||
)
|
||||
labels = inputs.clone()
|
||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||
|
||||
probability_matrix = mask_labels
|
||||
|
||||
special_tokens_mask = [
|
||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||
]
|
||||
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
||||
if self.tokenizer.pad_token is not None:
|
||||
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
||||
|
||||
masked_indices = probability_matrix.bool()
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = (
|
||||
torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool()
|
||||
& masked_indices
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"DataCollatorForWholeWordMask is deprecated and will be removed in a future version, you can now use "
|
||||
"DataCollatorForLanguageModeling with whole_word_mask=True instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||
|
||||
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||
return inputs, labels
|
||||
|
||||
remaining_prob = 1 - self.mask_replace_prob
|
||||
# scaling the random_replace_prob to the remaining probability for example if
|
||||
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
|
||||
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
|
||||
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||
|
||||
# random_replacement_prob% of the time, we replace masked input tokens with random word
|
||||
indices_random = (
|
||||
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool()
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator)
|
||||
inputs[indices_random] = random_words[indices_random]
|
||||
|
||||
# The rest of the time ((1-random_replacement_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> tuple[Any, Any]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
||||
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
||||
"""
|
||||
if self.tokenizer.mask_token is None:
|
||||
raise ValueError(
|
||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
||||
" --mlm flag if you want to use this tokenizer."
|
||||
)
|
||||
labels = np.copy(inputs)
|
||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||
|
||||
masked_indices = mask_labels.astype(bool)
|
||||
|
||||
special_tokens_mask = [
|
||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||
]
|
||||
masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
|
||||
if self.tokenizer.pad_token is not None:
|
||||
padding_mask = labels == self.tokenizer.pad_token_id
|
||||
masked_indices[padding_mask] = 0
|
||||
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# mask_replacement_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
if self.generator:
|
||||
indices_replaced = (
|
||||
self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
|
||||
)
|
||||
else:
|
||||
indices_replaced = (
|
||||
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
|
||||
)
|
||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||
|
||||
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
|
||||
return inputs, labels
|
||||
|
||||
remaining_prob = 1 - self.mask_replace_prob
|
||||
# scaling the random_replace_prob to the remaining probability for example if
|
||||
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
|
||||
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
|
||||
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
|
||||
|
||||
if self.generator:
|
||||
indices_random = (
|
||||
self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = self.generator.integers(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
|
||||
else:
|
||||
indices_random = (
|
||||
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
|
||||
& masked_indices
|
||||
& ~indices_replaced
|
||||
)
|
||||
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
|
||||
|
||||
inputs[indices_random] = random_words[indices_random]
|
||||
|
||||
# The rest of the time ((1-mask_replace_prob-random_replace_prob)% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
super().__init__(*args, **kwargs)
|
||||
self.mlm = True # Force masked language modeling
|
||||
self.whole_word_mask = True # Force whole word masking
|
||||
|
||||
|
||||
def tolist(x):
|
||||
def tolist(x) -> list[Any]:
|
||||
if isinstance(x, list):
|
||||
return x
|
||||
elif hasattr(x, "numpy"):
|
||||
@ -1187,6 +1045,15 @@ def tolist(x):
|
||||
return x.tolist()
|
||||
|
||||
|
||||
def to_numpy(x) -> np.ndarray[Any]:
|
||||
if isinstance(x, np.ndarray):
|
||||
return x
|
||||
elif hasattr(x, "detach"):
|
||||
return x.detach().cpu().numpy()
|
||||
else:
|
||||
return np.array(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSOP(DataCollatorForLanguageModeling):
|
||||
"""
|
||||
|
@ -21,6 +21,7 @@ import numpy as np
|
||||
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
BertTokenizerFast,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorForSeq2Seq,
|
||||
@ -525,99 +526,120 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertFalse(torch.all(batch_3_labels == batch_5_labels))
|
||||
|
||||
def test_data_collator_for_whole_word_mask(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
tokenizer = BertTokenizerFast(self.vocab_file)
|
||||
|
||||
input_tokens = [f"token_{i}" for i in range(8)]
|
||||
tokenizer.add_tokens(input_tokens)
|
||||
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
||||
|
||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||
|
||||
# Features can already be tensors
|
||||
features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}]
|
||||
features = [
|
||||
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("np") for _ in range(2)
|
||||
]
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||
|
||||
if is_torch_available():
|
||||
# Features can already be tensors
|
||||
features = [
|
||||
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("pt")
|
||||
for _ in range(2)
|
||||
]
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||
|
||||
def test_data_collator_for_whole_word_mask_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
tokenizer = BertTokenizerFast(self.vocab_file)
|
||||
|
||||
input_tokens = [f"token_{i}" for i in range(998)]
|
||||
tokenizer.add_tokens(input_tokens)
|
||||
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForWholeWordMask instances
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
|
||||
batch_1 = data_collator(features)
|
||||
self.assertEqual(batch_1["input_ids"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_1["labels"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_1["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_1["labels"].shape, (2, 1000))
|
||||
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42)
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
|
||||
batch_2 = data_collator(features)
|
||||
self.assertEqual(batch_2["input_ids"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_2["labels"].shape, torch.Size((2, 1000)))
|
||||
self.assertEqual(batch_2["input_ids"].shape, (2, 1000))
|
||||
self.assertEqual(batch_2["labels"].shape, (2, 1000))
|
||||
|
||||
self.assertTrue(torch.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(torch.all(batch_1["labels"] == batch_2["labels"]))
|
||||
self.assertTrue(np.all(batch_1["input_ids"] == batch_2["input_ids"]))
|
||||
self.assertTrue(np.all(batch_1["labels"] == batch_2["labels"]))
|
||||
|
||||
# check if seed is respected in multiple workers situation
|
||||
features = [{"input_ids": list(range(1000))} for _ in range(10)]
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
||||
)
|
||||
if is_torch_available():
|
||||
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(10)]
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
||||
)
|
||||
|
||||
batch_3_input_ids = []
|
||||
batch_3_labels = []
|
||||
for batch in dataloader:
|
||||
batch_3_input_ids.append(batch["input_ids"])
|
||||
batch_3_labels.append(batch["labels"])
|
||||
batch_3_input_ids = []
|
||||
batch_3_labels = []
|
||||
for batch in dataloader:
|
||||
batch_3_input_ids.append(batch["input_ids"])
|
||||
batch_3_labels.append(batch["labels"])
|
||||
|
||||
batch_3_input_ids = torch.stack(batch_3_input_ids)
|
||||
batch_3_labels = torch.stack(batch_3_labels)
|
||||
self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000)))
|
||||
batch_3_input_ids = torch.stack(batch_3_input_ids)
|
||||
batch_3_labels = torch.stack(batch_3_labels)
|
||||
self.assertEqual(batch_3_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_3_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
||||
)
|
||||
|
||||
batch_4_input_ids = []
|
||||
batch_4_labels = []
|
||||
for batch in dataloader:
|
||||
batch_4_input_ids.append(batch["input_ids"])
|
||||
batch_4_labels.append(batch["labels"])
|
||||
batch_4_input_ids = torch.stack(batch_4_input_ids)
|
||||
batch_4_labels = torch.stack(batch_4_labels)
|
||||
self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000)))
|
||||
batch_4_input_ids = []
|
||||
batch_4_labels = []
|
||||
for batch in dataloader:
|
||||
batch_4_input_ids.append(batch["input_ids"])
|
||||
batch_4_labels.append(batch["labels"])
|
||||
batch_4_input_ids = torch.stack(batch_4_input_ids)
|
||||
batch_4_labels = torch.stack(batch_4_labels)
|
||||
self.assertEqual(batch_4_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_4_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids))
|
||||
self.assertTrue(torch.all(batch_3_labels == batch_4_labels))
|
||||
self.assertTrue(torch.all(batch_3_input_ids == batch_4_input_ids))
|
||||
self.assertTrue(torch.all(batch_3_labels == batch_4_labels))
|
||||
|
||||
# try with different seed
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=43),
|
||||
)
|
||||
# try with different seed
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
features,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=43),
|
||||
)
|
||||
|
||||
batch_5_input_ids = []
|
||||
batch_5_labels = []
|
||||
for batch in dataloader:
|
||||
batch_5_input_ids.append(batch["input_ids"])
|
||||
batch_5_labels.append(batch["labels"])
|
||||
batch_5_input_ids = torch.stack(batch_5_input_ids)
|
||||
batch_5_labels = torch.stack(batch_5_labels)
|
||||
self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000)))
|
||||
batch_5_input_ids = []
|
||||
batch_5_labels = []
|
||||
for batch in dataloader:
|
||||
batch_5_input_ids.append(batch["input_ids"])
|
||||
batch_5_labels.append(batch["labels"])
|
||||
batch_5_input_ids = torch.stack(batch_5_input_ids)
|
||||
batch_5_labels = torch.stack(batch_5_labels)
|
||||
self.assertEqual(batch_5_input_ids.shape, torch.Size((5, 2, 1000)))
|
||||
self.assertEqual(batch_5_labels.shape, torch.Size((5, 2, 1000)))
|
||||
|
||||
self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids))
|
||||
self.assertFalse(torch.all(batch_3_labels == batch_5_labels))
|
||||
self.assertFalse(torch.all(batch_3_input_ids == batch_5_input_ids))
|
||||
self.assertFalse(torch.all(batch_3_labels == batch_5_labels))
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
@ -929,24 +951,23 @@ class DataCollatorImmutabilityTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_whole_world_masking_collator_immutability(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
tokenizer = BertTokenizerFast(self.vocab_file)
|
||||
|
||||
features_base = [
|
||||
{"input_ids": list(range(10)), "labels": (1,)},
|
||||
{"input_ids": list(range(10)), "labels": (1,)},
|
||||
]
|
||||
whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
||||
input_tokens = [f"token_{i}" for i in range(8)]
|
||||
tokenizer.add_tokens(input_tokens)
|
||||
original_data = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
||||
for feature in original_data:
|
||||
feature["labels"] = (1,)
|
||||
|
||||
for datatype_input, datatype_label in [(list, list), (np.array, np.array)]:
|
||||
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
|
||||
collator=whole_word_masking_collator,
|
||||
base_data=features_base,
|
||||
input_key="input_ids",
|
||||
input_datatype=datatype_input,
|
||||
label_key="labels",
|
||||
label_datatype=datatype_label,
|
||||
ignore_label=True,
|
||||
)
|
||||
batch_data = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
||||
for feature in batch_data:
|
||||
feature["labels"] = (1,)
|
||||
|
||||
whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer)
|
||||
|
||||
self._validate_original_data_against_collated_data(
|
||||
collator=whole_word_masking_collator, original_data=original_data, batch_data=batch_data
|
||||
)
|
||||
|
||||
def test_permutation_language_modelling_collator_immutability(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
@ -1400,23 +1421,31 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertFalse(np.all(batch_1["labels"] == batch_3["labels"]))
|
||||
|
||||
def test_data_collator_for_whole_word_mask(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
tokenizer = BertTokenizerFast(self.vocab_file)
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
|
||||
|
||||
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
input_tokens = [f"token_{i}" for i in range(8)]
|
||||
tokenizer.add_tokens(input_tokens)
|
||||
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
||||
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||
|
||||
# Features can already be tensors
|
||||
features = [{"input_ids": np.arange(10)}, {"input_ids": np.arange(10)}]
|
||||
features = [
|
||||
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("np") for _ in range(2)
|
||||
]
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
||||
self.assertEqual(batch["labels"].shape, (2, 10))
|
||||
|
||||
def test_data_collator_for_whole_word_mask_with_seed(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
||||
tokenizer = BertTokenizerFast(self.vocab_file)
|
||||
|
||||
input_tokens = [f"token_{i}" for i in range(998)]
|
||||
tokenizer.add_tokens(input_tokens)
|
||||
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
||||
|
||||
# check if seed is respected between two different DataCollatorForWholeWordMask instances
|
||||
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
|
||||
@ -1755,24 +1784,23 @@ class NumpyDataCollatorImmutabilityTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_whole_world_masking_collator_immutability(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
tokenizer = BertTokenizerFast(self.vocab_file)
|
||||
|
||||
input_tokens = [f"token_{i}" for i in range(8)]
|
||||
tokenizer.add_tokens(input_tokens)
|
||||
original_data = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
||||
for feature in original_data:
|
||||
feature["labels"] = (1,)
|
||||
|
||||
batch_data = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
||||
for feature in batch_data:
|
||||
feature["labels"] = (1,)
|
||||
|
||||
features_base = [
|
||||
{"input_ids": list(range(10)), "labels": (1,)},
|
||||
{"input_ids": list(range(10)), "labels": (1,)},
|
||||
]
|
||||
whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
|
||||
|
||||
for datatype_input, datatype_label in [(list, list), (np.array, np.array)]:
|
||||
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
|
||||
collator=whole_word_masking_collator,
|
||||
base_data=features_base,
|
||||
input_key="input_ids",
|
||||
input_datatype=datatype_input,
|
||||
label_key="labels",
|
||||
label_datatype=datatype_label,
|
||||
ignore_label=True,
|
||||
)
|
||||
self._validate_original_data_against_collated_data(
|
||||
collator=whole_word_masking_collator, original_data=original_data, batch_data=batch_data
|
||||
)
|
||||
|
||||
def test_permutation_language_modelling_collator_immutability(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
@ -1842,3 +1870,98 @@ class NumpyDataCollatorImmutabilityTest(unittest.TestCase):
|
||||
self._validate_original_data_against_collated_data(
|
||||
collator=sop_collator, original_data=features_original, batch_data=features_batch
|
||||
)
|
||||
|
||||
|
||||
class DataCollatorForLanguageModelingUnitTest(unittest.TestCase):
|
||||
def test__calc_word_ids_and_prob_mask(self):
|
||||
offsets = np.array(
|
||||
[
|
||||
[(0, 0), (0, 3), (3, 4), (5, 6), (6, 7), (8, 9)],
|
||||
[(0, 0), (0, 3), (3, 4), (5, 6), (6, 7), (0, 0)],
|
||||
[(0, 0), (0, 3), (3, 4), (0, 0), (6, 7), (0, 0)],
|
||||
[(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)],
|
||||
[(1, 1), (2, 2), (3, 4), (5, 6), (7, 8), (9, 10)],
|
||||
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
|
||||
]
|
||||
)
|
||||
|
||||
special_tokens_mask = np.array(
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 1],
|
||||
[1, 0, 0, 1, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
]
|
||||
)
|
||||
|
||||
output_word_ids, output_prob_mask = DataCollatorForLanguageModeling._calc_word_ids_and_prob_mask(
|
||||
offsets, special_tokens_mask
|
||||
)
|
||||
|
||||
expected_word_ids = np.array(
|
||||
[
|
||||
[-1, 1, 1, 2, 2, 3],
|
||||
[-1, 1, 1, 2, 2, -1],
|
||||
[-1, 1, 1, -1, 2, -1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 2, 3, 4, 5, 6],
|
||||
[-1, -1, -1, -1, -1, -1],
|
||||
]
|
||||
)
|
||||
|
||||
expected_prob_mask = np.array(
|
||||
[
|
||||
[1, 0, 1, 0, 1, 0],
|
||||
[1, 0, 1, 0, 1, 1],
|
||||
[1, 0, 1, 1, 0, 1],
|
||||
[0, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
]
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(output_word_ids, expected_word_ids)
|
||||
np.testing.assert_array_equal(output_prob_mask, expected_prob_mask)
|
||||
|
||||
def test__whole_word_mask(self):
|
||||
word_ids = np.array(
|
||||
[
|
||||
[-1, 1, 1, 2, 2, 3],
|
||||
[-1, 1, 1, 2, 2, -1],
|
||||
[-1, 1, 1, -1, 2, -1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 2, 3, 4, 5, 6],
|
||||
[1, 2, 3, 4, 5, 6],
|
||||
[-1, -1, -1, -1, -1, -1],
|
||||
]
|
||||
)
|
||||
|
||||
mask = np.array(
|
||||
[
|
||||
[0, 1, 0, 0, 0, 0],
|
||||
[0, 1, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0],
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 1, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
]
|
||||
).astype(bool)
|
||||
|
||||
output_mask = DataCollatorForLanguageModeling._whole_word_mask(word_ids, mask)
|
||||
|
||||
expected_mask = np.array(
|
||||
[
|
||||
[0, 1, 1, 0, 0, 0],
|
||||
[0, 1, 1, 1, 1, 0],
|
||||
[0, 0, 0, 0, 1, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 1, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
]
|
||||
).astype(bool)
|
||||
|
||||
np.testing.assert_array_equal(output_mask, expected_mask)
|
||||
|
Reference in New Issue
Block a user