# Copyright 2023 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import io import json import os import pathlib import subprocess import tempfile import time import unittest import warnings from copy import deepcopy from datetime import datetime import httpx import numpy as np import pytest from packaging import version from transformers import AutoImageProcessor, BatchFeature from transformers.image_utils import AnnotationFormat, AnnotionFormat from transformers.models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES from transformers.testing_utils import ( check_json_file_has_correct_format, require_torch, require_torch_accelerator, require_vision, slow, torch_device, ) from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available if is_torchvision_available(): from transformers.image_processing_utils_fast import BaseImageProcessorFast if is_torch_available(): import torch if is_vision_available(): from PIL import Image def prepare_image_inputs( batch_size, min_resolution, max_resolution, num_channels, size_divisor=None, equal_resolution=False, numpify=False, torchify=False, ): """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, or a list of PyTorch tensors if one specifies torchify=True. One can specify whether the images are of the same resolution or not. """ assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time" image_inputs = [] for i in range(batch_size): if equal_resolution: width = height = max_resolution else: # To avoid getting image width/height 0 if size_divisor is not None: # If `size_divisor` is defined, the image needs to have width/size >= `size_divisor` min_resolution = max(size_divisor, min_resolution) width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2) image_inputs.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8)) if not numpify and not torchify: # PIL expects the channel dimension as last dimension image_inputs = [Image.fromarray(np.moveaxis(image, 0, -1)) for image in image_inputs] if torchify: image_inputs = [torch.from_numpy(image) for image in image_inputs] if numpify: # Numpy images are typically in channels last format image_inputs = [image.transpose(1, 2, 0) for image in image_inputs] return image_inputs def prepare_video(num_frames, num_channels, width=10, height=10, numpify=False, torchify=False): """This function prepares a video as a list of PIL images/NumPy arrays/PyTorch tensors.""" video = [] for i in range(num_frames): video.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8)) if not numpify and not torchify: # PIL expects the channel dimension as last dimension video = [Image.fromarray(np.moveaxis(frame, 0, -1)) for frame in video] if torchify: video = [torch.from_numpy(frame) for frame in video] return video def prepare_video_inputs( batch_size, num_frames, num_channels, min_resolution, max_resolution, equal_resolution=False, numpify=False, torchify=False, ): """This function prepares a batch of videos: a list of list of PIL images, or a list of list of numpy arrays if one specifies numpify=True, or a list of list of PyTorch tensors if one specifies torchify=True. One can specify whether the videos are of the same resolution or not. """ assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time" video_inputs = [] for _ in range(batch_size): if equal_resolution: width = height = max_resolution else: width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2) video = prepare_video( num_frames=num_frames, num_channels=num_channels, width=width, height=height, numpify=numpify, torchify=torchify, ) video_inputs.append(video) return video_inputs class ImageProcessingTestMixin: test_cast_dtype = None image_processing_class = None fast_image_processing_class = None image_processors_list = None test_slow_image_processor = True test_fast_image_processor = True def setUp(self): image_processor_list = [] if self.test_slow_image_processor and self.image_processing_class: image_processor_list.append(self.image_processing_class) if self.test_fast_image_processor and self.fast_image_processing_class: image_processor_list.append(self.fast_image_processing_class) self.image_processor_list = image_processor_list def _assert_slow_fast_tensors_equivalence(self, slow_tensor, fast_tensor, atol=1e-1, rtol=1e-3, mean_atol=5e-3): torch.testing.assert_close(slow_tensor, fast_tensor, atol=atol, rtol=rtol) self.assertLessEqual(torch.mean(torch.abs(slow_tensor - fast_tensor)).item(), mean_atol) @require_vision @require_torch def test_slow_fast_equivalence(self): if not self.test_slow_image_processor or not self.test_fast_image_processor: self.skipTest(reason="Skipping slow/fast equivalence test") if self.image_processing_class is None or self.fast_image_processing_class is None: self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") dummy_image = Image.open( io.BytesIO( httpx.get("http://images.cocodataset.org/val2017/000000039769.jpg", follow_redirects=True).content ) ) image_processor_slow = self.image_processing_class(**self.image_processor_dict) image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) @require_vision @require_torch def test_slow_fast_equivalence_batched(self): if not self.test_slow_image_processor or not self.test_fast_image_processor: self.skipTest(reason="Skipping slow/fast equivalence test") if self.image_processing_class is None or self.fast_image_processing_class is None: self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) image_processor_slow = self.image_processing_class(**self.image_processor_dict) image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) encoding_slow = image_processor_slow(dummy_images, return_tensors="pt") encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) @require_vision @require_torch @unittest.skip(reason="Many failing cases. This test needs a more deep investigation.") def test_fast_is_faster_than_slow(self): if not self.test_slow_image_processor or not self.test_fast_image_processor: self.skipTest(reason="Skipping speed test") if self.image_processing_class is None or self.fast_image_processing_class is None: self.skipTest(reason="Skipping speed test as one of the image processors is not defined") def measure_time(image_processor, image): # Warmup for _ in range(5): _ = image_processor(image, return_tensors="pt") all_times = [] for _ in range(10): start = time.time() _ = image_processor(image, return_tensors="pt") all_times.append(time.time() - start) # Take the average of the fastest 3 runs avg_time = sum(sorted(all_times[:3])) / 3.0 return avg_time dummy_images = [torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) for _ in range(4)] image_processor_slow = self.image_processing_class(**self.image_processor_dict) image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) fast_time = measure_time(image_processor_fast, dummy_images) slow_time = measure_time(image_processor_slow, dummy_images) self.assertLessEqual(fast_time, slow_time) def test_is_fast(self): for image_processing_class in self.image_processor_list: image_processor = image_processing_class(**self.image_processor_dict) # Check is_fast is set correctly if is_torchvision_available() and issubclass(image_processing_class, BaseImageProcessorFast): self.assertTrue(image_processor.is_fast) else: self.assertFalse(image_processor.is_fast) def test_image_processor_to_json_string(self): for image_processing_class in self.image_processor_list: image_processor = image_processing_class(**self.image_processor_dict) obj = json.loads(image_processor.to_json_string()) for key, value in self.image_processor_dict.items(): self.assertEqual(obj[key], value) def test_image_processor_to_json_file(self): for image_processing_class in self.image_processor_list: image_processor_first = image_processing_class(**self.image_processor_dict) with tempfile.TemporaryDirectory() as tmpdirname: json_file_path = os.path.join(tmpdirname, "image_processor.json") image_processor_first.to_json_file(json_file_path) image_processor_second = image_processing_class.from_json_file(json_file_path) self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict()) def test_image_processor_from_and_save_pretrained(self): for image_processing_class in self.image_processor_list: image_processor_first = image_processing_class(**self.image_processor_dict) with tempfile.TemporaryDirectory() as tmpdirname: saved_file = image_processor_first.save_pretrained(tmpdirname)[0] check_json_file_has_correct_format(saved_file) image_processor_second = image_processing_class.from_pretrained(tmpdirname) self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict()) def test_image_processor_save_load_with_autoimageprocessor(self): for i, image_processing_class in enumerate(self.image_processor_list): image_processor_first = image_processing_class(**self.image_processor_dict) with tempfile.TemporaryDirectory() as tmpdirname: saved_file = image_processor_first.save_pretrained(tmpdirname)[0] check_json_file_has_correct_format(saved_file) use_fast = i == 1 or not self.test_slow_image_processor image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=use_fast) self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict()) def test_save_load_fast_slow(self): "Test that we can load a fast image processor from a slow one and vice-versa." if self.image_processing_class is None or self.fast_image_processing_class is None: self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined") image_processor_dict = self.image_processor_tester.prepare_image_processor_dict() image_processor_slow_0 = self.image_processing_class(**image_processor_dict) # Load fast image processor from slow one with tempfile.TemporaryDirectory() as tmpdirname: image_processor_slow_0.save_pretrained(tmpdirname) image_processor_fast_0 = self.fast_image_processing_class.from_pretrained(tmpdirname) image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict) # Load slow image processor from fast one with tempfile.TemporaryDirectory() as tmpdirname: image_processor_fast_1.save_pretrained(tmpdirname) image_processor_slow_1 = self.image_processing_class.from_pretrained(tmpdirname) dict_slow_0 = image_processor_slow_0.to_dict() dict_slow_1 = image_processor_slow_1.to_dict() difference = { key: dict_slow_0.get(key) if key in dict_slow_0 else dict_slow_1.get(key) for key in set(dict_slow_0) ^ set(dict_slow_1) } dict_slow_0 = {key: dict_slow_0[key] for key in set(dict_slow_0) & set(dict_slow_1)} dict_slow_1 = {key: dict_slow_1[key] for key in set(dict_slow_0) & set(dict_slow_1)} # check that all additional keys are None, except for `default_to_square` and `data_format` which are only set in fast processors self.assertTrue( all(value is None for key, value in difference.items() if key not in ["default_to_square", "data_format"]) ) # check that the remaining keys are the same self.assertEqual(dict_slow_0, dict_slow_1) dict_fast_0 = image_processor_fast_0.to_dict() dict_fast_1 = image_processor_fast_1.to_dict() difference = { key: dict_fast_0.get(key) if key in dict_fast_0 else dict_fast_1.get(key) for key in set(dict_fast_0) ^ set(dict_fast_1) } dict_fast_0 = {key: dict_fast_0[key] for key in set(dict_fast_0) & set(dict_fast_1)} dict_fast_1 = {key: dict_fast_1[key] for key in set(dict_fast_0) & set(dict_fast_1)} # check that all additional keys are None, except for `default_to_square` and `data_format` which are only set in fast processors self.assertTrue( all(value is None for key, value in difference.items() if key not in ["default_to_square", "data_format"]) ) # check that the remaining keys are the same self.assertEqual(dict_fast_0, dict_fast_1) def test_save_load_fast_slow_auto(self): "Test that we can load a fast image processor from a slow one and vice-versa using AutoImageProcessor." if self.image_processing_class is None or self.fast_image_processing_class is None: self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined") image_processor_dict = self.image_processor_tester.prepare_image_processor_dict() image_processor_slow_0 = self.image_processing_class(**image_processor_dict) # Load fast image processor from slow one with tempfile.TemporaryDirectory() as tmpdirname: image_processor_slow_0.save_pretrained(tmpdirname) image_processor_fast_0 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=True) image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict) # Load slow image processor from fast one with tempfile.TemporaryDirectory() as tmpdirname: image_processor_fast_1.save_pretrained(tmpdirname) image_processor_slow_1 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=False) dict_slow_0 = image_processor_slow_0.to_dict() dict_slow_1 = image_processor_slow_1.to_dict() difference = { key: dict_slow_0.get(key) if key in dict_slow_0 else dict_slow_1.get(key) for key in set(dict_slow_0) ^ set(dict_slow_1) } dict_slow_0 = {key: dict_slow_0[key] for key in set(dict_slow_0) & set(dict_slow_1)} dict_slow_1 = {key: dict_slow_1[key] for key in set(dict_slow_0) & set(dict_slow_1)} # check that all additional keys are None, except for `default_to_square` and `data_format` which are only set in fast processors self.assertTrue( all(value is None for key, value in difference.items() if key not in ["default_to_square", "data_format"]) ) # check that the remaining keys are the same self.assertEqual(dict_slow_0, dict_slow_1) dict_fast_0 = image_processor_fast_0.to_dict() dict_fast_1 = image_processor_fast_1.to_dict() difference = { key: dict_fast_0.get(key) if key in dict_fast_0 else dict_fast_1.get(key) for key in set(dict_fast_0) ^ set(dict_fast_1) } dict_fast_0 = {key: dict_fast_0[key] for key in set(dict_fast_0) & set(dict_fast_1)} dict_fast_1 = {key: dict_fast_1[key] for key in set(dict_fast_0) & set(dict_fast_1)} # check that all additional keys are None, except for `default_to_square` and `data_format` which are only set in fast processors self.assertTrue( all(value is None for key, value in difference.items() if key not in ["default_to_square", "data_format"]) ) # check that the remaining keys are the same self.assertEqual(dict_fast_0, dict_fast_1) def test_init_without_params(self): for image_processing_class in self.image_processor_list: image_processor = image_processing_class() self.assertIsNotNone(image_processor) @require_torch @require_vision def test_cast_dtype_device(self): for image_processing_class in self.image_processor_list: if self.test_cast_dtype is not None: # Initialize image_processor image_processor = image_processing_class(**self.image_processor_dict) # create random PyTorch tensors image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) encoding = image_processor(image_inputs, return_tensors="pt") # for layoutLM compatibility self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) self.assertEqual(encoding.pixel_values.dtype, torch.float32) encoding = image_processor(image_inputs, return_tensors="pt").to(torch.float16) self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) self.assertEqual(encoding.pixel_values.dtype, torch.float16) encoding = image_processor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16) with self.assertRaises(TypeError): _ = image_processor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu") # Try with text + image feature encoding = image_processor(image_inputs, return_tensors="pt") encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])}) encoding = encoding.to(torch.float16) self.assertEqual(encoding.pixel_values.device, torch.device("cpu")) self.assertEqual(encoding.pixel_values.dtype, torch.float16) self.assertEqual(encoding.input_ids.dtype, torch.long) def test_call_pil(self): for image_processing_class in self.image_processor_list: # Initialize image_processing image_processing = image_processing_class(**self.image_processor_dict) # create random PIL images image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) for image in image_inputs: self.assertIsInstance(image, Image.Image) # Test not batched input encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) # Test batched encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) self.assertEqual( tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) ) def test_call_numpy(self): for image_processing_class in self.image_processor_list: # Initialize image_processing image_processing = image_processing_class(**self.image_processor_dict) # create random numpy tensors image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) for image in image_inputs: self.assertIsInstance(image, np.ndarray) # Test not batched input encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) # Test batched encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) self.assertEqual( tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) ) def test_call_pytorch(self): for image_processing_class in self.image_processor_list: # Initialize image_processing image_processing = image_processing_class(**self.image_processor_dict) # create random PyTorch tensors image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) for image in image_inputs: self.assertIsInstance(image, torch.Tensor) # Test not batched input encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) # Test batched expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values self.assertEqual( tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape), ) def test_call_numpy_4_channels(self): for image_processing_class in self.image_processor_list: # Test that can process images which have an arbitrary number of channels # Initialize image_processing image_processor = image_processing_class(**self.image_processor_dict) # create random numpy tensors self.image_processor_tester.num_channels = 4 image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) # Test not batched input encoded_images = image_processor( image_inputs[0], return_tensors="pt", input_data_format="channels_last", image_mean=[0.0, 0.0, 0.0, 0.0], image_std=[1.0, 1.0, 1.0, 1.0], ).pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) # Test batched encoded_images = image_processor( image_inputs, return_tensors="pt", input_data_format="channels_last", image_mean=[0.0, 0.0, 0.0, 0.0], image_std=[1.0, 1.0, 1.0, 1.0], ).pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) self.assertEqual( tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) ) def test_image_processor_preprocess_arguments(self): is_tested = False for image_processing_class in self.image_processor_list: image_processor = image_processing_class(**self.image_processor_dict) # validation done by _valid_processor_keys attribute if hasattr(image_processor, "_valid_processor_keys") and hasattr(image_processor, "preprocess"): preprocess_parameter_names = inspect.getfullargspec(image_processor.preprocess).args preprocess_parameter_names.remove("self") preprocess_parameter_names.sort() valid_processor_keys = image_processor._valid_processor_keys valid_processor_keys.sort() self.assertEqual(preprocess_parameter_names, valid_processor_keys) is_tested = True # validation done by @filter_out_non_signature_kwargs decorator if hasattr(image_processor.preprocess, "_filter_out_non_signature_kwargs"): if hasattr(self.image_processor_tester, "prepare_image_inputs"): inputs = self.image_processor_tester.prepare_image_inputs() elif hasattr(self.image_processor_tester, "prepare_video_inputs"): inputs = self.image_processor_tester.prepare_video_inputs() else: self.skipTest(reason="No valid input preparation method found") with warnings.catch_warnings(record=True) as raised_warnings: warnings.simplefilter("always") image_processor(inputs, extra_argument=True) messages = " ".join([str(w.message) for w in raised_warnings]) self.assertGreaterEqual(len(raised_warnings), 1) self.assertIn("extra_argument", messages) is_tested = True if not is_tested: self.skipTest(reason="No validation found for `preprocess` method") def test_override_instance_attributes_does_not_affect_other_instances(self): if self.fast_image_processing_class is None: self.skipTest( "Only testing fast image processor, as most slow processors break this test and are to be deprecated" ) image_processing_class = self.fast_image_processing_class image_processor_1 = image_processing_class() image_processor_2 = image_processing_class() if not (hasattr(image_processor_1, "size") and isinstance(image_processor_1.size, dict)) or not ( hasattr(image_processor_1, "image_mean") and isinstance(image_processor_1.image_mean, list) ): self.skipTest( reason="Skipping test as the image processor does not have dict size or list image_mean attributes" ) original_size_2 = deepcopy(image_processor_2.size) for key in image_processor_1.size: image_processor_1.size[key] = -1 modified_copied_size_1 = deepcopy(image_processor_1.size) original_image_mean_2 = deepcopy(image_processor_2.image_mean) image_processor_1.image_mean[0] = -1 modified_copied_image_mean_1 = deepcopy(image_processor_1.image_mean) # check that the original attributes of the second instance are not affected self.assertEqual(image_processor_2.size, original_size_2) self.assertEqual(image_processor_2.image_mean, original_image_mean_2) for key in image_processor_2.size: image_processor_2.size[key] = -2 image_processor_2.image_mean[0] = -2 # check that the modified attributes of the first instance are not affected by the second instance self.assertEqual(image_processor_1.size, modified_copied_size_1) self.assertEqual(image_processor_1.image_mean, modified_copied_image_mean_1) @slow @require_torch_accelerator @require_vision @pytest.mark.torch_compile_test def test_can_compile_fast_image_processor(self): if self.fast_image_processing_class is None: self.skipTest("Skipping compilation test as fast image processor is not defined") if version.parse(torch.__version__) < version.parse("2.3"): self.skipTest(reason="This test requires torch >= 2.3 to run.") torch.compiler.reset() input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) image_processor = self.fast_image_processing_class(**self.image_processor_dict) output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") image_processor = torch.compile(image_processor, mode="reduce-overhead") output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt") self._assert_slow_fast_tensors_equivalence( output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4, rtol=1e-4, mean_atol=1e-5 ) def test_new_models_require_fast_image_processor(self): """ Test that new models have a fast image processor. For more information on how to implement a fast image processor, see this issue: https://github.com/huggingface/transformers/issues/36978, and ping @yonigozlan for help. """ if self.fast_image_processing_class is not None: return if self.image_processing_class is None: self.skipTest("No image processing class defined") def _is_old_model_by_commit_date(model_type, date_cutoff=(2025, 9, 1)): try: # Convert model_type to directory name and construct file path model_dir = model_type.replace("-", "_") slow_processor_file = f"src/transformers/models/{model_dir}/image_processing_{model_dir}.py" # Check if the file exists otherwise skip the test if not os.path.exists(slow_processor_file): return None # Get the first commit date of the slow processor file result = subprocess.run( ["git", "log", "--reverse", "--pretty=format:%ad", "--date=iso", slow_processor_file], capture_output=True, text=True, cwd=os.getcwd(), ) if result.returncode != 0 or not result.stdout.strip(): return None # Parse the first line (earliest commit) first_line = result.stdout.strip().split("\n")[0] date_part = first_line.split(" ")[0] # Extract just the date part commit_date = datetime.strptime(date_part, "%Y-%m-%d") # Check if committed before the cutoff date cutoff_date = datetime(*date_cutoff) return commit_date <= cutoff_date except Exception: # If any error occurs, skip the test return None image_processor_name = self.image_processing_class.__name__ model_type = None for mapping_model_type, (slow_class, _) in IMAGE_PROCESSOR_MAPPING_NAMES.items(): if slow_class == image_processor_name: model_type = mapping_model_type break if model_type is None: self.skipTest(f"Could not find model type for {image_processor_name} in IMAGE_PROCESSOR_MAPPING_NAMES") # Check if this is a new model (added after 2024-01-01) based on git history is_old_model = _is_old_model_by_commit_date(model_type) if is_old_model is None: self.skipTest(f"Could not determine if {model_type} is new based on git history") # New models must have fast processors self.assertTrue( is_old_model, f"Model '{model_type}' (processor: {image_processor_name}) was added after the cutoff date and must have " f"a fast image processor implementation. Please implement the corresponding fast processor.", ) class AnnotationFormatTestMixin: # this mixin adds a test to assert that usages of the # to-be-deprecated `AnnotionFormat` continue to be # supported for the time being def test_processor_can_use_legacy_annotation_format(self): image_processor_dict = self.image_processor_tester.prepare_image_processor_dict() fixtures_path = pathlib.Path(__file__).parent / "fixtures" / "tests_samples" / "COCO" with open(fixtures_path / "coco_annotations.txt") as f: detection_target = json.loads(f.read()) detection_annotations = {"image_id": 39769, "annotations": detection_target} detection_params = { "images": Image.open(fixtures_path / "000000039769.png"), "annotations": detection_annotations, "return_tensors": "pt", } with open(fixtures_path / "coco_panoptic_annotations.txt") as f: panoptic_target = json.loads(f.read()) panoptic_annotations = {"file_name": "000000039769.png", "image_id": 39769, "segments_info": panoptic_target} masks_path = pathlib.Path(fixtures_path / "coco_panoptic") panoptic_params = { "images": Image.open(fixtures_path / "000000039769.png"), "annotations": panoptic_annotations, "return_tensors": "pt", "masks_path": masks_path, } test_cases = [ ("coco_detection", detection_params), ("coco_panoptic", panoptic_params), (AnnotionFormat.COCO_DETECTION, detection_params), (AnnotionFormat.COCO_PANOPTIC, panoptic_params), (AnnotationFormat.COCO_DETECTION, detection_params), (AnnotationFormat.COCO_PANOPTIC, panoptic_params), ] def _compare(a, b) -> None: if isinstance(a, (dict, BatchFeature)): self.assertEqual(a.keys(), b.keys()) for k, v in a.items(): _compare(v, b[k]) elif isinstance(a, list): self.assertEqual(len(a), len(b)) for idx in range(len(a)): _compare(a[idx], b[idx]) elif isinstance(a, torch.Tensor): torch.testing.assert_close(a, b, rtol=1e-3, atol=1e-3) elif isinstance(a, str): self.assertEqual(a, b) for annotation_format, params in test_cases: with self.subTest(annotation_format): image_processor_params = {**image_processor_dict, **{"format": annotation_format}} image_processor_first = self.image_processing_class(**image_processor_params) with tempfile.TemporaryDirectory() as tmpdirname: image_processor_first.save_pretrained(tmpdirname) image_processor_second = self.image_processing_class.from_pretrained(tmpdirname) # check the 'format' key exists and that the dicts of the # first and second processors are equal self.assertIn("format", image_processor_first.to_dict().keys()) self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict()) # perform encoding using both processors and compare # the resulting BatchFeatures first_encoding = image_processor_first(**params) second_encoding = image_processor_second(**params) _compare(first_encoding, second_encoding)