mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
35 Commits
v4.51.0
...
add_pipeli
Author | SHA1 | Date | |
---|---|---|---|
1f93b12f69 | |||
4c3fb8ffff | |||
123a9e64a8 | |||
a0bcb233e8 | |||
5e040e66f5 | |||
35d0592327 | |||
bdfde76bec | |||
8f7050fd7d | |||
2221ceffe1 | |||
b15cf4ff87 | |||
ddeb6790c8 | |||
3d2c477371 | |||
6db83ab5e5 | |||
12fcfcfdb0 | |||
9ce2faa767 | |||
98b5dd447b | |||
3fd294dd31 | |||
bd84dc2427 | |||
c6069d72b5 | |||
021ae205c6 | |||
04fe7ad4ba | |||
c4b5c1f21f | |||
6c7de21eb8 | |||
046f4367f8 | |||
298fb37491 | |||
3983fcc247 | |||
8d2a10bd9a | |||
80a578d4cc | |||
71569fd19a | |||
d951519efe | |||
f7e508735d | |||
aa4d9c4f11 | |||
9b3d877def | |||
55614541e1 | |||
7a6090b28a |
@ -126,6 +126,11 @@ class AudioClassificationPipeline(Pipeline):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is `None` or
|
||||
higher than the number of labels available in the model configuration, it will default to the number of
|
||||
labels.
|
||||
function_to_apply(`str`, *optional*, defaults to "softmax"):
|
||||
The function to apply to the model output. By default, the pipeline will apply the softmax function to
|
||||
the output of the model. Valid options: ["softmax", "sigmoid", "none"]. Note that passing Python's
|
||||
built-in `None` will default to "softmax", so you need to pass the string "none" to disable any
|
||||
post-processing.
|
||||
|
||||
Return:
|
||||
A list of `dict` with the following keys:
|
||||
@ -135,13 +140,22 @@ class AudioClassificationPipeline(Pipeline):
|
||||
"""
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, top_k=None, **kwargs):
|
||||
def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs):
|
||||
# No parameters on this pipeline right now
|
||||
postprocess_params = {}
|
||||
if top_k is not None:
|
||||
if top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
postprocess_params["top_k"] = top_k
|
||||
if function_to_apply is not None:
|
||||
if function_to_apply not in ["softmax", "sigmoid", "none"]:
|
||||
raise ValueError(
|
||||
f"Invalid value for `function_to_apply`: {function_to_apply}. "
|
||||
"Valid options are ['softmax', 'sigmoid', 'none']"
|
||||
)
|
||||
postprocess_params["function_to_apply"] = function_to_apply
|
||||
else:
|
||||
postprocess_params["function_to_apply"] = "softmax"
|
||||
return {}, {}, postprocess_params
|
||||
|
||||
def preprocess(self, inputs):
|
||||
@ -203,8 +217,13 @@ class AudioClassificationPipeline(Pipeline):
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, top_k=5):
|
||||
probs = model_outputs.logits[0].softmax(-1)
|
||||
def postprocess(self, model_outputs, top_k=5, function_to_apply="softmax"):
|
||||
if function_to_apply == "softmax":
|
||||
probs = model_outputs.logits[0].softmax(-1)
|
||||
elif function_to_apply == "sigmoid":
|
||||
probs = model_outputs.logits[0].sigmoid()
|
||||
else:
|
||||
probs = model_outputs.logits[0]
|
||||
scores, ids = probs.topk(top_k)
|
||||
|
||||
scores = scores.tolist()
|
||||
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||
|
||||
@ -269,8 +270,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||
complete overview of generate, check the [following
|
||||
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
|
||||
max_new_tokens (`int`, *optional*):
|
||||
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||
|
||||
Return:
|
||||
`Dict`: A dictionary with the following keys:
|
||||
@ -310,6 +309,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
|
||||
forward_params = defaultdict(dict)
|
||||
if max_new_tokens is not None:
|
||||
warnings.warn(
|
||||
"`max_new_tokens` is deprecated and will be removed in version 5 of Transformers. To remove this warning, pass `max_new_tokens` a keyword argument inside `generate_kwargs` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
forward_params["max_new_tokens"] = max_new_tokens
|
||||
if generate_kwargs is not None:
|
||||
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@ -50,12 +51,12 @@ class DepthEstimationPipeline(Pipeline):
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
||||
"""
|
||||
Predict the depth(s) of the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
@ -65,9 +66,10 @@ class DepthEstimationPipeline(Pipeline):
|
||||
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
||||
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
||||
images.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
parameters (`Dict`, *optional*):
|
||||
A dictionary of argument names to parameter values, to control pipeline behaviour.
|
||||
The only parameter available right now is `timeout`, which is the length of time, in seconds,
|
||||
that the pipeline should wait before giving up on trying to download an image.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
|
||||
@ -79,12 +81,26 @@ class DepthEstimationPipeline(Pipeline):
|
||||
- **predicted_depth** (`torch.Tensor`) -- The predicted depth by the model as a `torch.Tensor`.
|
||||
- **depth** (`PIL.Image`) -- The predicted depth by the model as a `PIL.Image`.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs:
|
||||
warnings.warn(
|
||||
"The `images` argument has been renamed to `inputs`. In version 5 of Transformers, `images` will no longer be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
inputs = kwargs.pop("images")
|
||||
if inputs is None:
|
||||
raise ValueError("Cannot call the depth-estimation pipeline without an inputs argument!")
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, timeout=None, **kwargs):
|
||||
def _sanitize_parameters(self, timeout=None, parameters=None, **kwargs):
|
||||
preprocess_params = {}
|
||||
if timeout is not None:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = timeout
|
||||
if isinstance(parameters, dict) and "timeout" in parameters:
|
||||
preprocess_params["timeout"] = parameters["timeout"]
|
||||
return preprocess_params, {}, {}
|
||||
|
||||
def preprocess(self, image, timeout=None):
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -44,6 +45,7 @@ if is_pytesseract_available():
|
||||
TESSERACT_LOADED = True
|
||||
import pytesseract
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -245,11 +247,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
||||
Whether or not we accept impossible as an answer.
|
||||
lang (`str`, *optional*):
|
||||
Language to use while running OCR. Defaults to english.
|
||||
tesseract_config (`str`, *optional*):
|
||||
Additional flags to pass to tesseract while running OCR.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
|
||||
@ -291,6 +288,15 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
||||
|
||||
image = None
|
||||
image_features = {}
|
||||
if timeout is not None:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
if tesseract_config:
|
||||
warnings.warn(
|
||||
"The `tesseract_config` argument is deprecated and will be removed in version 5 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
if input.get("image", None) is not None:
|
||||
image = load_image(input["image"], timeout=timeout)
|
||||
if self.image_processor is not None:
|
||||
|
@ -37,7 +37,9 @@ class FeatureExtractionPipeline(Pipeline):
|
||||
[huggingface.co/models](https://huggingface.co/models).
|
||||
"""
|
||||
|
||||
def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
|
||||
def _sanitize_parameters(
|
||||
self, truncation=None, truncation_direction=None, tokenize_kwargs=None, return_tensors=None, **kwargs
|
||||
):
|
||||
if tokenize_kwargs is None:
|
||||
tokenize_kwargs = {}
|
||||
|
||||
@ -47,6 +49,13 @@ class FeatureExtractionPipeline(Pipeline):
|
||||
"truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
|
||||
)
|
||||
tokenize_kwargs["truncation"] = truncation
|
||||
if truncation_direction is not None:
|
||||
if "truncation_side" in tokenize_kwargs:
|
||||
raise ValueError(
|
||||
"truncation_side parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
|
||||
)
|
||||
# The JS spec uses title-case, transformers uses lower, so we normalize
|
||||
tokenize_kwargs["truncation_side"] = truncation_direction.lower()
|
||||
|
||||
preprocess_params = tokenize_kwargs
|
||||
|
||||
@ -73,14 +82,19 @@ class FeatureExtractionPipeline(Pipeline):
|
||||
elif self.framework == "tf":
|
||||
return model_outputs[0].numpy().tolist()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
"""
|
||||
Extract the features of the input(s).
|
||||
|
||||
Args:
|
||||
args (`str` or `List[str]`): One or several texts (or one list of texts) to get the features of.
|
||||
inputs (`str` or `List[str]`): One or several texts (or one list of texts) to get the features of.
|
||||
truncate(`bool`, *optional*, defaults to `None`):
|
||||
Whether to truncate the input to max length or not. Overrides the value passed when initializing the
|
||||
pipeline.
|
||||
truncation_direction (`str`, *optional*, defaults to `None`): The side to truncate from the input sequence
|
||||
if truncation is enabled. Can be 'left' or 'right'.
|
||||
|
||||
Return:
|
||||
A nested list of `float`: The features computed by the model.
|
||||
"""
|
||||
return super().__call__(*args, **kwargs)
|
||||
return super().__call__(*inputs, **kwargs)
|
||||
|
@ -245,12 +245,12 @@ class FillMaskPipeline(Pipeline):
|
||||
)
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, inputs, *args, **kwargs):
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
"""
|
||||
Fill the masked token in the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
args (`str` or `List[str]`):
|
||||
inputs (`str` or `List[str]`):
|
||||
One or several texts (or one list of prompts) with masked tokens.
|
||||
targets (`str` or `List[str]`, *optional*):
|
||||
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@ -99,6 +100,9 @@ class ImageClassificationPipeline(Pipeline):
|
||||
def _sanitize_parameters(self, top_k=None, function_to_apply=None, timeout=None):
|
||||
preprocess_params = {}
|
||||
if timeout is not None:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = timeout
|
||||
postprocess_params = {}
|
||||
if top_k is not None:
|
||||
@ -109,12 +113,12 @@ class ImageClassificationPipeline(Pipeline):
|
||||
postprocess_params["function_to_apply"] = function_to_apply
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
||||
"""
|
||||
Assign labels to the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
@ -142,9 +146,6 @@ class ImageClassificationPipeline(Pipeline):
|
||||
top_k (`int`, *optional*, defaults to 5):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
||||
the number of labels available in the model configuration, it will default to the number of labels.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
|
||||
@ -156,7 +157,16 @@ class ImageClassificationPipeline(Pipeline):
|
||||
- **label** (`str`) -- The label identified by the model.
|
||||
- **score** (`int`) -- The score attributed by the model for that label.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs:
|
||||
warnings.warn(
|
||||
"The `images` argument has been renamed to `inputs`. In version 5 of Transformers, `images` will no longer be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
inputs = kwargs.pop("images")
|
||||
if inputs is None:
|
||||
raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def preprocess(self, image, timeout=None):
|
||||
image = load_image(image, timeout=timeout)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
@ -90,16 +91,19 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
if "overlap_mask_area_threshold" in kwargs:
|
||||
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
||||
if "timeout" in kwargs:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_kwargs["timeout"] = kwargs["timeout"]
|
||||
|
||||
return preprocess_kwargs, {}, postprocess_kwargs
|
||||
|
||||
def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||
def __call__(self, inputs=None, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||
"""
|
||||
Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing an HTTP(S) link pointing to an image
|
||||
@ -118,9 +122,6 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Mask overlap threshold to eliminate small, disconnected segments.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
|
||||
@ -136,7 +137,16 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of the
|
||||
"object" described by the label and the mask.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs:
|
||||
warnings.warn(
|
||||
"The `images` argument has been renamed to `inputs`. In version 5 of Transformers, `images` will no longer be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
inputs = kwargs.pop("images")
|
||||
if inputs is None:
|
||||
raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def preprocess(self, image, subtask=None, timeout=None):
|
||||
image = load_image(image, timeout=timeout)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
from ..utils import (
|
||||
@ -80,6 +81,9 @@ class ImageToTextPipeline(Pipeline):
|
||||
if prompt is not None:
|
||||
preprocess_params["prompt"] = prompt
|
||||
if timeout is not None:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = timeout
|
||||
|
||||
if max_new_tokens is not None:
|
||||
@ -94,12 +98,12 @@ class ImageToTextPipeline(Pipeline):
|
||||
|
||||
return preprocess_params, forward_params, {}
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
||||
"""
|
||||
Assign labels to the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a HTTP(s) link pointing to an image
|
||||
@ -113,16 +117,22 @@ class ImageToTextPipeline(Pipeline):
|
||||
|
||||
generate_kwargs (`Dict`, *optional*):
|
||||
Pass it to send all of these arguments directly to `generate` allowing full control of this function.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
|
||||
|
||||
- **generated_text** (`str`) -- The generated text.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs:
|
||||
warnings.warn(
|
||||
"The `images` argument has been renamed to `inputs`. In version 5 of Transformers, `images` will no longer be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
inputs = kwargs.pop("images")
|
||||
if inputs is None:
|
||||
raise ValueError("Cannot call the image-to-text pipeline without an inputs argument!")
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def preprocess(self, image, prompt=None, timeout=None):
|
||||
image = load_image(image, timeout=timeout)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
|
||||
@ -63,6 +64,9 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_params = {}
|
||||
if "timeout" in kwargs:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = kwargs["timeout"]
|
||||
postprocess_kwargs = {}
|
||||
if "threshold" in kwargs:
|
||||
@ -74,7 +78,7 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing an HTTP(S) link pointing to an image
|
||||
@ -85,9 +89,6 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The probability necessary to make a prediction.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
|
||||
@ -100,7 +101,13 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
- **score** (`float`) -- The score attributed by the model for that label.
|
||||
- **box** (`List[Dict[str, int]]`) -- The bounding box of detected object in image's original size.
|
||||
"""
|
||||
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs and "inputs" not in kwargs:
|
||||
warnings.warn(
|
||||
"The `images` argument has been renamed to `inputs`. In version 5 of Transformers, `images` will no longer be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
kwargs["inputs"] = kwargs.pop("images")
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
def preprocess(self, image, timeout=None):
|
||||
|
@ -183,8 +183,14 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
# Generic compatibility with sklearn and Keras
|
||||
# Batched data
|
||||
elif "X" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing the `X` argument to the pipeline is deprecated and will be removed in v5.", FutureWarning
|
||||
)
|
||||
inputs = kwargs["X"]
|
||||
elif "data" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing the `data` argument to the pipeline is deprecated and will be removed in v5.", FutureWarning
|
||||
)
|
||||
inputs = kwargs["data"]
|
||||
elif "question" in kwargs and "context" in kwargs:
|
||||
if isinstance(kwargs["question"], list) and isinstance(kwargs["context"], str):
|
||||
@ -345,20 +351,12 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
||||
Answer the question(s) given as inputs by using the context(s).
|
||||
|
||||
Args:
|
||||
args ([`SquadExample`] or a list of [`SquadExample`]):
|
||||
One or several [`SquadExample`] containing the question and context.
|
||||
X ([`SquadExample`] or a list of [`SquadExample`], *optional*):
|
||||
One or several [`SquadExample`] containing the question and context (will be treated the same way as if
|
||||
passed as the first positional argument).
|
||||
data ([`SquadExample`] or a list of [`SquadExample`], *optional*):
|
||||
One or several [`SquadExample`] containing the question and context (will be treated the same way as if
|
||||
passed as the first positional argument).
|
||||
question (`str` or `List[str]`):
|
||||
One or several question(s) (must be used in conjunction with the `context` argument).
|
||||
context (`str` or `List[str]`):
|
||||
One or several context(s) associated with the question(s) (must be used in conjunction with the
|
||||
`question` argument).
|
||||
topk (`int`, *optional*, defaults to 1):
|
||||
top_k (`int`, *optional*, defaults to 1):
|
||||
The number of answers to return (will be chosen by order of likelihood). Note that we return less than
|
||||
topk answers if there are not enough options available within the context.
|
||||
doc_stride (`int`, *optional*, defaults to 128):
|
||||
@ -387,6 +385,11 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
||||
"""
|
||||
|
||||
# Convert inputs to features
|
||||
if args:
|
||||
warnings.warn(
|
||||
"Passing a list of SQuAD examples to the pipeline is deprecated and will be removed in v5. Inputs should be passed using the `question` and `context` keyword arguments instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
examples = self._args_parser(*args, **kwargs)
|
||||
if isinstance(examples, (list, tuple)) and len(examples) == 1:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import collections
|
||||
import types
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -34,7 +35,7 @@ class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
Handles arguments for the TableQuestionAnsweringPipeline
|
||||
"""
|
||||
|
||||
def __call__(self, table=None, query=None, **kwargs):
|
||||
def __call__(self, table=None, question=None, **kwargs):
|
||||
# Returns tqa_pipeline_inputs of shape:
|
||||
# [
|
||||
# {"table": pd.DataFrame, "query": List[str]},
|
||||
@ -46,7 +47,7 @@ class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
|
||||
if table is None:
|
||||
raise ValueError("Keyword argument `table` cannot be None.")
|
||||
elif query is None:
|
||||
elif question is None:
|
||||
if isinstance(table, dict) and table.get("query") is not None and table.get("table") is not None:
|
||||
tqa_pipeline_inputs = [table]
|
||||
elif isinstance(table, list) and len(table) > 0:
|
||||
@ -70,7 +71,7 @@ class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
f"is {type(table)})"
|
||||
)
|
||||
else:
|
||||
tqa_pipeline_inputs = [{"table": table, "query": query}]
|
||||
tqa_pipeline_inputs = [{"table": table, "query": question}]
|
||||
|
||||
for tqa_pipeline_input in tqa_pipeline_inputs:
|
||||
if not isinstance(tqa_pipeline_input["table"], pd.DataFrame):
|
||||
@ -305,30 +306,11 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
table (`pd.DataFrame` or `Dict`):
|
||||
Pandas DataFrame or dictionary that will be converted to a DataFrame containing all the table values.
|
||||
See above for an example of dictionary.
|
||||
query (`str` or `List[str]`):
|
||||
Query or list of queries that will be sent to the model alongside the table.
|
||||
sequential (`bool`, *optional*, defaults to `False`):
|
||||
Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the
|
||||
inference to be done sequentially to extract relations within sequences, given their conversational
|
||||
nature.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
|
||||
truncation (`bool`, `str` or [`TapasTruncationStrategy`], *optional*, defaults to `False`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
- `True` or `'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument `max_length`
|
||||
or to the maximum acceptable input length for the model if that argument is not provided. This will
|
||||
truncate row by row, removing rows from the table.
|
||||
- `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
|
||||
greater than the model maximum admissible input size).
|
||||
question (`str`):
|
||||
Query that will be sent to the model alongside the table.
|
||||
parameters (`Dict[str, Any]`, *optional*):
|
||||
Additional parameters passed to the tokenizer and model. Currently supported parameters are
|
||||
`sequential`, `padding` and `truncation`.
|
||||
|
||||
|
||||
Return:
|
||||
@ -341,6 +323,36 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
- **cells** (`List[str]`) -- List of strings made up of the answer cell values.
|
||||
- **aggregator** (`str`) -- If the model has an aggregator, this returns the aggregator.
|
||||
"""
|
||||
# This block just for deprecation / input checking
|
||||
if args:
|
||||
table = args[0]
|
||||
elif "table" in kwargs:
|
||||
table = kwargs["table"]
|
||||
if isinstance(table, dict):
|
||||
if "query" in table:
|
||||
warnings.warn(
|
||||
"Passing the query as a key with the input table is deprecated and will be removed in Transformers v5. Use the `question` keyword argument instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
elif isinstance(table, list) and isinstance(table[0], dict):
|
||||
if "query" in table[0]:
|
||||
warnings.warn(
|
||||
"Passing the query as a key with the input table is deprecated and will be removed in Transformers v5. Use the `question` keyword argument instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if "query" in kwargs:
|
||||
warnings.warn(
|
||||
"The `query` keyword argument is deprecated and will be removed in Transformers v5. Use the `question` keyword argument instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
kwargs["question"] = kwargs.pop("query")
|
||||
if isinstance(kwargs.get("question", None), list):
|
||||
warnings.warn(
|
||||
"Passing a list of queries to TableQuestionAnsweringPipeline is deprecated and will be removed in Transformers v5. Pass queries singly instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
# The default input parser does not actually use or modify any kwargs except table and query/question
|
||||
pipeline_inputs = self._args_parser(*args, **kwargs)
|
||||
|
||||
results = super().__call__(pipeline_inputs, **kwargs)
|
||||
@ -348,16 +360,34 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
return results[0]
|
||||
return results
|
||||
|
||||
def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, **kwargs):
|
||||
def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, parameters=None, **kwargs):
|
||||
preprocess_params = {}
|
||||
if padding is not None:
|
||||
warnings.warn(
|
||||
"The `padding` argument is deprecated and will be removed in version 5 of Transformers. Please pass it as a key in the `parameters` dict instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
preprocess_params["padding"] = padding
|
||||
elif parameters is not None and "padding" in parameters:
|
||||
preprocess_params["padding"] = parameters["padding"]
|
||||
if truncation is not None:
|
||||
warnings.warn(
|
||||
"The `truncation` argument is deprecated and will be removed in version 5 of Transformers. Please pass it as a key in the `parameters` dict instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
preprocess_params["truncation"] = truncation
|
||||
elif parameters is not None and "truncation" in parameters:
|
||||
preprocess_params["truncation"] = parameters["truncation"]
|
||||
|
||||
forward_params = {}
|
||||
if sequential is not None:
|
||||
warnings.warn(
|
||||
"The `sequential` argument is deprecated and will be removed in version 5 of Transformers. Please pass it as a key in the `parameters` dict instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
forward_params["sequential"] = sequential
|
||||
elif parameters is not None and "sequential" in parameters:
|
||||
forward_params["sequential"] = parameters["sequential"]
|
||||
return preprocess_params, forward_params, {}
|
||||
|
||||
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
|
||||
|
@ -89,7 +89,16 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
forward_params = generate_kwargs
|
||||
|
||||
postprocess_params = {}
|
||||
if return_text is not None:
|
||||
warnings.warn(
|
||||
"The `return_text` argument is deprecated and will be removed in version 5 of Transformers. ",
|
||||
FutureWarning,
|
||||
)
|
||||
if return_tensors is not None and return_type is None:
|
||||
warnings.warn(
|
||||
"The `return_tensors` argument is deprecated and will be removed in version 5 of Transformers. ",
|
||||
FutureWarning,
|
||||
)
|
||||
return_type = ReturnType.TENSORS if return_tensors else ReturnType.TEXT
|
||||
if return_type is not None:
|
||||
postprocess_params["return_type"] = return_type
|
||||
@ -135,17 +144,13 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
del inputs["token_type_ids"]
|
||||
return inputs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
r"""
|
||||
Generate the output text(s) using text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
args (`str` or `List[str]`):
|
||||
inputs (`str` or `List[str]`):
|
||||
Input text for the encoder.
|
||||
return_tensors (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
||||
return_text (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to include the decoded texts in the outputs.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
|
||||
@ -164,10 +169,10 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
ids of the generated text.
|
||||
"""
|
||||
|
||||
result = super().__call__(*args, **kwargs)
|
||||
result = super().__call__(*inputs, **kwargs)
|
||||
if (
|
||||
isinstance(args[0], list)
|
||||
and all(isinstance(el, str) for el in args[0])
|
||||
isinstance(inputs[0], list)
|
||||
and all(isinstance(el, str) for el in inputs[0])
|
||||
and all(len(res) == 1 for res in result)
|
||||
):
|
||||
return [res[0] for res in result]
|
||||
@ -252,14 +257,14 @@ class SummarizationPipeline(Text2TextGenerationPipeline):
|
||||
Summarize the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
documents (*str* or `List[str]`):
|
||||
inputs (*str* or `List[str]`):
|
||||
One or several articles (or one list of articles) to summarize.
|
||||
return_text (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to include the decoded texts in the outputs
|
||||
return_tensors (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
|
||||
The truncation strategy for the tokenization within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
|
||||
(default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
|
||||
max_length instead of throwing an error down the line.
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework [here](./text_generation)).
|
||||
@ -343,19 +348,19 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
||||
preprocess_params["tgt_lang"] = items[3]
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
r"""
|
||||
Translate the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
args (`str` or `List[str]`):
|
||||
inputs (`str` or `List[str]`):
|
||||
Texts to be translated.
|
||||
return_tensors (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
||||
return_text (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to include the decoded texts in the outputs.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
|
||||
The truncation strategy for the tokenization within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
|
||||
(default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
|
||||
max_length instead of throwing an error down the line.
|
||||
src_lang (`str`, *optional*):
|
||||
The language of the input. Might be required for multilingual models. Will not have any effect for
|
||||
single pair translation models
|
||||
@ -373,4 +378,4 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
||||
- **translation_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The
|
||||
token ids of the translation.
|
||||
"""
|
||||
return super().__call__(*args, **kwargs)
|
||||
return super().__call__(*inputs, **kwargs)
|
||||
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.from typing import List, Union
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
@ -161,16 +162,13 @@ class TextToAudioPipeline(Pipeline):
|
||||
|
||||
return output
|
||||
|
||||
def __call__(self, text_inputs: Union[str, List[str]], **forward_params):
|
||||
def __call__(self, inputs: Union[str, List[str]] = None, **forward_params):
|
||||
"""
|
||||
Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information.
|
||||
|
||||
Args:
|
||||
text_inputs (`str` or `List[str]`):
|
||||
inputs (`str` or `List[str]`):
|
||||
The text(s) to generate.
|
||||
forward_params (`dict`, *optional*):
|
||||
Parameters passed to the model generation/forward method. `forward_params` are always passed to the
|
||||
underlying model.
|
||||
generate_kwargs (`dict`, *optional*):
|
||||
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||
complete overview of generate, check the [following
|
||||
@ -183,7 +181,22 @@ class TextToAudioPipeline(Pipeline):
|
||||
- **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform.
|
||||
- **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform.
|
||||
"""
|
||||
return super().__call__(text_inputs, **forward_params)
|
||||
# After deprecation of this is completed, remove the default `None` value for `inputs`
|
||||
if "text_inputs" in forward_params:
|
||||
warnings.warn(
|
||||
"The `text_inputs` argument has been renamed to `inputs`. In version 5 of Transformers, `text_inputs` will no longer be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
inputs = forward_params.pop("text_inputs")
|
||||
elif inputs is None:
|
||||
raise ValueError("Cannot call the text-to-audio pipeline without an inputs argument!")
|
||||
# After deprecation of this is completed, rename the input kwarg to `generate_kwargs`
|
||||
if {key for key in forward_params if key != "generate_kwargs"}:
|
||||
warnings.warn(
|
||||
"Kwargs other than `generate_kwargs` are deprecated. In version 5 of Transformers, only the `generate_kwargs` kwarg will be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
return super().__call__(inputs, **forward_params)
|
||||
|
||||
def _sanitize_parameters(
|
||||
self,
|
||||
|
@ -7,11 +7,10 @@ import numpy as np
|
||||
from ..models.bert.tokenization_bert import BasicTokenizer
|
||||
from ..utils import (
|
||||
ExplicitEnum,
|
||||
add_end_docstrings,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from .base import ArgumentHandler, ChunkPipeline, Dataset, build_pipeline_init_args
|
||||
from .base import ArgumentHandler, ChunkPipeline, Dataset
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@ -58,40 +57,6 @@ class AggregationStrategy(ExplicitEnum):
|
||||
MAX = "max"
|
||||
|
||||
|
||||
@add_end_docstrings(
|
||||
build_pipeline_init_args(has_tokenizer=True),
|
||||
r"""
|
||||
ignore_labels (`List[str]`, defaults to `["O"]`):
|
||||
A list of labels to ignore.
|
||||
grouped_entities (`bool`, *optional*, defaults to `False`):
|
||||
DEPRECATED, use `aggregation_strategy` instead. Whether or not to group the tokens corresponding to the
|
||||
same entity together in the predictions or not.
|
||||
stride (`int`, *optional*):
|
||||
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
|
||||
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
|
||||
value of this argument defines the number of overlapping tokens between chunks. In other words, the model
|
||||
will shift forward by `tokenizer.model_max_length - stride` tokens each step.
|
||||
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
|
||||
The strategy to fuse (or not) tokens based on the model prediction.
|
||||
|
||||
- "none" : Will simply not do any aggregation and simply return raw results from the model
|
||||
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
|
||||
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
|
||||
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
|
||||
different entities. On word based languages, we might end up splitting words undesirably : Imagine
|
||||
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
|
||||
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
|
||||
that support that meaning, which is basically tokens separated by a space). These mitigations will
|
||||
only work on real words, "New york" might still be tagged with two different entities.
|
||||
- "first" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
|
||||
end up with different tags. Words will simply use the tag of the first token of the word when there
|
||||
is ambiguity.
|
||||
- "average" : (works only on word based models) Will use the `SIMPLE` strategy except that words,
|
||||
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
|
||||
label is applied.
|
||||
- "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
|
||||
end up with different tags. Word entity will simply be the token with the maximum score.""",
|
||||
)
|
||||
class TokenClassificationPipeline(ChunkPipeline):
|
||||
"""
|
||||
Named Entity Recognition pipeline using any `ModelForTokenClassification`. See the [named entity recognition
|
||||
@ -222,6 +187,33 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
Args:
|
||||
inputs (`str` or `List[str]`):
|
||||
One or several texts (or one list of texts) for token classification.
|
||||
ignore_labels (`List[str]`, defaults to `["O"]`):
|
||||
A list of labels to ignore.
|
||||
stride (`int`, *optional*):
|
||||
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
|
||||
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
|
||||
value of this argument defines the number of overlapping tokens between chunks. In other words, the model
|
||||
will shift forward by `tokenizer.model_max_length - stride` tokens each step.
|
||||
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
|
||||
The strategy to fuse (or not) tokens based on the model prediction.
|
||||
|
||||
- "none" : Will simply not do any aggregation and simply return raw results from the model
|
||||
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
|
||||
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
|
||||
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
|
||||
different entities. On word based languages, we might end up splitting words undesirably : Imagine
|
||||
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
|
||||
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
|
||||
that support that meaning, which is basically tokens separated by a space). These mitigations will
|
||||
only work on real words, "New york" might still be tagged with two different entities.
|
||||
- "first" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
|
||||
end up with different tags. Words will simply use the tag of the first token of the word when there
|
||||
is ambiguity.
|
||||
- "average" : (works only on word based models) Will use the `SIMPLE` strategy except that words,
|
||||
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
|
||||
label is applied.
|
||||
- "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
|
||||
end up with different tags. Word entity will simply be the token with the maximum score.
|
||||
|
||||
Return:
|
||||
A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the
|
||||
@ -239,6 +231,7 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
exists if the offsets are available within the tokenizer
|
||||
- **end** (`int`, *optional*) -- The index of the end of the corresponding entity in the sentence. Only
|
||||
exists if the offsets are available within the tokenizer
|
||||
|
||||
"""
|
||||
|
||||
_inputs, offset_mapping = self._args_parser(inputs, **kwargs)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from typing import List, Union
|
||||
|
||||
@ -42,7 +43,7 @@ class VideoClassificationPipeline(Pipeline):
|
||||
requires_backends(self, "av")
|
||||
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)
|
||||
|
||||
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
|
||||
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None, function_to_apply=None):
|
||||
preprocess_params = {}
|
||||
if frame_sampling_rate is not None:
|
||||
preprocess_params["frame_sampling_rate"] = frame_sampling_rate
|
||||
@ -52,14 +53,23 @@ class VideoClassificationPipeline(Pipeline):
|
||||
postprocess_params = {}
|
||||
if top_k is not None:
|
||||
postprocess_params["top_k"] = top_k
|
||||
if function_to_apply is not None:
|
||||
if function_to_apply not in ["softmax", "sigmoid", "none"]:
|
||||
raise ValueError(
|
||||
f"Invalid value for `function_to_apply`: {function_to_apply}. "
|
||||
"Valid options are ['softmax', 'sigmoid', 'none']"
|
||||
)
|
||||
postprocess_params["function_to_apply"] = function_to_apply
|
||||
else:
|
||||
postprocess_params["function_to_apply"] = "softmax"
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, videos: Union[str, List[str]], **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Assign labels to the video(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
videos (`str`, `List[str]`):
|
||||
inputs (`str`, `List[str]`):
|
||||
The pipeline handles three types of videos:
|
||||
|
||||
- A string containing a http link pointing to a video
|
||||
@ -76,6 +86,11 @@ class VideoClassificationPipeline(Pipeline):
|
||||
frame_sampling_rate (`int`, *optional*, defaults to 1):
|
||||
The sampling rate used to select frames from the video. If not provided, will default to 1, i.e. every
|
||||
frame will be used.
|
||||
function_to_apply(`str`, *optional*, defaults to "softmax"):
|
||||
The function to apply to the model output. By default, the pipeline will apply the softmax function to
|
||||
the output of the model. Valid options: ["softmax", "sigmoid", "none"]. Note that passing Python's
|
||||
built-in `None` will default to "softmax", so you need to pass the string "none" to disable any
|
||||
post-processing.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing result. If the input is a single video, will return a
|
||||
@ -87,7 +102,16 @@ class VideoClassificationPipeline(Pipeline):
|
||||
- **label** (`str`) -- The label identified by the model.
|
||||
- **score** (`int`) -- The score attributed by the model for that label.
|
||||
"""
|
||||
return super().__call__(videos, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "videos" in kwargs:
|
||||
warnings.warn(
|
||||
"The `videos` argument has been renamed to `inputs`. In version 5 of Transformers, `videos` will no longer be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
inputs = kwargs.pop("videos")
|
||||
if inputs is None:
|
||||
raise ValueError("Cannot call the video-classification pipeline without an inputs argument!")
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def preprocess(self, video, num_frames=None, frame_sampling_rate=1):
|
||||
if num_frames is None:
|
||||
@ -114,12 +138,17 @@ class VideoClassificationPipeline(Pipeline):
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, top_k=5):
|
||||
def postprocess(self, model_outputs, top_k=5, function_to_apply="softmax"):
|
||||
if top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
|
||||
if self.framework == "pt":
|
||||
probs = model_outputs.logits.softmax(-1)[0]
|
||||
if function_to_apply == "softmax":
|
||||
probs = model_outputs.logits[0].softmax(-1)
|
||||
elif function_to_apply == "sigmoid":
|
||||
probs = model_outputs.logits[0].sigmoid()
|
||||
else:
|
||||
probs = model_outputs.logits[0]
|
||||
scores, ids = probs.topk(top_k)
|
||||
else:
|
||||
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging
|
||||
@ -63,6 +64,9 @@ class VisualQuestionAnsweringPipeline(Pipeline):
|
||||
if truncation is not None:
|
||||
preprocess_params["truncation"] = truncation
|
||||
if timeout is not None:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = timeout
|
||||
if top_k is not None:
|
||||
postprocess_params["top_k"] = top_k
|
||||
@ -110,9 +114,6 @@ class VisualQuestionAnsweringPipeline(Pipeline):
|
||||
top_k (`int`, *optional*, defaults to 5):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
||||
the number of labels available in the model configuration, it will default to the number of labels.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing the result. The dictionaries contain the following keys:
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@ -162,7 +163,7 @@ class ZeroShotClassificationPipeline(ChunkPipeline):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sequences: Union[str, List[str]],
|
||||
text: Union[str, List[str]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@ -171,7 +172,7 @@ class ZeroShotClassificationPipeline(ChunkPipeline):
|
||||
information.
|
||||
|
||||
Args:
|
||||
sequences (`str` or `List[str]`):
|
||||
text (`str` or `List[str]`):
|
||||
The sequence(s) to classify, will be truncated if the model input is too large.
|
||||
candidate_labels (`str` or `List[str]`):
|
||||
The set of possible class labels to classify each sequence into. Can be a single label, a string of
|
||||
@ -203,7 +204,17 @@ class ZeroShotClassificationPipeline(ChunkPipeline):
|
||||
else:
|
||||
raise ValueError(f"Unable to understand extra arguments {args}")
|
||||
|
||||
return super().__call__(sequences, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `text`
|
||||
if "sequences" in kwargs:
|
||||
warnings.warn(
|
||||
"The `sequences` argument has been renamed to `text`. In version 5 of Transformers, `sequences` will no longer be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
text = kwargs.pop("sequences")
|
||||
if text is None:
|
||||
raise ValueError("Cannot call the zero_shot_classification pipeline without a text argument!")
|
||||
|
||||
return super().__call__(text, **kwargs)
|
||||
|
||||
def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
|
||||
sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from collections import UserDict
|
||||
from typing import List, Union
|
||||
|
||||
@ -73,12 +74,12 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
|
||||
def __call__(self, image: Union[str, List[str], "Image", List["Image"]] = None, **kwargs):
|
||||
"""
|
||||
Assign labels to the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
image (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
@ -93,13 +94,6 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
replacing the placeholder with the candidate_labels. Pass "{}" if *candidate_labels* are
|
||||
already formatted.
|
||||
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
tokenizer_kwargs (`dict`, *optional*):
|
||||
Additional dictionary of keyword arguments passed along to the tokenizer.
|
||||
|
||||
Return:
|
||||
A list of dictionaries containing one entry per proposed label. Each dictionary contains the
|
||||
following keys:
|
||||
@ -107,17 +101,33 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
- **score** (`float`) -- The score attributed by the model to that label. It is a value between
|
||||
0 and 1, computed as the `softmax` of `logits_per_image`.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `image`
|
||||
if "images" in kwargs:
|
||||
warnings.warn(
|
||||
"The `images` argument has been renamed to `image`. In version 5 of Transformers, `images` will no longer be accepted",
|
||||
FutureWarning,
|
||||
)
|
||||
image = kwargs.pop("images")
|
||||
if image is None:
|
||||
raise ValueError("Cannot call the zero-shot-image-classification pipeline without an images argument!")
|
||||
return super().__call__(image, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, tokenizer_kwargs=None, **kwargs):
|
||||
preprocess_params = {}
|
||||
if "candidate_labels" in kwargs:
|
||||
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
|
||||
if "timeout" in kwargs:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = kwargs["timeout"]
|
||||
if "hypothesis_template" in kwargs:
|
||||
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
|
||||
if tokenizer_kwargs is not None:
|
||||
warnings.warn(
|
||||
"The `tokenizer_kwargs` argument is deprecated and will be removed in version 5 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
|
||||
|
||||
return preprocess_params, {}, {}
|
||||
|
@ -66,7 +66,7 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
|
||||
self,
|
||||
image: Union[str, "Image.Image", List[Dict[str, Any]]],
|
||||
candidate_labels: Union[str, List[str]] = None,
|
||||
**kwargs,
|
||||
**parameters,
|
||||
):
|
||||
"""
|
||||
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
|
||||
@ -104,16 +104,8 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
|
||||
candidate_labels (`str` or `List[str]` or `List[List[str]]`):
|
||||
What the model should recognize in the image.
|
||||
|
||||
threshold (`float`, *optional*, defaults to 0.1):
|
||||
The probability necessary to make a prediction.
|
||||
|
||||
top_k (`int`, *optional*, defaults to None):
|
||||
The number of top predictions that will be returned by the pipeline. If the provided number is `None`
|
||||
or higher than the number of predictions available, it will default to the number of predictions.
|
||||
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
parameters (`Dict[str, Any]`, *optional*): Additional inference parameters. Valid parameters include
|
||||
`threshold`, top_k`, and `timeout`.
|
||||
|
||||
|
||||
Return:
|
||||
@ -125,14 +117,14 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
|
||||
- **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a
|
||||
dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys.
|
||||
"""
|
||||
if "text_queries" in kwargs:
|
||||
candidate_labels = kwargs.pop("text_queries")
|
||||
if "text_queries" in parameters:
|
||||
candidate_labels = parameters.pop("text_queries")
|
||||
|
||||
if isinstance(image, (str, Image.Image)):
|
||||
inputs = {"image": image, "candidate_labels": candidate_labels}
|
||||
else:
|
||||
inputs = image
|
||||
results = super().__call__(inputs, **kwargs)
|
||||
results = super().__call__(inputs, **parameters)
|
||||
return results
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
|
@ -16,14 +16,22 @@
|
||||
|
||||
import datetime
|
||||
import gc
|
||||
import inspect
|
||||
import math
|
||||
import re
|
||||
import unittest
|
||||
from dataclasses import fields
|
||||
from inspect import isclass
|
||||
from textwrap import dedent
|
||||
from typing import get_args
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.inference._generated import types as inference_specs
|
||||
|
||||
from transformers import GPT2Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_pipeline_test,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
@ -939,3 +947,174 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(output_native, output_fa_2)
|
||||
self.assertListEqual(output_native, expected_output)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class HuggingfaceJSEquivalencetest(unittest.TestCase):
|
||||
def test_huggingface_js_equivalence(self):
|
||||
from transformers import (
|
||||
AudioClassificationPipeline,
|
||||
AutomaticSpeechRecognitionPipeline,
|
||||
DepthEstimationPipeline,
|
||||
DocumentQuestionAnsweringPipeline,
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
ImageClassificationPipeline,
|
||||
ImageFeatureExtractionPipeline,
|
||||
ImageSegmentationPipeline,
|
||||
ImageToImagePipeline,
|
||||
ImageToTextPipeline,
|
||||
MaskGenerationPipeline,
|
||||
ObjectDetectionPipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
SummarizationPipeline,
|
||||
TableQuestionAnsweringPipeline,
|
||||
Text2TextGenerationPipeline,
|
||||
TextClassificationPipeline,
|
||||
TextGenerationPipeline,
|
||||
TextToAudioPipeline,
|
||||
TokenClassificationPipeline,
|
||||
TranslationPipeline,
|
||||
VideoClassificationPipeline,
|
||||
VisualQuestionAnsweringPipeline,
|
||||
ZeroShotAudioClassificationPipeline,
|
||||
ZeroShotClassificationPipeline,
|
||||
ZeroShotImageClassificationPipeline,
|
||||
ZeroShotObjectDetectionPipeline,
|
||||
)
|
||||
|
||||
# Putting this here for now because this file is already tested by CircleCI, will move later
|
||||
|
||||
PIPELINES_TO_TEST = {
|
||||
"audio-classification": (AudioClassificationPipeline, inference_specs.AudioClassificationInput),
|
||||
"automatic-speech-recognition": (
|
||||
AutomaticSpeechRecognitionPipeline,
|
||||
inference_specs.AutomaticSpeechRecognitionInput,
|
||||
),
|
||||
"document-question-answering": (
|
||||
DocumentQuestionAnsweringPipeline,
|
||||
inference_specs.DocumentQuestionAnsweringInput,
|
||||
),
|
||||
"feature-extraction": (FeatureExtractionPipeline, inference_specs.FeatureExtractionInput),
|
||||
"image-classification": (ImageClassificationPipeline, inference_specs.ImageClassificationInput),
|
||||
"text-to-audio": (TextToAudioPipeline, inference_specs.TextToAudioInput),
|
||||
"text-classification": (TextClassificationPipeline, inference_specs.TextClassificationInput),
|
||||
"token-classification": (TokenClassificationPipeline, inference_specs.TokenClassificationInput),
|
||||
"question-answering": (QuestionAnsweringPipeline, inference_specs.QuestionAnsweringInput),
|
||||
"table-question-answering": (TableQuestionAnsweringPipeline, inference_specs.TableQuestionAnsweringInput),
|
||||
"visual-question-answering": (
|
||||
VisualQuestionAnsweringPipeline,
|
||||
inference_specs.VisualQuestionAnsweringInput,
|
||||
),
|
||||
"fill-mask": (FillMaskPipeline, inference_specs.FillMaskInput),
|
||||
"summarization": (SummarizationPipeline, inference_specs.SummarizationInput),
|
||||
"translation": (TranslationPipeline, inference_specs.TranslationInput),
|
||||
"text2text-generation": (Text2TextGenerationPipeline, inference_specs.Text2TextGenerationInput),
|
||||
"text-generation": (TextGenerationPipeline, inference_specs.TextGenerationInput),
|
||||
"image-segmentation": (ImageSegmentationPipeline, inference_specs.ImageSegmentationInput),
|
||||
"image-to-text": (ImageToTextPipeline, inference_specs.ImageToTextInput),
|
||||
"object-detection": (ObjectDetectionPipeline, inference_specs.ObjectDetectionInput),
|
||||
"depth-estimation": (DepthEstimationPipeline, inference_specs.DepthEstimationInput),
|
||||
"zero-shot-object-detection": (
|
||||
ZeroShotObjectDetectionPipeline,
|
||||
inference_specs.ZeroShotObjectDetectionInput,
|
||||
),
|
||||
"zero-shot-classification": (ZeroShotClassificationPipeline, inference_specs.ZeroShotClassificationInput),
|
||||
"zero-shot-image-classification": (
|
||||
ZeroShotImageClassificationPipeline,
|
||||
inference_specs.ZeroShotImageClassificationInput,
|
||||
),
|
||||
"video-classification": (VideoClassificationPipeline, inference_specs.VideoClassificationInput),
|
||||
}
|
||||
PIPELINES_WITHOUT_SPEC = { # noqa: F841
|
||||
"image-feature-extraction": ImageFeatureExtractionPipeline,
|
||||
"zero-shot-audio-classification": ZeroShotAudioClassificationPipeline,
|
||||
"mask-generation": MaskGenerationPipeline,
|
||||
"image-to-image": ImageToImagePipeline, # The huggingface_hub version of this looks like diffusers
|
||||
}
|
||||
|
||||
mismatches = []
|
||||
for task, (pipeline_cls, js_spec) in PIPELINES_TO_TEST.items():
|
||||
docstring = inspect.getdoc(pipeline_cls.__call__).strip()
|
||||
docstring_args = set(self._parse_google_format_docstring_by_indentation(docstring))
|
||||
js_args = set(self.get_arg_names_from_hub_spec(js_spec))
|
||||
|
||||
# Special casing: We allow the name of this arg to differ
|
||||
js_generate_args = [js_arg for js_arg in js_args if js_arg.startswith("generate")]
|
||||
docstring_generate_args = [
|
||||
docstring_arg for docstring_arg in docstring_args if docstring_arg.startswith("generate")
|
||||
]
|
||||
if (
|
||||
len(js_generate_args) == 1
|
||||
and len(docstring_generate_args) == 1
|
||||
and js_generate_args != docstring_generate_args
|
||||
):
|
||||
js_args.remove(js_generate_args[0])
|
||||
docstring_args.remove(docstring_generate_args[0])
|
||||
|
||||
if js_args != docstring_args:
|
||||
mismatches.append((task, js_args, docstring_args))
|
||||
|
||||
if mismatches:
|
||||
error = ["The following tasks have divergent input specs:", ""]
|
||||
for mismatch in mismatches:
|
||||
task = mismatch[0]
|
||||
matching_args = mismatch[1] & mismatch[2]
|
||||
huggingface_js_only = mismatch[1] - mismatch[2]
|
||||
transformers_only = mismatch[2] - mismatch[1]
|
||||
error.append(f"Task: {task}")
|
||||
if matching_args:
|
||||
error.append(f"Matching args: {matching_args}")
|
||||
if huggingface_js_only:
|
||||
error.append(f"Huggingface.js only: {huggingface_js_only}")
|
||||
if transformers_only:
|
||||
error.append(f"Transformers only: {transformers_only}")
|
||||
error.append("")
|
||||
raise ValueError("\n".join(error))
|
||||
|
||||
def get_arg_names_from_hub_spec(self, hub_spec, first_level=True):
|
||||
arg_names = []
|
||||
for field in fields(hub_spec):
|
||||
# First, recurse into nested fields
|
||||
if first_level and isclass(field.type) and issubclass(field.type, inference_specs.BaseInferenceType):
|
||||
arg_names.extend(self.get_arg_names_from_hub_spec(field.type, first_level=False))
|
||||
continue
|
||||
# Next, catch nested fields that are part of a Union[], which is usually caused by Optional[]
|
||||
for param_type in get_args(field.type):
|
||||
if first_level and isclass(param_type) and issubclass(param_type, inference_specs.BaseInferenceType):
|
||||
arg_names.extend(
|
||||
self.get_arg_names_from_hub_spec(param_type, first_level=False)
|
||||
) # Recurse into nested fields
|
||||
break
|
||||
else:
|
||||
# Finally, this line triggers if it's not a nested field
|
||||
arg_names.append(field.name)
|
||||
return arg_names
|
||||
|
||||
@staticmethod
|
||||
def _parse_google_format_docstring_by_indentation(docstring):
|
||||
docstring = dedent(docstring)
|
||||
lines_by_indent = [
|
||||
(len(line) - len(line.lstrip()), line.strip()) for line in docstring.split("\n") if line.strip()
|
||||
]
|
||||
args_lineno = None
|
||||
args_indent = None
|
||||
args_end = None
|
||||
for lineno, (indent, line) in enumerate(lines_by_indent):
|
||||
if line == "Args:":
|
||||
args_lineno = lineno
|
||||
args_indent = indent
|
||||
continue
|
||||
elif args_lineno is not None and indent == args_indent:
|
||||
args_end = lineno
|
||||
break
|
||||
if args_lineno is None:
|
||||
raise ValueError("No args block to parse!")
|
||||
elif args_end is None:
|
||||
args_block = lines_by_indent[args_lineno + 1 :]
|
||||
else:
|
||||
args_block = lines_by_indent[args_lineno + 1 : args_end]
|
||||
outer_indent_level = min(line[0] for line in args_block)
|
||||
outer_lines = [line for line in args_block if line[0] == outer_indent_level]
|
||||
arg_names = [re.match(r"(\w+)\W", line[1]).group(1) for line in outer_lines]
|
||||
return arg_names
|
||||
|
@ -97,7 +97,7 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
classifier("", candidate_labels="politics")
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
with self.assertRaises(ValueError):
|
||||
classifier(None, candidate_labels="politics")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
|
Reference in New Issue
Block a user