mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Adding superglue fast image processing (#41394)
* Default implementation - no time improvement * Improved implementation - apparently 2 times faster with only simple function refactor * elementary torch first approach, still need further implementation of torch first method * torch-first approach finished * refactor processor * refactor test * partial doc update * EfficientLoFTRImageProcessorFast based implementation * EfficientLoFTRImageProcessorFast based implementation * Logic checked - Test Passed - Validated execution speed * use modular for efficientloftr * fix import --------- Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
@ -88,16 +88,16 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
|
||||
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
|
||||
|
||||
|
||||
# SuperGlue requires pairs of images
|
||||
images = [image1, image2]
|
||||
inputs = processor(images, return_tensors="pt")
|
||||
with torch.inference_mode():
|
||||
outputs = model(**inputs)
|
||||
|
||||
|
||||
# Extract matching information
|
||||
keypoints0 = outputs.keypoints0 # Keypoints in first image
|
||||
keypoints1 = outputs.keypoints1 # Keypoints in second image
|
||||
@ -112,7 +112,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
# Process outputs for visualization
|
||||
image_sizes = [[(image.height, image.width) for image in images]]
|
||||
processed_outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
|
||||
|
||||
|
||||
for i, output in enumerate(processed_outputs):
|
||||
print(f"For the image pair {i}")
|
||||
for keypoint0, keypoint1, matching_score in zip(
|
||||
@ -147,6 +147,13 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
- post_process_keypoint_matching
|
||||
- visualize_keypoint_matching
|
||||
|
||||
## SuperGlueImageProcessorFast
|
||||
|
||||
[[autodoc]] SuperGlueImageProcessorFast
|
||||
- preprocess
|
||||
- post_process_keypoint_matching
|
||||
- visualize_keypoint_matching
|
||||
|
||||
## SuperGlueForKeypointMatching
|
||||
|
||||
[[autodoc]] SuperGlueForKeypointMatching
|
||||
|
@ -171,7 +171,7 @@ else:
|
||||
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
|
||||
("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
|
||||
("superglue", ("SuperGlueImageProcessor", None)),
|
||||
("superglue", ("SuperGlueImageProcessor", "SuperGlueImageProcessorFast")),
|
||||
("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")),
|
||||
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
|
@ -1,30 +1,17 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for EfficientLoFTR."""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/efficientloftr/modular_efficientloftr.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_efficientloftr.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature
|
||||
from ...image_transforms import group_images_by_shape, reorder_images
|
||||
from ...image_utils import (
|
||||
ImageInput,
|
||||
ImageType,
|
||||
@ -35,17 +22,9 @@ from ...image_utils import (
|
||||
is_valid_image,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
)
|
||||
from ...utils import TensorType, auto_docstring
|
||||
from .image_processing_efficientloftr import EfficientLoFTRImageProcessorKwargs
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_efficientloftr import KeypointMatchingOutput
|
||||
|
||||
import torchvision.transforms.v2.functional as F
|
||||
from .modeling_efficientloftr import KeypointMatchingOutput
|
||||
|
||||
|
||||
def _is_valid_image(image):
|
||||
@ -299,7 +278,7 @@ class EfficientLoFTRImageProcessorFast(BaseImageProcessorFast):
|
||||
r = int(255 * (1 - score))
|
||||
g = int(255 * score)
|
||||
b = 0
|
||||
return (r, g, b)
|
||||
return r, g, b
|
||||
|
||||
|
||||
__all__ = ["EfficientLoFTRImageProcessorFast"]
|
||||
|
@ -0,0 +1,8 @@
|
||||
from ..superglue.image_processing_superglue_fast import SuperGlueImageProcessorFast
|
||||
|
||||
|
||||
class EfficientLoFTRImageProcessorFast(SuperGlueImageProcessorFast):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["EfficientLoFTRImageProcessorFast"]
|
@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_superglue import *
|
||||
from .image_processing_superglue import *
|
||||
from .image_processing_superglue_fast import *
|
||||
from .modeling_superglue import *
|
||||
else:
|
||||
import sys
|
||||
|
@ -35,6 +35,7 @@ from ...image_utils import (
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...processing_utils import ImagesKwargs
|
||||
from ...utils import TensorType, logging, requires_backends
|
||||
from ...utils.import_utils import requires
|
||||
|
||||
@ -133,6 +134,15 @@ def validate_and_format_image_pairs(images: ImageInput):
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
class SuperGlueImageProcessorKwargs(ImagesKwargs, total=False):
|
||||
r"""
|
||||
do_grayscale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
|
||||
"""
|
||||
|
||||
do_grayscale: bool
|
||||
|
||||
|
||||
@requires(backends=("torch",))
|
||||
class SuperGlueImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
|
@ -0,0 +1,292 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature
|
||||
from ...image_transforms import group_images_by_shape, reorder_images
|
||||
from ...image_utils import (
|
||||
ImageInput,
|
||||
ImageType,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_type,
|
||||
is_pil_image,
|
||||
is_valid_image,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TensorType, auto_docstring
|
||||
from .image_processing_superglue import SuperGlueImageProcessorKwargs
|
||||
from .modeling_superglue import KeypointMatchingOutput
|
||||
|
||||
|
||||
def _is_valid_image(image):
|
||||
return is_pil_image(image) or (
|
||||
is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
|
||||
)
|
||||
|
||||
|
||||
def flatten_pair_images(images):
|
||||
# Handle the pair validation and flattening similar to slow processor
|
||||
if isinstance(images, list):
|
||||
if len(images) == 2 and all((_is_valid_image(image) or isinstance(image, torch.Tensor)) for image in images):
|
||||
# Single pair of images - keep as is, they'll be processed by the base class
|
||||
return images
|
||||
elif all(
|
||||
isinstance(image_pair, list)
|
||||
and len(image_pair) == 2
|
||||
and all(_is_valid_image(image) or isinstance(image, torch.Tensor) for image in image_pair)
|
||||
for image_pair in images
|
||||
):
|
||||
# Multiple pairs - flatten them
|
||||
images = [image for image_pair in images for image in image_pair]
|
||||
return images
|
||||
raise ValueError(
|
||||
"Input images must be a one of the following :",
|
||||
" - A pair of PIL images.",
|
||||
" - A pair of 3D arrays.",
|
||||
" - A list of pairs of PIL images.",
|
||||
" - A list of pairs of 3D arrays.",
|
||||
)
|
||||
|
||||
|
||||
def is_grayscale(
|
||||
image: "torch.Tensor",
|
||||
):
|
||||
"""Checks if an image is grayscale (all RGB channels are identical)."""
|
||||
if image.ndim < 3 or image.shape[0 if image.ndim == 3 else 1] == 1:
|
||||
return True
|
||||
return torch.all(image[..., 0, :, :] == image[..., 1, :, :]) and torch.all(
|
||||
image[..., 1, :, :] == image[..., 2, :, :]
|
||||
)
|
||||
|
||||
|
||||
def convert_to_grayscale(
|
||||
image: "torch.Tensor",
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Converts an image to grayscale format using the NTSC formula. Only support torch.Tensor.
|
||||
|
||||
This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
|
||||
channel, because of an issue that is discussed in :
|
||||
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
|
||||
|
||||
Args:
|
||||
image (torch.Tensor):
|
||||
The image to convert.
|
||||
"""
|
||||
if is_grayscale(image):
|
||||
return image
|
||||
return F.rgb_to_grayscale(image, num_output_channels=3)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SuperGlueImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
size = {"height": 480, "width": 640}
|
||||
default_to_square = False
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
rescale_factor = 1 / 255
|
||||
do_normalize = None
|
||||
valid_kwargs = SuperGlueImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[SuperGlueImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[SuperGlueImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
**kwargs,
|
||||
) -> ImageInput:
|
||||
# we need to handle image pairs validation and flattening
|
||||
return flatten_pair_images(images)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
size: Union[dict[str, int], SizeDict],
|
||||
rescale_factor: float,
|
||||
do_rescale: bool,
|
||||
do_resize: bool,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_grayscale: bool,
|
||||
disable_grouping: bool,
|
||||
return_tensors: Union[str, TensorType],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(stacked_images, size=size, interpolation=interpolation)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_rescale:
|
||||
stacked_images = self.rescale(stacked_images, rescale_factor)
|
||||
if do_grayscale:
|
||||
stacked_images = convert_to_grayscale(stacked_images)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
|
||||
# Convert back to pairs format
|
||||
image_pairs = [processed_images[i : i + 2] for i in range(0, len(processed_images), 2)]
|
||||
|
||||
# Stack each pair into a single tensor to match slow processor format
|
||||
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
|
||||
|
||||
# Return in same format as slow processor
|
||||
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
|
||||
|
||||
return BatchFeature(data={"pixel_values": image_pairs})
|
||||
|
||||
def post_process_keypoint_matching(
|
||||
self,
|
||||
outputs: "KeypointMatchingOutput",
|
||||
target_sizes: Union[TensorType, list[tuple]],
|
||||
threshold: float = 0.0,
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors
|
||||
with coordinates absolute to the original image sizes.
|
||||
Args:
|
||||
outputs ([`KeypointMatchingOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
|
||||
target size `(height, width)` of each image in the batch. This must be the original image size (before
|
||||
any processing).
|
||||
threshold (`float`, *optional*, defaults to 0.0):
|
||||
Threshold to filter out the matches with low scores.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
|
||||
of the pair, the matching scores and the matching indices.
|
||||
"""
|
||||
if outputs.matches.shape[0] != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
|
||||
if not all(len(target_size) == 2 for target_size in target_sizes):
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
if isinstance(target_sizes, list):
|
||||
image_pair_sizes = torch.tensor(target_sizes, device=outputs.matches.device)
|
||||
else:
|
||||
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
|
||||
raise ValueError(
|
||||
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
|
||||
)
|
||||
image_pair_sizes = target_sizes
|
||||
|
||||
keypoints = outputs.keypoints.clone()
|
||||
keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
|
||||
keypoints = keypoints.to(torch.int32)
|
||||
|
||||
results = []
|
||||
for keypoints_pair, matches, scores in zip(keypoints, outputs.matches, outputs.matching_scores):
|
||||
# Filter out matches with low scores
|
||||
valid_matches = torch.logical_and(scores > threshold, matches > -1)
|
||||
|
||||
matched_keypoints0 = keypoints_pair[0][valid_matches[0]]
|
||||
matched_keypoints1 = keypoints_pair[1][valid_matches[1]]
|
||||
matching_scores = scores[0][valid_matches[0]]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"keypoints0": matched_keypoints0,
|
||||
"keypoints1": matched_keypoints1,
|
||||
"matching_scores": matching_scores,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def visualize_keypoint_matching(
|
||||
self,
|
||||
images,
|
||||
keypoint_matching_output: list[dict[str, torch.Tensor]],
|
||||
) -> list["Image.Image"]:
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
|
||||
|
||||
Args:
|
||||
images:
|
||||
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
|
||||
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
|
||||
A post processed keypoint matching output
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
|
||||
keypoints as well as the matching between them.
|
||||
"""
|
||||
from ...image_utils import to_numpy_array
|
||||
from .image_processing_superglue import validate_and_format_image_pairs
|
||||
|
||||
images = validate_and_format_image_pairs(images)
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
|
||||
|
||||
results = []
|
||||
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
|
||||
height0, width0 = image_pair[0].shape[:2]
|
||||
height1, width1 = image_pair[1].shape[:2]
|
||||
plot_image = torch.zeros((max(height0, height1), width0 + width1, 3), dtype=torch.uint8)
|
||||
plot_image[:height0, :width0] = torch.from_numpy(image_pair[0])
|
||||
plot_image[:height1, width0:] = torch.from_numpy(image_pair[1])
|
||||
|
||||
plot_image_pil = Image.fromarray(plot_image.numpy())
|
||||
draw = ImageDraw.Draw(plot_image_pil)
|
||||
|
||||
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
|
||||
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
|
||||
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
|
||||
):
|
||||
color = self._get_color(matching_score)
|
||||
draw.line(
|
||||
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
|
||||
fill=color,
|
||||
width=3,
|
||||
)
|
||||
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
|
||||
draw.ellipse(
|
||||
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
results.append(plot_image_pil)
|
||||
return results
|
||||
|
||||
def _get_color(self, score):
|
||||
"""Maps a score to a color."""
|
||||
r = int(255 * (1 - score))
|
||||
g = int(255 * score)
|
||||
b = 0
|
||||
return r, g, b
|
||||
|
||||
|
||||
__all__ = ["SuperGlueImageProcessorFast"]
|
@ -15,8 +15,6 @@ import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from packaging import version
|
||||
|
||||
from tests.models.superglue.test_image_processing_superglue import (
|
||||
SuperGlueImageProcessingTest,
|
||||
@ -24,10 +22,7 @@ from tests.models.superglue.test_image_processing_superglue import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
@ -103,46 +98,6 @@ class EfficientLoFTRImageProcessingTest(SuperGlueImageProcessingTest, unittest.T
|
||||
super().setUp()
|
||||
self.image_processor_tester = EfficientLoFTRImageProcessingTester(self)
|
||||
|
||||
def test_slow_fast_equivalence(self):
|
||||
"""Override the generic test since EfficientLoFTR requires image pairs."""
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
# Create image pairs instead of single images
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=False)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
"""Override the generic test since EfficientLoFTR requires image pairs."""
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
||||
self.skipTest(
|
||||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
||||
)
|
||||
|
||||
# Create image pairs instead of single images
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
@unittest.skip(reason="Many failing cases. This test needs a more deep investigation.")
|
||||
def test_fast_is_faster_than_slow(self):
|
||||
"""Override the generic test since EfficientLoFTR requires image pairs."""
|
||||
@ -173,25 +128,3 @@ class EfficientLoFTRImageProcessingTest(SuperGlueImageProcessingTest, unittest.T
|
||||
self.assertLessEqual(
|
||||
fast_time, slow_time * 1.2, "Fast processor should not be significantly slower than slow processor"
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@require_vision
|
||||
@pytest.mark.torch_compile_test
|
||||
def test_can_compile_fast_image_processor(self):
|
||||
"""Override the generic test since EfficientLoFTR requires image pairs."""
|
||||
if self.fast_image_processing_class is None:
|
||||
self.skipTest("Skipping compilation test as fast image processor is not defined")
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
torch.compiler.reset()
|
||||
input_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=False)
|
||||
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
|
||||
|
||||
image_processor = torch.compile(image_processor, mode="reduce-overhead")
|
||||
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4, rtol=1e-4, mean_atol=1e-5
|
||||
)
|
||||
|
@ -90,6 +90,7 @@ class LightGlueImageProcessingTester(SuperGlueImageProcessingTester):
|
||||
@require_vision
|
||||
class LightGlueImageProcessingTest(SuperGlueImageProcessingTest, unittest.TestCase):
|
||||
image_processing_class = LightGlueImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = None
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
@ -11,12 +11,22 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import (
|
||||
ImageProcessingTestMixin,
|
||||
@ -33,6 +43,9 @@ if is_torch_available():
|
||||
if is_vision_available():
|
||||
from transformers import SuperGlueImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import SuperGlueImageProcessorFast
|
||||
|
||||
|
||||
def random_array(size):
|
||||
return np.random.randint(255, size=size)
|
||||
@ -119,6 +132,7 @@ class SuperGlueImageProcessingTester:
|
||||
@require_vision
|
||||
class SuperGlueImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = SuperGlueImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = SuperGlueImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
@ -397,3 +411,76 @@ class SuperGlueImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tensor_post_processed_outputs = image_processor.post_process_keypoint_matching(outputs, tensor_image_sizes)
|
||||
|
||||
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
|
||||
|
||||
@unittest.skip(reason="Many failing cases. This test needs a more deep investigation.")
|
||||
def test_fast_is_faster_than_slow(self):
|
||||
"""Override the generic test since EfficientLoFTR requires image pairs."""
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast speed test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast speed test as one of the image processors is not defined")
|
||||
|
||||
# Create image pairs for speed test
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=False)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# Time slow processor
|
||||
start_time = time.time()
|
||||
for _ in range(10):
|
||||
_ = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
slow_time = time.time() - start_time
|
||||
|
||||
# Time fast processor
|
||||
start_time = time.time()
|
||||
for _ in range(10):
|
||||
_ = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
fast_time = time.time() - start_time
|
||||
|
||||
# Fast should be faster (or at least not significantly slower)
|
||||
self.assertLessEqual(
|
||||
fast_time, slow_time * 1.2, "Fast processor should not be significantly slower than slow processor"
|
||||
)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
dummy_image = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=False, numpify=True, batch_size=2, pairs=False
|
||||
)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@require_vision
|
||||
@pytest.mark.torch_compile_test
|
||||
def test_can_compile_fast_image_processor(self):
|
||||
"""Override the generic test since EfficientLoFTR requires image pairs."""
|
||||
if self.fast_image_processing_class is None:
|
||||
self.skipTest("Skipping compilation test as fast image processor is not defined")
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
torch.compiler.reset()
|
||||
input_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=False)
|
||||
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
|
||||
|
||||
image_processor = torch.compile(image_processor, mode="reduce-overhead")
|
||||
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4, rtol=1e-4, mean_atol=1e-5
|
||||
)
|
||||
|
Reference in New Issue
Block a user