Compare commits

...

9 Commits

Author SHA1 Message Date
a0fafa4e8b trigger CI 2024-10-21 17:30:21 +01:00
ae8744b6a0 Split off the zero-shot part of this PR 2024-10-18 14:49:11 +01:00
f8fa11ed30 make fixup 2024-10-18 14:48:03 +01:00
ad27b5cf4b Importing the wrong element for zero-shot 2024-10-18 14:48:03 +01:00
4c1205aaf3 Importing the wrong element for zero-shot 2024-10-18 14:48:03 +01:00
53a1bd7ac1 Add extra test in token classifier 2024-10-18 14:48:03 +01:00
bcbfb8b0e6 Remove extra newline 2024-10-18 14:48:03 +01:00
b9b1174074 make fixup 2024-10-18 14:48:03 +01:00
2fdc8dcb3a Synchronize zero-shot and token-classification 2024-10-18 14:48:03 +01:00
3 changed files with 39 additions and 36 deletions

View File

@ -7,11 +7,10 @@ import numpy as np
from ..models.bert.tokenization_bert import BasicTokenizer
from ..utils import (
ExplicitEnum,
add_end_docstrings,
is_tf_available,
is_torch_available,
)
from .base import ArgumentHandler, ChunkPipeline, Dataset, build_pipeline_init_args
from .base import ArgumentHandler, ChunkPipeline, Dataset
if is_tf_available():
@ -60,40 +59,6 @@ class AggregationStrategy(ExplicitEnum):
MAX = "max"
@add_end_docstrings(
build_pipeline_init_args(has_tokenizer=True),
r"""
ignore_labels (`List[str]`, defaults to `["O"]`):
A list of labels to ignore.
grouped_entities (`bool`, *optional*, defaults to `False`):
DEPRECATED, use `aggregation_strategy` instead. Whether or not to group the tokens corresponding to the
same entity together in the predictions or not.
stride (`int`, *optional*):
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
value of this argument defines the number of overlapping tokens between chunks. In other words, the model
will shift forward by `tokenizer.model_max_length - stride` tokens each step.
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
The strategy to fuse (or not) tokens based on the model prediction.
- "none" : Will simply not do any aggregation and simply return raw results from the model
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
different entities. On word based languages, we might end up splitting words undesirably : Imagine
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
that support that meaning, which is basically tokens separated by a space). These mitigations will
only work on real words, "New york" might still be tagged with two different entities.
- "first" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
end up with different tags. Words will simply use the tag of the first token of the word when there
is ambiguity.
- "average" : (works only on word based models) Will use the `SIMPLE` strategy except that words,
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
label is applied.
- "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
end up with different tags. Word entity will simply be the token with the maximum score.""",
)
class TokenClassificationPipeline(ChunkPipeline):
"""
Named Entity Recognition pipeline using any `ModelForTokenClassification`. See the [named entity recognition
@ -224,6 +189,33 @@ class TokenClassificationPipeline(ChunkPipeline):
Args:
inputs (`str` or `List[str]`):
One or several texts (or one list of texts) for token classification.
ignore_labels (`List[str]`, defaults to `["O"]`):
A list of labels to ignore.
stride (`int`, *optional*):
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
value of this argument defines the number of overlapping tokens between chunks. In other words, the model
will shift forward by `tokenizer.model_max_length - stride` tokens each step.
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
The strategy to fuse (or not) tokens based on the model prediction.
- "none" : Will simply not do any aggregation and simply return raw results from the model
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
different entities. On word based languages, we might end up splitting words undesirably : Imagine
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
that support that meaning, which is basically tokens separated by a space). These mitigations will
only work on real words, "New york" might still be tagged with two different entities.
- "first" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
end up with different tags. Words will simply use the tag of the first token of the word when there
is ambiguity.
- "average" : (works only on word based models) Will use the `SIMPLE` strategy except that words,
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
label is applied.
- "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
end up with different tags. Word entity will simply be the token with the maximum score.
Return:
A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the

View File

@ -15,6 +15,7 @@
import unittest
import numpy as np
from huggingface_hub import TokenClassificationOutputElement
from transformers import (
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
@ -26,6 +27,7 @@ from transformers import (
)
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
is_torch_available,
nested_simplify,
@ -103,6 +105,9 @@ class TokenClassificationPipelineTests(unittest.TestCase):
for i in range(n)
],
)
for output_element in nested_simplify(outputs):
compare_pipeline_output_to_hub_spec(output_element, TokenClassificationOutputElement)
outputs = token_classifier(["list of strings", "A simple string that is quite a bit longer"])
self.assertIsInstance(outputs, list)
self.assertEqual(len(outputs), 2)
@ -137,6 +142,9 @@ class TokenClassificationPipelineTests(unittest.TestCase):
],
)
for output_element in nested_simplify(outputs):
compare_pipeline_output_to_hub_spec(output_element, TokenClassificationOutputElement)
self.run_aggregation_strategy(model, tokenizer)
def run_aggregation_strategy(self, model, tokenizer):

View File

@ -34,6 +34,7 @@ from huggingface_hub import (
ImageToTextInput,
ObjectDetectionInput,
QuestionAnsweringInput,
TokenClassificationInput,
ZeroShotImageClassificationInput,
)
@ -47,6 +48,7 @@ from transformers.pipelines import (
ImageToTextPipeline,
ObjectDetectionPipeline,
QuestionAnsweringPipeline,
TokenClassificationPipeline,
ZeroShotImageClassificationPipeline,
)
from transformers.testing_utils import (
@ -132,6 +134,7 @@ task_to_pipeline_and_spec_mapping = {
"image-to-text": (ImageToTextPipeline, ImageToTextInput),
"object-detection": (ObjectDetectionPipeline, ObjectDetectionInput),
"question-answering": (QuestionAnsweringPipeline, QuestionAnsweringInput),
"token-classification": (TokenClassificationPipeline, TokenClassificationInput),
"zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput),
}