Compare commits

...

4 Commits

Author SHA1 Message Date
297d8549fd add version 2025-10-17 09:29:33 +00:00
1eb45cd61d Fix ckpt in docs (#41659)
* fix ckpt in docs

* fix config ckpt
2025-10-17 11:00:34 +02:00
354567d955 Adding superglue fast image processing (#41394)
* Default implementation - no time improvement

* Improved implementation - apparently 2 times faster with only simple function refactor

* elementary torch first approach, still need further implementation of torch first method

* torch-first approach finished

* refactor processor

* refactor test

* partial doc update

* EfficientLoFTRImageProcessorFast based implementation

* EfficientLoFTRImageProcessorFast based implementation

* Logic checked - Test Passed - Validated execution speed

* use modular for efficientloftr

* fix import

---------

Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
2025-10-16 19:34:09 +00:00
4dd4133d32 🌐 [i18n-KO] Translated ko-LFM2.md to Korean (#41502)
* feat: nmt draft

* fix: manual edits

* Update docs/source/ko/model_doc/lfm2.md

Co-authored-by: Yijun Lee <119404328+yijun-lee@users.noreply.github.com>

* Update docs/source/ko/model_doc/lfm2.md

Co-authored-by: Ahnjj_DEV <ahnjj.dev@gmail.com>

* Update docs/source/ko/model_doc/lfm2.md

Co-authored-by: Ahnjj_DEV <ahnjj.dev@gmail.com>

* Update docs/source/ko/model_doc/lfm2.md

Co-authored-by: Ahnjj_DEV <ahnjj.dev@gmail.com>

---------

Co-authored-by: Yijun Lee <119404328+yijun-lee@users.noreply.github.com>
Co-authored-by: Ahnjj_DEV <ahnjj.dev@gmail.com>
2025-10-16 11:29:04 -07:00
17 changed files with 527 additions and 122 deletions

View File

@ -70,8 +70,8 @@ from transformers import AutoProcessor, Florence2ForConditionalGeneration
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-base", dtype=torch.bfloat16, device_map="auto")
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base")
model = Florence2ForConditionalGeneration.from_pretrained("florence-community/Florence-2-base", dtype=torch.bfloat16, device_map="auto")
processor = AutoProcessor.from_pretrained("florence-community/Florence-2-base")
task_prompt = "<OD>"
inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(model.device)
@ -105,12 +105,12 @@ from transformers import AutoProcessor, Florence2ForConditionalGeneration, BitsA
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = Florence2ForConditionalGeneration.from_pretrained(
"microsoft/Florence-2-large",
"florence-community/Florence-2-base",
dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config
)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
processor = AutoProcessor.from_pretrained("florence-community/Florence-2-base")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

View File

@ -88,16 +88,16 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
import torch
from PIL import Image
import requests
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
# SuperGlue requires pairs of images
images = [image1, image2]
inputs = processor(images, return_tensors="pt")
with torch.inference_mode():
outputs = model(**inputs)
# Extract matching information
keypoints0 = outputs.keypoints0 # Keypoints in first image
keypoints1 = outputs.keypoints1 # Keypoints in second image
@ -112,7 +112,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
# Process outputs for visualization
image_sizes = [[(image.height, image.width) for image in images]]
processed_outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
for i, output in enumerate(processed_outputs):
print(f"For the image pair {i}")
for keypoint0, keypoint1, matching_score in zip(
@ -147,6 +147,13 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- post_process_keypoint_matching
- visualize_keypoint_matching
## SuperGlueImageProcessorFast
[[autodoc]] SuperGlueImageProcessorFast
- preprocess
- post_process_keypoint_matching
- visualize_keypoint_matching
## SuperGlueForKeypointMatching
[[autodoc]] SuperGlueForKeypointMatching

View File

@ -603,7 +603,7 @@
title: Jukebox
- local: in_translation
title: LED
- local: in_translation
- local: model_doc/lfm2
title: LFM2
- local: in_translation
title: LFM2-VL

View File

@ -0,0 +1,85 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
*이 모델은 2025년 7월 10일에 출시되었으며, 2025년 7월 10일에 Hugging Face Transformers에 추가되었습니다.*
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
# LFM2
## 개요[[overview]]
[LFM2](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)는 Liquid AI가 개발한 차세대 Liquid Foundation Model로 egde AI와 온디바이스 배포에 특화되어 설계되었습니다.
이 모델들은 350M, 700M, 1.2B, 2.6B의 네 가지 크기의 매개변수로 제공되며, CPU, GPU, NPU 하드웨어에서 효율적으로 실행되도록 설계되었습니다. 이로 인해 특히 낮은 지연 시간, 오프라인 작동 및 개인 정보 보호가 필요한 애플리케이션에 적합합니다.
## 아키텍처[[architecture]]
아키텍처는 게이트가 있는 짧은 합성곱 블록과 QK 레이어 정규화가 적용된 그룹 쿼리 어텐션 블록으로 구성됩니다. 이 설계는 선형 연산이 입력 의존적인 게이트에 의해 조절되는 동적 시스템 개념에서 비롯되었습니다. 짧은 합성곱은 특히 임베디드 SoC CPU에 최적화되어 있어, 클라우드 연결에 의존하지 않고 빠르고 로컬화된 추론이 필요한 장치에 이상적입니다.
LFM2는 제한된 속도와 메모리 환경에서 품질을 최대화되도록 설계되었습니다. 이는 퀄컴 스냅드래곤 프로세서에서 실제 최대 메모리 사용량과 추론 속도를 측정하여, 임베디드 하드웨어에서의 실제 성능에 맞게 모델을 최적화하기 위한 체계적인 아키텍처 탐색을 통해 달성되었습니다. 그 결과, 비슷한 크기의 모델에 비해 2배 빠른 디코딩 및 프리필 성능을 달성하면서도, 지식, 수학, 지시 사항 따르기, 다국어 작업 전반에서 우수한 벤치마크 성능을 유지하는 모델이 탄생했습니다.
## 예시[[example]]
다음 예시는 `AutoModelForCausalLM` 클래스를 사용하여 답변을 생성하는 방법을 보여줍니다.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
# 모델과 토크나이저를 가져옵니다
model_id = "LiquidAI/LFM2-1.2B"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# 답변 생성
prompt = "What is C. elegans?"
input_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt",
tokenize=True,
)
output = model.generate(
input_ids,
do_sample=True,
temperature=0.3,
min_p=0.15,
repetition_penalty=1.05,
max_new_tokens=512,
)
print(tokenizer.decode(output[0], skip_special_tokens=False))
```
## Lfm2Config [[transformers.Lfm2Config]]
[[autodoc]] Lfm2Config
## Lfm2Model [[transformers.Lfm2Model]]
[[autodoc]] Lfm2Model
- forward
## Lfm2ForCausalLM [[transformers.Lfm2ForCausalLM]]
[[autodoc]] Lfm2ForCausalLM
- forward

View File

@ -162,8 +162,8 @@ except ImportError:
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
_HUB_KERNEL_MAPPING: dict[str, str] = {
"causal-conv1d": "kernels-community/causal-conv1d",
_HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
"causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"},
}
_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {}
@ -242,7 +242,9 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]
from kernels import get_kernel
try:
kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name])
repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
kernel = get_kernel(repo_id, version=version)
mapping[kernel_name] = kernel
except FileNotFoundError:
mapping[kernel_name] = None

View File

@ -171,7 +171,7 @@ else:
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
("superglue", ("SuperGlueImageProcessor", None)),
("superglue", ("SuperGlueImageProcessor", "SuperGlueImageProcessorFast")),
("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")),
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),

View File

@ -1,30 +1,17 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for EfficientLoFTR."""
from typing import TYPE_CHECKING, Optional, Union
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/efficientloftr/modular_efficientloftr.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_efficientloftr.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Optional, Union
import torch
from PIL import Image, ImageDraw
from torchvision.transforms.v2 import functional as F
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
group_images_by_shape,
reorder_images,
)
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import (
ImageInput,
ImageType,
@ -35,17 +22,9 @@ from ...image_utils import (
is_valid_image,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
auto_docstring,
)
from ...utils import TensorType, auto_docstring
from .image_processing_efficientloftr import EfficientLoFTRImageProcessorKwargs
if TYPE_CHECKING:
from .modeling_efficientloftr import KeypointMatchingOutput
import torchvision.transforms.v2.functional as F
from .modeling_efficientloftr import KeypointMatchingOutput
def _is_valid_image(image):
@ -299,7 +278,7 @@ class EfficientLoFTRImageProcessorFast(BaseImageProcessorFast):
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return (r, g, b)
return r, g, b
__all__ = ["EfficientLoFTRImageProcessorFast"]

View File

@ -0,0 +1,8 @@
from ..superglue.image_processing_superglue_fast import SuperGlueImageProcessorFast
class EfficientLoFTRImageProcessorFast(SuperGlueImageProcessorFast):
pass
__all__ = ["EfficientLoFTRImageProcessorFast"]

View File

@ -140,7 +140,7 @@ class Florence2Config(PreTrainedConfig):
Florence-2 model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Florence-2
[microsoft/Florence-2-base](https://huggingface.co/microsoft/Florence-2-base) architecture.
[florence-community/Florence-2-base](https://huggingface.co/florence-community/Florence-2-base) architecture.
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.

View File

@ -884,8 +884,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
>>> import requests
>>> from transformers import AutoProcessor, Florence2ForConditionalGeneration
>>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large")
>>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
>>> model = Florence2ForConditionalGeneration.from_pretrained("florence-community/Florence-2-large")
>>> processor = AutoProcessor.from_pretrained("florence-community/Florence-2-large")
>>> prompt = "<CAPTION>"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"

View File

@ -160,7 +160,7 @@ class Florence2Config(PreTrainedConfig):
Florence-2 model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Florence-2
[microsoft/Florence-2-base](https://huggingface.co/microsoft/Florence-2-base) architecture.
[florence-community/Florence-2-base](https://huggingface.co/florence-community/Florence-2-base) architecture.
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.
@ -1674,8 +1674,8 @@ class Florence2ForConditionalGeneration(LlavaForConditionalGeneration):
>>> import requests
>>> from transformers import AutoProcessor, Florence2ForConditionalGeneration
>>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large")
>>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
>>> model = Florence2ForConditionalGeneration.from_pretrained("florence-community/Florence-2-large")
>>> processor = AutoProcessor.from_pretrained("florence-community/Florence-2-large")
>>> prompt = "<CAPTION>"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_superglue import *
from .image_processing_superglue import *
from .image_processing_superglue_fast import *
from .modeling_superglue import *
else:
import sys

View File

@ -35,6 +35,7 @@ from ...image_utils import (
valid_images,
validate_preprocess_arguments,
)
from ...processing_utils import ImagesKwargs
from ...utils import TensorType, logging, requires_backends
from ...utils.import_utils import requires
@ -133,6 +134,15 @@ def validate_and_format_image_pairs(images: ImageInput):
raise ValueError(error_message)
class SuperGlueImageProcessorKwargs(ImagesKwargs, total=False):
r"""
do_grayscale (`bool`, *optional*, defaults to `True`):
Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
"""
do_grayscale: bool
@requires(backends=("torch",))
class SuperGlueImageProcessor(BaseImageProcessor):
r"""

View File

@ -0,0 +1,292 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from PIL import Image, ImageDraw
from torchvision.transforms.v2 import functional as F
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import (
ImageInput,
ImageType,
PILImageResampling,
SizeDict,
get_image_type,
is_pil_image,
is_valid_image,
)
from ...processing_utils import Unpack
from ...utils import TensorType, auto_docstring
from .image_processing_superglue import SuperGlueImageProcessorKwargs
from .modeling_superglue import KeypointMatchingOutput
def _is_valid_image(image):
return is_pil_image(image) or (
is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
)
def flatten_pair_images(images):
# Handle the pair validation and flattening similar to slow processor
if isinstance(images, list):
if len(images) == 2 and all((_is_valid_image(image) or isinstance(image, torch.Tensor)) for image in images):
# Single pair of images - keep as is, they'll be processed by the base class
return images
elif all(
isinstance(image_pair, list)
and len(image_pair) == 2
and all(_is_valid_image(image) or isinstance(image, torch.Tensor) for image in image_pair)
for image_pair in images
):
# Multiple pairs - flatten them
images = [image for image_pair in images for image in image_pair]
return images
raise ValueError(
"Input images must be a one of the following :",
" - A pair of PIL images.",
" - A pair of 3D arrays.",
" - A list of pairs of PIL images.",
" - A list of pairs of 3D arrays.",
)
def is_grayscale(
image: "torch.Tensor",
):
"""Checks if an image is grayscale (all RGB channels are identical)."""
if image.ndim < 3 or image.shape[0 if image.ndim == 3 else 1] == 1:
return True
return torch.all(image[..., 0, :, :] == image[..., 1, :, :]) and torch.all(
image[..., 1, :, :] == image[..., 2, :, :]
)
def convert_to_grayscale(
image: "torch.Tensor",
) -> "torch.Tensor":
"""
Converts an image to grayscale format using the NTSC formula. Only support torch.Tensor.
This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
channel, because of an issue that is discussed in :
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
Args:
image (torch.Tensor):
The image to convert.
"""
if is_grayscale(image):
return image
return F.rgb_to_grayscale(image, num_output_channels=3)
@auto_docstring
class SuperGlueImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
size = {"height": 480, "width": 640}
default_to_square = False
do_resize = True
do_rescale = True
rescale_factor = 1 / 255
do_normalize = None
valid_kwargs = SuperGlueImageProcessorKwargs
def __init__(self, **kwargs: Unpack[SuperGlueImageProcessorKwargs]):
super().__init__(**kwargs)
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[SuperGlueImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
**kwargs,
) -> ImageInput:
# we need to handle image pairs validation and flattening
return flatten_pair_images(images)
def _preprocess(
self,
images: list["torch.Tensor"],
size: Union[dict[str, int], SizeDict],
rescale_factor: float,
do_rescale: bool,
do_resize: bool,
interpolation: Optional["F.InterpolationMode"],
do_grayscale: bool,
disable_grouping: bool,
return_tensors: Union[str, TensorType],
**kwargs,
) -> BatchFeature:
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(stacked_images, size=size, interpolation=interpolation)
processed_images_grouped[shape] = stacked_images
resized_images = reorder_images(processed_images_grouped, grouped_images_index)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_rescale:
stacked_images = self.rescale(stacked_images, rescale_factor)
if do_grayscale:
stacked_images = convert_to_grayscale(stacked_images)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
# Convert back to pairs format
image_pairs = [processed_images[i : i + 2] for i in range(0, len(processed_images), 2)]
# Stack each pair into a single tensor to match slow processor format
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
# Return in same format as slow processor
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
return BatchFeature(data={"pixel_values": image_pairs})
def post_process_keypoint_matching(
self,
outputs: "KeypointMatchingOutput",
target_sizes: Union[TensorType, list[tuple]],
threshold: float = 0.0,
) -> list[dict[str, torch.Tensor]]:
"""
Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors
with coordinates absolute to the original image sizes.
Args:
outputs ([`KeypointMatchingOutput`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
target size `(height, width)` of each image in the batch. This must be the original image size (before
any processing).
threshold (`float`, *optional*, defaults to 0.0):
Threshold to filter out the matches with low scores.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
of the pair, the matching scores and the matching indices.
"""
if outputs.matches.shape[0] != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
if not all(len(target_size) == 2 for target_size in target_sizes):
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
if isinstance(target_sizes, list):
image_pair_sizes = torch.tensor(target_sizes, device=outputs.matches.device)
else:
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
raise ValueError(
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
)
image_pair_sizes = target_sizes
keypoints = outputs.keypoints.clone()
keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
keypoints = keypoints.to(torch.int32)
results = []
for keypoints_pair, matches, scores in zip(keypoints, outputs.matches, outputs.matching_scores):
# Filter out matches with low scores
valid_matches = torch.logical_and(scores > threshold, matches > -1)
matched_keypoints0 = keypoints_pair[0][valid_matches[0]]
matched_keypoints1 = keypoints_pair[1][valid_matches[1]]
matching_scores = scores[0][valid_matches[0]]
results.append(
{
"keypoints0": matched_keypoints0,
"keypoints1": matched_keypoints1,
"matching_scores": matching_scores,
}
)
return results
def visualize_keypoint_matching(
self,
images,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images:
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
from ...image_utils import to_numpy_array
from .image_processing_superglue import validate_and_format_image_pairs
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = torch.zeros((max(height0, height1), width0 + width1, 3), dtype=torch.uint8)
plot_image[:height0, :width0] = torch.from_numpy(image_pair[0])
plot_image[:height1, width0:] = torch.from_numpy(image_pair[1])
plot_image_pil = Image.fromarray(plot_image.numpy())
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return r, g, b
__all__ = ["SuperGlueImageProcessorFast"]

View File

@ -15,8 +15,6 @@ import time
import unittest
import numpy as np
import pytest
from packaging import version
from tests.models.superglue.test_image_processing_superglue import (
SuperGlueImageProcessingTest,
@ -24,10 +22,7 @@ from tests.models.superglue.test_image_processing_superglue import (
)
from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -103,46 +98,6 @@ class EfficientLoFTRImageProcessingTest(SuperGlueImageProcessingTest, unittest.T
super().setUp()
self.image_processor_tester = EfficientLoFTRImageProcessingTester(self)
def test_slow_fast_equivalence(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
# Create image pairs instead of single images
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=False)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
def test_slow_fast_equivalence_batched(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
# Create image pairs instead of single images
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
@unittest.skip(reason="Many failing cases. This test needs a more deep investigation.")
def test_fast_is_faster_than_slow(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
@ -173,25 +128,3 @@ class EfficientLoFTRImageProcessingTest(SuperGlueImageProcessingTest, unittest.T
self.assertLessEqual(
fast_time, slow_time * 1.2, "Fast processor should not be significantly slower than slow processor"
)
@slow
@require_torch_accelerator
@require_vision
@pytest.mark.torch_compile_test
def test_can_compile_fast_image_processor(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if self.fast_image_processing_class is None:
self.skipTest("Skipping compilation test as fast image processor is not defined")
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
torch.compiler.reset()
input_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=False)
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
image_processor = torch.compile(image_processor, mode="reduce-overhead")
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(
output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4, rtol=1e-4, mean_atol=1e-5
)

View File

@ -90,6 +90,7 @@ class LightGlueImageProcessingTester(SuperGlueImageProcessingTester):
@require_vision
class LightGlueImageProcessingTest(SuperGlueImageProcessingTest, unittest.TestCase):
image_processing_class = LightGlueImageProcessor if is_vision_available() else None
fast_image_processing_class = None
def setUp(self) -> None:
super().setUp()

View File

@ -11,12 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import unittest
import numpy as np
import pytest
from packaging import version
from parameterized import parameterized
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import (
ImageProcessingTestMixin,
@ -33,6 +43,9 @@ if is_torch_available():
if is_vision_available():
from transformers import SuperGlueImageProcessor
if is_torchvision_available():
from transformers import SuperGlueImageProcessorFast
def random_array(size):
return np.random.randint(255, size=size)
@ -119,6 +132,7 @@ class SuperGlueImageProcessingTester:
@require_vision
class SuperGlueImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SuperGlueImageProcessor if is_vision_available() else None
fast_image_processing_class = SuperGlueImageProcessorFast if is_torchvision_available() else None
def setUp(self) -> None:
super().setUp()
@ -397,3 +411,76 @@ class SuperGlueImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
tensor_post_processed_outputs = image_processor.post_process_keypoint_matching(outputs, tensor_image_sizes)
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
@unittest.skip(reason="Many failing cases. This test needs a more deep investigation.")
def test_fast_is_faster_than_slow(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast speed test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast speed test as one of the image processors is not defined")
# Create image pairs for speed test
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=False)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
# Time slow processor
start_time = time.time()
for _ in range(10):
_ = image_processor_slow(dummy_images, return_tensors="pt")
slow_time = time.time() - start_time
# Time fast processor
start_time = time.time()
for _ in range(10):
_ = image_processor_fast(dummy_images, return_tensors="pt")
fast_time = time.time() - start_time
# Fast should be faster (or at least not significantly slower)
self.assertLessEqual(
fast_time, slow_time * 1.2, "Fast processor should not be significantly slower than slow processor"
)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image = self.image_processor_tester.prepare_image_inputs(
equal_resolution=False, numpify=True, batch_size=2, pairs=False
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
@slow
@require_torch_accelerator
@require_vision
@pytest.mark.torch_compile_test
def test_can_compile_fast_image_processor(self):
"""Override the generic test since EfficientLoFTR requires image pairs."""
if self.fast_image_processing_class is None:
self.skipTest("Skipping compilation test as fast image processor is not defined")
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
torch.compiler.reset()
input_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=False)
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
image_processor = torch.compile(image_processor, mode="reduce-overhead")
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(
output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4, rtol=1e-4, mean_atol=1e-5
)