Compare commits

...

4 Commits

Author SHA1 Message Date
28e278728d Release: 4.10.1 2021-09-10 16:12:44 +02:00
e5e0452c29 Fixing #13381 (#13400)
* Fixing #13381

* Enabling automatic LED models.
2021-09-10 16:11:33 +02:00
4afbd7ebf3 Fixing backward compatiblity for non prefixed tokens (B-, I-). (#13493) 2021-09-10 16:11:25 +02:00
60eb416a13 [Wav2Vec2] Fix normalization for non-padded tensors (#13512)
* finalize

* Apply suggestions from code review

* finish cleaner implementation

* more tests

* small fix

* finish

* up
2021-09-10 16:11:13 +02:00
13 changed files with 248 additions and 77 deletions

View File

@ -27,7 +27,8 @@ author = "huggingface"
# The short X.Y version # The short X.Y version
version = "" version = ""
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = "4.10.0" release = "4.10.1"

View File

@ -105,9 +105,8 @@ Supported models
3. :doc:`BARThez <model_doc/barthez>` (from École polytechnique) released with the paper `BARThez: a Skilled Pretrained 3. :doc:`BARThez <model_doc/barthez>` (from École polytechnique) released with the paper `BARThez: a Skilled Pretrained
French Sequence-to-Sequence Model <https://arxiv.org/abs/2010.12321>`__ by Moussa Kamal Eddine, Antoine J.-P. French Sequence-to-Sequence Model <https://arxiv.org/abs/2010.12321>`__ by Moussa Kamal Eddine, Antoine J.-P.
Tixier, Michalis Vazirgiannis. Tixier, Michalis Vazirgiannis.
4. `BEiT <https://huggingface.co/transformers/master/model_doc/beit.html>`__ (from Microsoft) released with the paper 4. :doc:`BEiT <model_doc/beit>` (from Microsoft) released with the paper `BEiT: BERT Pre-Training of Image Transformers
`BEiT: BERT Pre-Training of Image Transformers <https://arxiv.org/abs/2106.08254>`__ by Hangbo Bao, Li Dong, Furu <https://arxiv.org/abs/2106.08254>`__ by Hangbo Bao, Li Dong, Furu Wei.
Wei.
5. :doc:`BERT <model_doc/bert>` (from Google) released with the paper `BERT: Pre-training of Deep Bidirectional 5. :doc:`BERT <model_doc/bert>` (from Google) released with the paper `BERT: Pre-training of Deep Bidirectional
Transformers for Language Understanding <https://arxiv.org/abs/1810.04805>`__ by Jacob Devlin, Ming-Wei Chang, Transformers for Language Understanding <https://arxiv.org/abs/1810.04805>`__ by Jacob Devlin, Ming-Wei Chang,
Kenton Lee and Kristina Toutanova. Kenton Lee and Kristina Toutanova.
@ -264,9 +263,9 @@ Supported models
55. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper 55. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun `fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
56. `Splinter <https://huggingface.co/transformers/master/model_doc/splinter.html>`__ (from Tel Aviv University), 56. :doc:`Splinter <model_doc/splinter>` (from Tel Aviv University), released together with the paper `Few-Shot
released together with the paper `Few-Shot Question Answering by Pretraining Span Selection Question Answering by Pretraining Span Selection <https://arxiv.org/abs/2101.00438>`__ by Ori Ram, Yuval Kirstain,
<https://arxiv.org/abs/2101.00438>`__ by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. Jonathan Berant, Amir Globerson, Omer Levy.
57. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP 57. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
Krishna, and Kurt W. Keutzer. Krishna, and Kurt W. Keutzer.

View File

@ -342,7 +342,7 @@ install_requires = [
setup( setup(
name="transformers", name="transformers",
version="4.10.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="4.10.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors", author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
author_email="thomas@huggingface.co", author_email="thomas@huggingface.co",
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch", description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",

View File

@ -22,7 +22,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends). # in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.10.0" __version__ = "4.10.1"
# Work around to update TensorFlow's absl.logging threshold which alters the # Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present. # default Python logging output behavior when present.

View File

@ -341,7 +341,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
return processed_features return processed_features
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs): def _get_padding_strategies(self, padding=False, max_length=None):
""" """
Find the correct padding strategy Find the correct padding strategy
""" """

View File

@ -93,10 +93,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
@staticmethod @staticmethod
def utterance_cmvn( def utterance_cmvn(
x: np.ndarray, input_length: int, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True x: np.ndarray,
input_length: int,
normalize_means: Optional[bool] = True,
normalize_vars: Optional[bool] = True,
padding_value: float = 0.0,
) -> np.ndarray: ) -> np.ndarray:
# make sure we normalie float32 arrays # make sure we normalie float32 arrays
mean = x[:input_length].mean(axis=0) mean = x[:input_length].mean(axis=0)
square_sums = (x[:input_length] ** 2).sum(axis=0) square_sums = (x[:input_length] ** 2).sum(axis=0)
@ -107,15 +110,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
std = np.sqrt(np.maximum(var, 1e-10)) std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std) x = np.divide(x, std)
if x.shape[0] > input_length:
x[input_length:] = padding_value
# make sure array is in float32 # make sure array is in float32
x = x.astype(np.float32) x = x.astype(np.float32)
return x return x
def normalize(self, input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]: def normalize(
self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
) -> List[np.ndarray]:
lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
return [ return [
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars) self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars, self.padding_value)
for x, n in zip(input_values, input_lengths) for x, n in zip(input_features, lengths)
] ]
def __call__( def __call__(
@ -197,7 +206,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
) )
# make sure input is in list format
if is_batched and not isinstance(raw_speech[0], np.ndarray): if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech) for speech in raw_speech] raw_speech = [np.asarray(speech) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray): elif not is_batched and not isinstance(raw_speech, np.ndarray):
@ -225,21 +233,25 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
**kwargs, **kwargs,
) )
if "attention_mask" in padded_inputs: # make sure list is in array format
input_lengths = padded_inputs["attention_mask"].sum(-1) input_features = padded_inputs.get("input_features")
else: if isinstance(input_features[0], list):
padded_input_values = padded_inputs["input_features"] padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
# Utterance-level cepstral mean and variance normalization # Utterance-level cepstral mean and variance normalization
if self.do_ceptral_normalize: if self.do_ceptral_normalize:
input_features = padded_inputs["input_features"] attention_mask = (
np.array(attention_mask, dtype=np.bool)
# make sure list is in array format if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
if isinstance(input_features[0], list): else None
input_features = [np.asarray(feature, dtype=np.float32) for feature in input_features] )
padded_inputs["input_features"] = self.normalize(
padded_inputs["input_features"] = self.normalize(input_features, input_lengths=input_lengths) padded_inputs["input_features"], attention_mask=attention_mask
)
if return_tensors is not None: if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

