Compare commits

...

3 Commits

Author SHA1 Message Date
f718f067c8 make style 2022-09-27 19:30:08 +02:00
b7e1eef25d add sep 2022-09-27 19:30:08 +02:00
407ba3540f First draft 2022-09-27 19:30:07 +02:00
2 changed files with 67 additions and 0 deletions

View File

@ -1043,6 +1043,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
raise NotImplementedError
def warn_if_pad_token_in_input_ids_no_attention_mask(self, input_ids, attention_mask):
if attention_mask is not None:
# if the attention_mask is defined it's all good
return
if not hasattr(self, "warnings_issued"):
self.warnings_issued = {}
if self.warnings_issued.get("pad_token_in_input_ids", False):
# if warning has already been thrown don't throw again
return
is_pad_token_in_input_ids = self.config.pad_token_id is not None and self.config.pad_token_id in input_ids
if is_pad_token_in_input_ids:
# things become tricky if <pad> is equal to either BOS or EOS:
# in this case we cannot reasonably know whether the user should use an attention_mask or not
# because a in left padding both <pad> and BOS could be in the beginning
# and in right padding both <pad> and EOS could be in the end.
# Both cases look the same for `input_ids`, but one is correct behavior, the other one incorrect.
# In this case, we should still throw a warning because most models don't have
# <pad> == EOS or <pad> == BOS or <pad> == SEP.
is_pad_token_equal_to_bos_token = (
self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id
)
is_pad_token_equal_to_eos_token = (
self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id
)
is_pad_token_equal_to_sep_token = (
self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id
)
warn_string = (
f"The input IDs {input_ids} contains the `pad_token_id` {self.config.pad_token_id}, "
"but NO `attention_mask` is passed."
)
if is_pad_token_equal_to_bos_token:
warn_string += (
"\nWe strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the"
" attention weights. \nYou can ignore this warning, if your `pad_token_id`"
f" {self.config.pad_token_id} is identical to your `bos_token_id` {self.config.bos_token_id} AND"
" your input is NOT padded."
)
if is_pad_token_equal_to_eos_token:
warn_string += (
"\nWe strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the"
" attention weights. \nYou can ignore this warning, if your `pad_token_id`"
f" {self.config.pad_token_id} is identical to your `eos_token_id` {self.config.eos_token_id} AND"
" your input is NOT padded."
)
if is_pad_token_equal_to_sep_token:
warn_string += (
"\nWe strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the"
" attention weights. \nYou can ignore this warning, if your `pad_token_id`"
f" {self.config.pad_token_id} is identical to your `sep_token_id` {self.config.sep_token_id} AND"
" your input is NOT padded."
)
if not (is_pad_token_equal_to_bos_token or is_pad_token_equal_to_eos_token):
warn_string += (
"\nPadding the input IDs without passing an `attention_mask` leads to "
"unexpected, possibly incorrect outputs."
)
logger.warning(warn_string)
self.warnings_issued["pad_token_in_input_ids"] = True
def set_input_embeddings(self, value: nn.Module):
"""
Set model's input embeddings.

View File

@ -960,6 +960,7 @@ class BertModel(BertPreTrainedModel):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
self.warn_if_pad_token_in_input_ids_no_attention_mask(input_ids, attention_mask)
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else: