Adding superglue fast image processing (#41394)

* Default implementation - no time improvement

* Improved implementation - apparently 2 times faster with only simple function refactor

* elementary torch first approach, still need further implementation of torch first method

* torch-first approach finished

* refactor processor

* refactor test

* partial doc update

* EfficientLoFTRImageProcessorFast based implementation

* EfficientLoFTRImageProcessorFast based implementation

* Logic checked - Test Passed - Validated execution speed

* use modular for efficientloftr

* fix import

---------

Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Julien
2025-10-16 21:34:09 +02:00
committed by GitHub
parent 4dd4133d32
commit 354567d955
10 changed files with 426 additions and 108 deletions

View File

@ -88,16 +88,16 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
import torch
from PIL import Image
import requests
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
# SuperGlue requires pairs of images
images = [image1, image2]
inputs = processor(images, return_tensors="pt")
with torch.inference_mode():
outputs = model(**inputs)
# Extract matching information
keypoints0 = outputs.keypoints0 # Keypoints in first image
keypoints1 = outputs.keypoints1 # Keypoints in second image
@ -112,7 +112,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
# Process outputs for visualization
image_sizes = [[(image.height, image.width) for image in images]]
processed_outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
for i, output in enumerate(processed_outputs):
print(f"For the image pair {i}")
for keypoint0, keypoint1, matching_score in zip(
@ -147,6 +147,13 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- post_process_keypoint_matching
- visualize_keypoint_matching
## SuperGlueImageProcessorFast
[[autodoc]] SuperGlueImageProcessorFast
- preprocess
- post_process_keypoint_matching
- visualize_keypoint_matching
## SuperGlueForKeypointMatching
[[autodoc]] SuperGlueForKeypointMatching

View File

@ -171,7 +171,7 @@ else:
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
("superglue", ("SuperGlueImageProcessor", None)),
("superglue", ("SuperGlueImageProcessor", "SuperGlueImageProcessorFast")),
("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")),
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),

View File

@ -1,30 +1,17 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for EfficientLoFTR."""
from typing import TYPE_CHECKING, Optional, Union
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/efficientloftr/modular_efficientloftr.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_efficientloftr.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Optional, Union
import torch
from PIL import Image, ImageDraw
from torchvision.transforms.v2 import functional as F
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
group_images_by_shape,
reorder_images,
)
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import (
ImageInput,
ImageType,
@ -35,17 +22,9 @@ from ...image_utils import (
is_valid_image,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
auto_docstring,
)
from ...utils import TensorType, auto_docstring
from .image_processing_efficientloftr import EfficientLoFTRImageProcessorKwargs
if TYPE_CHECKING:
from .modeling_efficientloftr import KeypointMatchingOutput
import torchvision.transforms.v2.functional as F
from .modeling_efficientloftr import KeypointMatchingOutput
def _is_valid_image(image):
@ -299,7 +278,7 @@ class EfficientLoFTRImageProcessorFast(BaseImageProcessorFast):
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return (r, g, b)
return r, g, b
__all__ = ["EfficientLoFTRImageProcessorFast"]

View File

@ -0,0 +1,8 @@
from ..superglue.image_processing_superglue_fast import SuperGlueImageProcessorFast
class EfficientLoFTRImageProcessorFast(SuperGlueImageProcessorFast):
pass
__all__ = ["EfficientLoFTRImageProcessorFast"]

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_superglue import *
from .image_processing_superglue import *
from .image_processing_superglue_fast import *
from .modeling_superglue import *
else:
import sys

View File

@ -35,6 +35,7 @@ from ...image_utils import (
valid_images,
validate_preprocess_arguments,
)
from ...processing_utils import ImagesKwargs
from ...utils import TensorType, logging, requires_backends
from ...utils.import_utils import requires
@ -133,6 +134,15 @@ def validate_and_format_image_pairs(images: ImageInput):
raise ValueError(error_message)
class SuperGlueImageProcessorKwargs(ImagesKwargs, total=False):
r"""
do_grayscale (`bool`, *optional*, defaults to `True`):
Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
"""
do_grayscale: bool
@requires(backends=("torch",))
class SuperGlueImageProcessor(BaseImageProcessor):
r"""

View File

@ -0,0 +1,292 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from PIL import Image, ImageDraw
from torchvision.transforms.v2 import functional as F
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import (
ImageInput,
ImageType,
PILImageResampling,
SizeDict,
get_image_type,
is_pil_image,
is_valid_image,
)
from ...processing_utils import Unpack
from ...utils import TensorType, auto_docstring
from .image_processing_superglue import SuperGlueImageProcessorKwargs
from .modeling_superglue import KeypointMatchingOutput
def _is_valid_image(image):
return is_pil_image(image) or (
is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
)
def flatten_pair_images(images):
# Handle the pair validation and flattening similar to slow processor
if isinstance(images, list):
if len(images) == 2 and all((_is_valid_image(image) or isinstance(image, torch.Tensor)) for image in images):
# Single pair of images - keep as is, they'll be processed by the base class
return images
elif all(
isinstance(image_pair, list)
and len(image_pair) == 2
and all(_is_valid_image(image) or isinstance(image, torch.Tensor) for image in image_pair)
for image_pair in images
):
# Multiple pairs - flatten them
images = [image for image_pair in images for image in image_pair]
return images
raise ValueError(
"Input images must be a one of the following :",
" - A pair of PIL images.",
" - A pair of 3D arrays.",
" - A list of pairs of PIL images.",
" - A list of pairs of 3D arrays.",
)
def is_grayscale(
image: "torch.Tensor",
):
"""Checks if an image is grayscale (all RGB channels are identical)."""
if image.ndim < 3 or image.shape[0 if image.ndim == 3 else 1] == 1:
return True
return torch.all(image[..., 0, :, :] == image[..., 1, :, :]) and torch.all(
image[..., 1, :, :] == image[..., 2, :, :]
)
def convert_to_grayscale(
image: "torch.Tensor",
) -> "torch.Tensor":
"""
Converts an image to grayscale format using the NTSC formula. Only support torch.Tensor.
This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
channel, because of an issue that is discussed in :
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
Args:
image (torch.Tensor):
The image to convert.
"""
if is_grayscale(image):
return image
return F.rgb_to_grayscale(image, num_output_channels=3)
@auto_docstring
class SuperGlueImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
size = {"height": 480, "width": 640}
default_to_square = False
do_resize = True
do_rescale = True
rescale_factor = 1 / 255
do_normalize = None
valid_kwargs = SuperGlueImageProcessorKwargs
def __init__(self, **kwargs: Unpack[SuperGlueImageProcessorKwargs]):
super().__init__(**kwargs)
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[SuperGlueImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
**kwargs,
) -> ImageInput:
# we need to handle image pairs validation and flattening
return flatten_pair_images(images)
def _preprocess(
self,
images: list["torch.Tensor"],
size: Union[dict[str, int], SizeDict],
rescale_factor: float,
do_rescale: bool,
do_resize: bool,
interpolation: Optional["F.InterpolationMode"],
do_grayscale: bool,
disable_grouping: bool,
return_tensors: Union[str, TensorType],
**kwargs,
) -> BatchFeature:
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(stacked_images, size=size, interpolation=interpolation)
processed_images_grouped[shape] = stacked_images
resized_images = reorder_images(processed_images_grouped, grouped_images_index)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_rescale:
stacked_images = self.rescale(stacked_images, rescale_factor)
if do_grayscale:
stacked_images = convert_to_grayscale(stacked_images)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
# Convert back to pairs format
image_pairs = [processed_images[i : i + 2] for i in range(0, len(processed_images), 2)]
# Stack each pair into a single tensor to match slow processor format
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
# Return in same format as slow processor
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
return BatchFeature(data={"pixel_values": image_pairs})
def post_process_keypoint_matching(
self,
outputs: "KeypointMatchingOutput",
target_sizes: Union[TensorType, list[tuple]],
threshold: float = 0.0,
) -> list[dict[str, torch.Tensor]]:
"""
Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors
with coordinates absolute to the original image sizes.
Args:
outputs ([`KeypointMatchingOutput`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
target size `(height, width)` of each image in the batch. This must be the original image size (before
any processing).
threshold (`float`, *optional*, defaults to 0.0):
Threshold to filter out the matches with low scores.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
of the pair, the matching scores and the matching indices.
"""
if outputs.matches.shape[0] != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
if not all(len(target_size) == 2 for target_size in target_sizes):
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
if isinstance(target_sizes, list):
image_pair_sizes = torch.tensor(target_sizes, device=outputs.matches.device)
else:
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
raise ValueError(
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
)
image_pair_sizes = target_sizes
keypoints = outputs.keypoints.clone()
keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
keypoints = keypoints.to(torch.int32)
results = []
for keypoints_pair, matches, scores in zip(keypoints, outputs.matches, outputs.matching_scores):
# Filter out matches with low scores
valid_matches = torch.logical_and(scores > threshold, matches > -1)
matched_keypoints0 = keypoints_pair[0][valid_matches[0]]
matched_keypoints1 = keypoints_pair[1][valid_matches[1]]
matching_scores = scores[0][valid_matches[0]]
results.append(
{
"keypoints0": matched_keypoints0,
"keypoints1": matched_keypoints1,
"matching_scores": matching_scores,
}
)
return results
def visualize_keypoint_matching(
self,
images,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images:
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
from ...image_utils import to_numpy_array
from .image_processing_superglue import validate_and_format_image_pairs
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = torch.zeros((max(height0, height1), width0 + width1, 3), dtype=torch.uint8)
plot_image[:height0, :width0] = torch.from_numpy(image_pair[0])
plot_image[:height1, width0:] = torch.from_numpy(image_pair[1])
plot_image_pil = Image.fromarray(plot_image.numpy())
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return r, g, b
__all__ = ["SuperGlueImageProcessorFast"]

View File

@ -15,8 +15,6 @@ import time
import unittest
import numpy as np
import pytest
from packaging import version
from tests.models.superglue.test_image_processing_superglue import (
SuperGlueImageProcessingTest,
@ -24,10 +22,7 @@ from tests.models.superglue.test_image_processing_superglue import (
)
from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -103,46 +98,6 @@ class EfficientLoFTRImageProcessingTest(SuperGlueImageProcessingTest, unittest.T
super().setUp()
self.image_processor_tester = EfficientLoFTRImageProcessingTester(self)
def test_slow_fast_equivalence(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
# Create image pairs instead of single images
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=False)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
def test_slow_fast_equivalence_batched(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
# Create image pairs instead of single images
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
@unittest.skip(reason="Many failing cases. This test needs a more deep investigation.")
def test_fast_is_faster_than_slow(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
@ -173,25 +128,3 @@ class EfficientLoFTRImageProcessingTest(SuperGlueImageProcessingTest, unittest.T
self.assertLessEqual(
fast_time, slow_time * 1.2, "Fast processor should not be significantly slower than slow processor"
)
@slow
@require_torch_accelerator
@require_vision
@pytest.mark.torch_compile_test
def test_can_compile_fast_image_processor(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if self.fast_image_processing_class is None:
self.skipTest("Skipping compilation test as fast image processor is not defined")
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
torch.compiler.reset()
input_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=False)
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
image_processor = torch.compile(image_processor, mode="reduce-overhead")
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(
output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4, rtol=1e-4, mean_atol=1e-5
)

View File

@ -90,6 +90,7 @@ class LightGlueImageProcessingTester(SuperGlueImageProcessingTester):
@require_vision
class LightGlueImageProcessingTest(SuperGlueImageProcessingTest, unittest.TestCase):
image_processing_class = LightGlueImageProcessor if is_vision_available() else None
fast_image_processing_class = None
def setUp(self) -> None:
super().setUp()

View File

@ -11,12 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import unittest
import numpy as np
import pytest
from packaging import version
from parameterized import parameterized
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import (
ImageProcessingTestMixin,
@ -33,6 +43,9 @@ if is_torch_available():
if is_vision_available():
from transformers import SuperGlueImageProcessor
if is_torchvision_available():
from transformers import SuperGlueImageProcessorFast
def random_array(size):
return np.random.randint(255, size=size)
@ -119,6 +132,7 @@ class SuperGlueImageProcessingTester:
@require_vision
class SuperGlueImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SuperGlueImageProcessor if is_vision_available() else None
fast_image_processing_class = SuperGlueImageProcessorFast if is_torchvision_available() else None
def setUp(self) -> None:
super().setUp()
@ -397,3 +411,76 @@ class SuperGlueImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
tensor_post_processed_outputs = image_processor.post_process_keypoint_matching(outputs, tensor_image_sizes)
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
@unittest.skip(reason="Many failing cases. This test needs a more deep investigation.")
def test_fast_is_faster_than_slow(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast speed test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast speed test as one of the image processors is not defined")
# Create image pairs for speed test
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=False)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
# Time slow processor
start_time = time.time()
for _ in range(10):
_ = image_processor_slow(dummy_images, return_tensors="pt")
slow_time = time.time() - start_time
# Time fast processor
start_time = time.time()
for _ in range(10):
_ = image_processor_fast(dummy_images, return_tensors="pt")
fast_time = time.time() - start_time
# Fast should be faster (or at least not significantly slower)
self.assertLessEqual(
fast_time, slow_time * 1.2, "Fast processor should not be significantly slower than slow processor"
)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image = self.image_processor_tester.prepare_image_inputs(
equal_resolution=False, numpify=True, batch_size=2, pairs=False
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
@slow
@require_torch_accelerator
@require_vision
@pytest.mark.torch_compile_test
def test_can_compile_fast_image_processor(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if self.fast_image_processing_class is None:
self.skipTest("Skipping compilation test as fast image processor is not defined")
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
torch.compiler.reset()
input_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=False)
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
image_processor = torch.compile(image_processor, mode="reduce-overhead")
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(
output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4, rtol=1e-4, mean_atol=1e-5
)