View File

@ -79,13 +79,25 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
self.do_normalize = do_normalize self.do_normalize = do_normalize
@staticmethod @staticmethod
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]: def zero_mean_unit_var_norm(
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
) -> List[np.ndarray]:
""" """
Every array in the list is normalized to have zero mean and unit variance Every array in the list is normalized to have zero mean and unit variance
""" """
normed_input_values = [ if attention_mask is not None:
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths) attention_mask = np.array(attention_mask, np.bool)
] normed_input_values = []
for vector, length in zip(input_values, attention_mask.sum(-1)):
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
if length > normed_slice.shape[0]:
normed_slice[length:] = padding_value
normed_input_values.append(normed_slice)
else:
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
return normed_input_values return normed_input_values
def __call__( def __call__(
@ -172,14 +184,6 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
) )
# make sure input is in list format
if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
raw_speech = raw_speech.astype(np.float32)
# always return batch # always return batch
if not is_batched: if not is_batched:
raw_speech = [raw_speech] raw_speech = [raw_speech]
@ -196,19 +200,33 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
return_attention_mask=return_attention_mask, return_attention_mask=return_attention_mask,
) )
if "attention_mask" in padded_inputs: # convert input values to correct format
input_lengths = padded_inputs["attention_mask"].sum(-1) input_values = padded_inputs["input_values"]
else: if not isinstance(input_values[0], np.ndarray):
padded_input_values = padded_inputs["input_values"] padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])] elif (
not isinstance(input_values, np.ndarray)
and isinstance(input_values[0], np.ndarray)
and input_values[0].dtype is np.float64
):
padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64:
padded_inputs["input_values"] = input_values.astype(np.float32)
if isinstance(padded_inputs["input_values"][0], np.ndarray): # convert attention_mask to correct format
padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]] attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
# zero-mean and unit-variance normalization # zero-mean and unit-variance normalization
if self.do_normalize: if self.do_normalize:
attention_mask = (
attention_mask
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
else None
)
padded_inputs["input_values"] = self.zero_mean_unit_var_norm( padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
padded_inputs["input_values"], input_lengths=input_lengths padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
) )
if return_tensors is not None: if return_tensors is not None:

View File

@ -411,7 +411,8 @@ class TokenClassificationPipeline(Pipeline):
tag = entity_name[2:] tag = entity_name[2:]
else: else:
# It's not in B-, I- format # It's not in B-, I- format
bi = "B" # Default to I- for continuation.
bi = "I"
tag = entity_name tag = entity_name
return bi, tag return bi, tag

View File

@ -88,7 +88,7 @@ class ZeroShotClassificationPipeline(Pipeline):
hypothesis_template, hypothesis_template,
padding=True, padding=True,
add_special_tokens=True, add_special_tokens=True,
truncation=TruncationStrategy.DO_NOT_TRUNCATE, truncation=TruncationStrategy.ONLY_FIRST,
**kwargs **kwargs
): ):
""" """
@ -113,13 +113,31 @@ class ZeroShotClassificationPipeline(Pipeline):
) )
inputs.append(model_input) inputs.append(model_input)
else: else:
inputs = self.tokenizer( try:
sequence_pairs, inputs = self.tokenizer(
add_special_tokens=add_special_tokens, sequence_pairs,
return_tensors=return_tensors, add_special_tokens=add_special_tokens,
padding=padding, return_tensors=return_tensors,
truncation=truncation, padding=padding,
) truncation=truncation,
)
except Exception as e:
if "too short" in str(e):
# tokenizers might yell that we want to truncate
# to a value that is not even reached by the input.
# In that case we don't want to truncate.
# It seems there's not a really better way to catch that
# exception.
inputs = self.tokenizer(
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
)
else:
raise e
return inputs return inputs

