Files
openmind/tests/unit/test_pipeline.py
幽若 e7ccd0b316 !98 fix clean code issue
Merge pull request !98 from 幽若/dev-cleancode
2025-03-03 07:19:22 +00:00

90 lines
3.1 KiB
Python

# Copyright (c) 2024 Huawei Technologies Co., Ltd.
# openMind is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
# http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import sys
import unittest
from io import StringIO
from openmind import pipeline
from openmind.archived.pipelines.builder import _parse_native_json
from openmind.archived.pipelines.interface import generate_pipeline_report
class FakeHFPipeline:
__name__ = "FakeHFPipeline"
class TestPipeline(unittest.TestCase):
def test_pipeline_task_and_model_are_none(self):
task = None
model = None
with self.assertRaises(RuntimeError):
pipeline(task=task, model=model)
def test_pipeline_model_is_none_and_tokenizer_is_not_none(self):
model = None
tokenizer = "tokenizer"
with self.assertRaises(RuntimeError):
pipeline(model=model, tokenizer=tokenizer)
def test_pipeline_model_is_none_and_feature_extractor_is_not_none(self):
model = None
feature_extractor = "feature_extractor"
with self.assertRaises(RuntimeError):
pipeline(model=model, feature_extractor=feature_extractor)
def test_pipeline_model_is_none_and_image_processor_is_not_none(self):
model = None
image_processor = "image_processor"
with self.assertRaises(RuntimeError):
pipeline(model=model, image_processor=image_processor)
def test_pipeline_task_is_none_and_model_is_object(self):
model = FakeHFPipeline()
with self.assertRaises(RuntimeError):
pipeline(model=model)
def test_parse_native_json(self):
pipeline_config, framework, backend = _parse_native_json(
task="visual-question-answering", framework="pt", backend="transformers"
)
self.assertEqual(
pipeline_config,
{
"pipeline_class": "common.hf.VisualQuestionAnsweringPipeline",
"supported_models": ["PyTorch-NPU/blip_vqa_base@4450392"],
},
)
self.assertEqual(framework, "pt")
self.assertEqual(backend, "transformers")
with self.assertRaises(KeyError):
_parse_native_json(task="test", framework="pt", backend="transformers")
def test_pipeline_report(self):
original_stdout = sys.stdout
sys.stdout = StringIO()
# test generate_pipeline_report w/o print report
generate_pipeline_report(print_report=False)
output = sys.stdout.getvalue()
self.assertIn("", output)
# test generate_pipeline_report w print report
generate_pipeline_report()
output = sys.stdout.getvalue()
self.assertIn("text-classification", output)
sys.stdout = original_stdout