Compare commits

...

2 Commits

Author SHA1 Message Date
694bd63099 make fixup 2024-11-18 17:07:58 +00:00
abe58f0a96 Sync TableQuestionAnswering 2024-11-18 17:07:30 +00:00
2 changed files with 11 additions and 0 deletions

View File

@ -14,6 +14,8 @@
import unittest
from huggingface_hub import TableQuestionAnsweringOutputElement
from transformers import (
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
AutoModelForTableQuestionAnswering,
@ -24,6 +26,7 @@ from transformers import (
pipeline,
)
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
require_pandas,
require_tensorflow_probability,
@ -66,6 +69,7 @@ class TQAPipelineTests(unittest.TestCase):
},
query="how many movies has george clooney played in?",
)
compare_pipeline_output_to_hub_spec(outputs, TableQuestionAnsweringOutputElement)
self.assertEqual(
outputs,
{"answer": "AVERAGE > ", "coordinates": [], "cells": [], "aggregator": "AVERAGE"},
@ -87,6 +91,8 @@ class TQAPipelineTests(unittest.TestCase):
{"answer": "AVERAGE > ", "coordinates": [], "cells": [], "aggregator": "AVERAGE"},
],
)
for output in outputs:
compare_pipeline_output_to_hub_spec(output, TableQuestionAnsweringOutputElement)
outputs = table_querier(
table={
"Repository": ["Transformers", "Datasets", "Tokenizers"],
@ -113,6 +119,8 @@ class TQAPipelineTests(unittest.TestCase):
{"answer": "AVERAGE > ", "coordinates": [], "cells": [], "aggregator": "AVERAGE"},
],
)
for output in outputs:
compare_pipeline_output_to_hub_spec(output, TableQuestionAnsweringOutputElement)
with self.assertRaises(ValueError):
table_querier(query="What does it do with empty context ?", table=None)

View File

@ -34,6 +34,7 @@ from huggingface_hub import (
ImageToTextInput,
ObjectDetectionInput,
QuestionAnsweringInput,
TableQuestionAnsweringInput,
VideoClassificationInput,
ZeroShotImageClassificationInput,
)
@ -48,6 +49,7 @@ from transformers.pipelines import (
ImageToTextPipeline,
ObjectDetectionPipeline,
QuestionAnsweringPipeline,
TableQuestionAnsweringPipeline,
VideoClassificationPipeline,
ZeroShotImageClassificationPipeline,
)
@ -136,6 +138,7 @@ task_to_pipeline_and_spec_mapping = {
"image-to-text": (ImageToTextPipeline, ImageToTextInput),
"object-detection": (ObjectDetectionPipeline, ObjectDetectionInput),
"question-answering": (QuestionAnsweringPipeline, QuestionAnsweringInput),
"table-question-answering": (TableQuestionAnsweringPipeline, TableQuestionAnsweringInput),
"video-classification": (VideoClassificationPipeline, VideoClassificationInput),
"zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput),
}