View File

@ -136,18 +136,49 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
def test_cepstral_mean_and_variance_normalization(self): def test_cepstral_mean_and_variance_normalization(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
inputs = feature_extractor(speech_inputs, padding=True, return_tensors="np", return_attention_mask=True)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
def _check_zero_mean_unit_variance(input_vector): paddings = ["longest", "max_length", "do_not_pad"]
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3)) max_lengths = [None, 16, None]
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3)) var_tolerances = [1e-3, 1e-3, 1e-1]
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]]) inputs = feature_extractor(
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]]) speech_inputs, padding=padding, max_length=max_length, return_attention_mask=True
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]]) )
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
def test_cepstral_mean_and_variance_normalization_np(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 16, None]
var_tolerances = [1e-3, 1e-3, 1e-1]
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
inputs = feature_extractor(
speech_inputs, max_length=max_length, padding=padding, return_tensors="np", return_attention_mask=True
)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
def test_cepstral_mean_and_variance_normalization_trunc(self): def test_cepstral_mean_and_variance_normalization_trunc(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())

View File

@ -120,21 +120,45 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_zero_mean_unit_variance_normalization(self): def test_zero_mean_unit_variance_normalization_np(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np")
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector): paddings = ["longest", "max_length", "do_not_pad"]
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3) max_lengths = [None, 1600, None]
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3) for max_length, padding in zip(max_lengths, paddings):
processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np")
input_values = processed.input_values
_check_zero_mean_unit_variance(input_values[0, :800]) def _check_zero_mean_unit_variance(input_vector):
_check_zero_mean_unit_variance(input_values[1, :1000]) self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
_check_zero_mean_unit_variance(input_values[2]) self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
def test_zero_mean_unit_variance_normalization_trunc(self): _check_zero_mean_unit_variance(input_values[0][:800])
_check_zero_mean_unit_variance(input_values[1][:1000])
_check_zero_mean_unit_variance(input_values[2][:1200])
def test_zero_mean_unit_variance_normalization(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
lengths = range(800, 1400, 200)
speech_inputs = [floats_list((1, x))[0] for x in lengths]
paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 1600, None]
for max_length, padding in zip(max_lengths, paddings):
processed = feat_extract(speech_inputs, max_length=max_length, padding=padding)
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
_check_zero_mean_unit_variance(input_values[0][:800])
_check_zero_mean_unit_variance(input_values[1][:1000])
_check_zero_mean_unit_variance(input_values[2][:1200])
def test_zero_mean_unit_variance_normalization_trunc_np(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract( processed = feat_extract(

View File

@ -318,6 +318,59 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
], ],
) )
@require_torch
def test_aggregation_strategy_no_b_i_prefix(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
# Just to understand scores indexes in this test
token_classifier.model.config.id2label = {0: "O", 1: "MISC", 2: "PER", 3: "ORG", 4: "LOC"}
example = [
{
# fmt : off
"scores": np.array([0, 0, 0, 0, 0.9968166351318359]),
"index": 1,
"is_subword": False,
"word": "En",
"start": 0,
"end": 2,
},
{
# fmt : off
"scores": np.array([0, 0, 0, 0, 0.9957635998725891]),
"index": 2,
"is_subword": True,
"word": "##zo",
"start": 2,
"end": 4,
},
{
# fmt: off
"scores": np.array([0, 0, 0, 0.9986497163772583, 0]),
# fmt: on
"index": 7,
"word": "UN",
"is_subword": False,
"start": 11,
"end": 13,
},
]
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)),
[
{"end": 2, "entity": "LOC", "score": 0.997, "start": 0, "word": "En", "index": 1},
{"end": 4, "entity": "LOC", "score": 0.996, "start": 2, "word": "##zo", "index": 2},
{"end": 13, "entity": "ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7},
],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)),
[
{"entity_group": "LOC", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)
@require_torch @require_torch
def test_aggregation_strategy(self): def test_aggregation_strategy(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"

View File

@ -105,6 +105,20 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
zero_shot_classifier.model.config.label2id = original_label2id zero_shot_classifier.model.config.label2id = original_label2id
self.assertEqual(original_entailment, zero_shot_classifier.entailment_id) self.assertEqual(original_entailment, zero_shot_classifier.entailment_id)
@require_torch
def test_truncation(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
)
# There was a regression in 4.10 for this
# Adding a test so we don't make the mistake again.
# https://github.com/huggingface/transformers/issues/13381#issuecomment-912343499
zero_shot_classifier(
"Who are you voting for in 2020?" * 100, candidate_labels=["politics", "public health", "science"]
)
@require_torch @require_torch
def test_small_model_pt(self): def test_small_model_pt(self):
zero_shot_classifier = pipeline( zero_shot_classifier = pipeline(