mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
12 Commits
3c7552f733
...
vision_vis
Author | SHA1 | Date | |
---|---|---|---|
b356fce1da | |||
af7f75e682 | |||
34ba5909a2 | |||
fbec904fb0 | |||
a1263dfe7b | |||
1878d6c4ff | |||
a6a18efe53 | |||
e581d2f2ce | |||
1f6822d114 | |||
edb70ae15c | |||
27bc371bea | |||
58c619e809 |
@ -452,6 +452,105 @@ def normalize(
|
||||
return image
|
||||
|
||||
|
||||
def unnormalize(
|
||||
image: Union[np.ndarray, "torch.Tensor"],
|
||||
mean: Union[float, Collection[float]],
|
||||
std: Union[float, Collection[float]],
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Inverse of `normalize`:
|
||||
|
||||
image = image * std + mean
|
||||
|
||||
Accepts NumPy arrays or PyTorch tensors and mirrors `normalize`'s API,
|
||||
but also handles 4D/5D by broadcasting along the channel axis and
|
||||
collapsing leading batch dims. Defaults to NHWC output for visualization.
|
||||
"""
|
||||
# type check
|
||||
is_np = isinstance(image, np.ndarray)
|
||||
is_torch = isinstance(image, torch.Tensor)
|
||||
if not (is_np or is_torch):
|
||||
raise TypeError("image must be a numpy array or a torch tensor")
|
||||
|
||||
# infer layout
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
# cast policy (match normalize): cast only if not floating
|
||||
if is_np:
|
||||
if not np.issubdtype(image.dtype, np.floating):
|
||||
image = image.astype(np.float32)
|
||||
else:
|
||||
if not image.is_floating_point():
|
||||
image = image.float()
|
||||
|
||||
# channel axis and sizes
|
||||
ch_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
|
||||
num_channels = int(image.shape[ch_axis])
|
||||
|
||||
# normalize mean/std to per-channel vectors
|
||||
def _as_seq(x, n):
|
||||
if isinstance(x, Collection):
|
||||
if len(x) != n:
|
||||
raise ValueError(f"value must have {n} elements if it is an iterable, got {len(x)}")
|
||||
return x
|
||||
return [x] * n
|
||||
|
||||
mean_seq = _as_seq(mean, num_channels)
|
||||
std_seq = _as_seq(std, num_channels)
|
||||
|
||||
# make broadcastable tensors/arrays shaped [1, ..., C (at ch_axis), ..., 1]
|
||||
bshape = [1] * image.ndim
|
||||
bshape[ch_axis] = num_channels
|
||||
|
||||
if is_np:
|
||||
mean_arr = np.asarray(mean_seq, dtype=image.dtype).reshape(bshape)
|
||||
std_arr = np.asarray(std_seq, dtype=image.dtype).reshape(bshape)
|
||||
image = image * std_arr + mean_arr
|
||||
else:
|
||||
mean_arr = torch.as_tensor(mean_seq, dtype=image.dtype, device=image.device).view(bshape)
|
||||
std_arr = torch.as_tensor(std_seq, dtype=image.dtype, device=image.device).view(bshape)
|
||||
image = image * std_arr + mean_arr
|
||||
|
||||
# convert to numpy for plotting
|
||||
if is_torch:
|
||||
image = image.detach().cpu().numpy()
|
||||
is_np = True # from here on
|
||||
|
||||
# target layout: default to NHWC so downstream viz works out of the box
|
||||
target_format = data_format or ChannelDimension.LAST
|
||||
|
||||
# collapse any leading batch dims into one, preserving (C,H,W) or (H,W,C)
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
# layout: [*, C, H, W]
|
||||
lead = int(np.prod(image.shape[: image.ndim - 3])) if image.ndim > 3 else 1
|
||||
if image.ndim == 3:
|
||||
c, h, w = image.shape
|
||||
image = image.reshape(1, c, h, w)
|
||||
lead = 1
|
||||
else:
|
||||
c, h, w = image.shape[-3:]
|
||||
image = image.reshape(lead, c, h, w)
|
||||
if target_format == ChannelDimension.LAST:
|
||||
image = np.transpose(image, (0, 2, 3, 1)) # -> [N, H, W, C]
|
||||
else:
|
||||
# layout: [*, H, W, C]
|
||||
lead = int(np.prod(image.shape[: image.ndim - 3])) if image.ndim > 3 else 1
|
||||
if image.ndim == 3:
|
||||
h, w, c = image.shape
|
||||
image = image.reshape(1, h, w, c)
|
||||
lead = 1
|
||||
else:
|
||||
h, w, c = image.shape[-3:]
|
||||
image = image.reshape(lead, h, w, c)
|
||||
if target_format == ChannelDimension.FIRST:
|
||||
image = np.transpose(image, (0, 3, 1, 2)) # -> [N, C, H, W]
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def center_crop(
|
||||
image: np.ndarray,
|
||||
size: tuple[int, int],
|
||||
|
373
src/transformers/utils/processor_visualizer_utils.py
Normal file
373
src/transformers/utils/processor_visualizer_utils.py
Normal file
@ -0,0 +1,373 @@
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..image_transforms import convert_to_rgb, to_pil_image, unnormalize
|
||||
from ..models.auto import AutoConfig, AutoProcessor
|
||||
|
||||
|
||||
# archs failing that should raise immediately for this util:
|
||||
|
||||
INCOMPATIBLE_MODELS = [
|
||||
"bit",
|
||||
"colpali",
|
||||
"colqwen2",
|
||||
"convnext",
|
||||
"d_fine",
|
||||
"data2vec",
|
||||
"efficientloftr",
|
||||
"efficientnet",
|
||||
"fuyu",
|
||||
"gemma3",
|
||||
"glm4v",
|
||||
"glpn",
|
||||
"hgnet_v2",
|
||||
"hiera",
|
||||
"internvl",
|
||||
"janus",
|
||||
"layoutlmv3",
|
||||
"levit",
|
||||
"lightglue",
|
||||
"llama4",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"mobilevit",
|
||||
"mobilevitv2",
|
||||
"musicgen",
|
||||
"musicgen_melody",
|
||||
"oneformer",
|
||||
"perceiver",
|
||||
"perception_lm",
|
||||
"phi4_multimodal",
|
||||
"qwen2_5_omni",
|
||||
"qwen2_5_vl",
|
||||
"qwen2_vl",
|
||||
"regnet",
|
||||
"resnet",
|
||||
"superglue",
|
||||
"superpoint",
|
||||
"swin2sr",
|
||||
"timm_wrapper",
|
||||
"tvp",
|
||||
"udop",
|
||||
"vitmatte",
|
||||
"vitpose",
|
||||
"vjepa2",
|
||||
"whisper",
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_IMAGE_URL = (
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/hf-logo-224x224.png"
|
||||
)
|
||||
|
||||
|
||||
def _looks_like_global(tile: np.ndarray, base: Image.Image, *, mae_tol: float = 0.3) -> bool:
|
||||
"""
|
||||
Very simple visualizer heuristic.
|
||||
"""
|
||||
base_r = base.convert("RGB").resize(tile.shape[:2][::-1], Image.BILINEAR)
|
||||
base_np = np.asarray(base_r).astype(np.float32) / 255.0
|
||||
|
||||
tile_f32 = tile.astype(np.float32)
|
||||
if tile_f32.max() > 1.5:
|
||||
tile_f32 /= 255.0
|
||||
|
||||
mae = np.abs(tile_f32 - base_np).mean()
|
||||
return mae < mae_tol
|
||||
|
||||
|
||||
class ImageVisualizer:
|
||||
def __init__(self, repo_id: str):
|
||||
self.processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=False)
|
||||
self.config = AutoConfig.from_pretrained(repo_id, trust_remote_code=False)
|
||||
|
||||
if hasattr(self.processor, "image_processor"):
|
||||
image_processor = self.processor.image_processor
|
||||
elif hasattr(self.processor, "image_mean"):
|
||||
image_processor = self.processor # weak test, but works most of the time
|
||||
else:
|
||||
raise ValueError(f"No image processor found for {repo_id}.")
|
||||
|
||||
self.channel_means = getattr(image_processor, "image_mean", [0.485, 0.456, 0.406])
|
||||
self.channel_stds = getattr(image_processor, "image_std", [0.229, 0.224, 0.225])
|
||||
if hasattr(self.processor, "image_token"):
|
||||
self.image_token_marker = self.processor.image_token
|
||||
elif hasattr(self.processor, "image_token_id"):
|
||||
self.image_token_marker = self.processor.decode(self.processor.image_token_id)
|
||||
else:
|
||||
self.image_token_marker = "<image>"
|
||||
|
||||
self.default_prompt = f"{self.image_token_marker} How does it look?"
|
||||
|
||||
self.vision_config = getattr(self.config, "vision_config", None)
|
||||
self.patch_size = getattr(self.vision_config, "patch_size", getattr(image_processor, "patch_size", 14))
|
||||
self.merge_size = getattr(image_processor, "merge_size", 1)
|
||||
|
||||
def _pixel_values_as_tensor(
|
||||
self, pixel_values: Union[torch.Tensor, np.ndarray, list[np.ndarray], list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Normalize input to a 4D tensor with shape (batch, channels, height, width).
|
||||
Supports input of shape:
|
||||
- (B, C, H, W)
|
||||
- (B, N, C, H, W) -> flattened to (B*N, C, H, W)
|
||||
- (C, H, W) -> expanded to (1, C, H, W)
|
||||
- list/tuple of arrays or tensors
|
||||
"""
|
||||
if isinstance(pixel_values, (list, tuple)):
|
||||
tensor_list = [pv if isinstance(pv, torch.Tensor) else torch.tensor(pv) for pv in pixel_values]
|
||||
pixel_values = torch.stack(tensor_list, dim=0)
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
pixel_values = torch.tensor(pixel_values)
|
||||
|
||||
if pixel_values.ndim == 5:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(batch_size * num_images, num_channels, height, width)
|
||||
elif pixel_values.ndim == 4:
|
||||
pass
|
||||
elif pixel_values.ndim == 3:
|
||||
pixel_values = pixel_values.unsqueeze(0)
|
||||
else:
|
||||
raise ValueError(f"Unexpected pixel tensor shape {pixel_values.shape}")
|
||||
|
||||
return pixel_values
|
||||
|
||||
def _display_single_image(self, image_array: np.ndarray, show_patch_grid: bool, figsize=(7, 7)):
|
||||
plt.figure(figsize=figsize)
|
||||
plt.imshow(image_array)
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
|
||||
if show_patch_grid:
|
||||
height, width = image_array.shape[:2]
|
||||
step = max(1, min(height, width) // self.patch_size)
|
||||
for x_pos in range(0, width, step):
|
||||
plt.axvline(x_pos, color="red", linewidth=0.5)
|
||||
for y_pos in range(0, height, step):
|
||||
plt.axhline(y_pos, color="red", linewidth=0.5)
|
||||
|
||||
caption = f"{width}×{height} | mean={', '.join(f'{m:.3f}' for m in self.channel_means)} | std={', '.join(f'{s:.3f}' for s in self.channel_stds)}"
|
||||
plt.tight_layout()
|
||||
plt.figtext(0.5, -0.02, caption, ha="center", va="top", fontsize=12)
|
||||
plt.show()
|
||||
|
||||
def _display_tiled_images(
|
||||
self,
|
||||
tiles_array: np.ndarray,
|
||||
source_image: Image.Image,
|
||||
rows: Optional[int] = None,
|
||||
cols: Optional[int] = None,
|
||||
aspect_ratio: float = 1.0,
|
||||
add_grid: bool = True,
|
||||
figsize=(7, 7),
|
||||
):
|
||||
"""
|
||||
Display a grid of image tiles. Attempts to detect and preserve the original/global image tile,
|
||||
which is then shown separately at the end.
|
||||
"""
|
||||
num_tiles = tiles_array.shape[0]
|
||||
|
||||
original_tile_index = None
|
||||
saved_original_tile = None
|
||||
|
||||
for idx in (0, num_tiles - 1):
|
||||
if _looks_like_global(tiles_array[idx], source_image):
|
||||
original_tile_index = idx
|
||||
break
|
||||
|
||||
if original_tile_index is not None:
|
||||
saved_original_tile = tiles_array[original_tile_index]
|
||||
tiles_array = np.delete(tiles_array, original_tile_index, axis=0)
|
||||
num_tiles -= 1
|
||||
|
||||
# Infer grid if not specified
|
||||
grid_rows, grid_cols = rows, cols
|
||||
if grid_rows is None or grid_cols is None:
|
||||
if aspect_ratio >= 1:
|
||||
guessed_cols = int(np.ceil(np.sqrt(num_tiles * aspect_ratio)))
|
||||
guessed_rows = int(np.ceil(num_tiles / max(guessed_cols, 1)))
|
||||
else:
|
||||
guessed_rows = int(np.ceil(np.sqrt(num_tiles / max(aspect_ratio, 1e-8))))
|
||||
guessed_cols = int(np.ceil(num_tiles / max(guessed_rows, 1)))
|
||||
grid_rows = grid_rows if grid_rows is not None else guessed_rows
|
||||
grid_cols = grid_cols if grid_cols is not None else guessed_cols
|
||||
|
||||
fig, axes = plt.subplots(grid_rows, grid_cols, figsize=figsize, squeeze=False)
|
||||
tile_index = 0
|
||||
for row_idx in range(grid_rows):
|
||||
for col_idx in range(grid_cols):
|
||||
ax = axes[row_idx, col_idx]
|
||||
if tile_index < num_tiles:
|
||||
tile_image = tiles_array[tile_index]
|
||||
ax.imshow(tile_image)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
if add_grid:
|
||||
height, width = tile_image.shape[:2]
|
||||
step = max(1, min(height, width) // self.patch_size)
|
||||
for x_pos in range(0, width, step):
|
||||
ax.axvline(x_pos, color="red", linewidth=0.5)
|
||||
for y_pos in range(0, height, step):
|
||||
ax.axhline(y_pos, color="red", linewidth=0.5)
|
||||
else:
|
||||
ax.axis("off")
|
||||
tile_index += 1
|
||||
|
||||
unique = sorted({f"{t.shape[1]}×{t.shape[0]}" for t in tiles_array})
|
||||
sizes = ", ".join(unique)
|
||||
caption = f"{tiles_array.shape[0]} patches | {sizes} | mean={', '.join(f'{m:.3f}' for m in self.channel_means)} | std={', '.join(f'{s:.3f}' for s in self.channel_stds)}"
|
||||
plt.tight_layout()
|
||||
fig.text(0.5, 0.02, caption, ha="center", va="bottom", fontsize=12)
|
||||
plt.show()
|
||||
|
||||
if saved_original_tile is not None:
|
||||
fig2, ax2 = plt.subplots(figsize=figsize)
|
||||
ax2.imshow(saved_original_tile)
|
||||
ax2.set_xticks([])
|
||||
ax2.set_yticks([])
|
||||
ax2.set_aspect("equal", adjustable="box")
|
||||
fig2.subplots_adjust(left=0, right=1, top=1, bottom=0) # no clipping
|
||||
h0, w0 = saved_original_tile.shape[:2]
|
||||
caption = f"{w0}×{h0} | mean={', '.join(f'{m:.3f}' for m in self.channel_means)} | std={', '.join(f'{s:.3f}' for s in self.channel_stds)}"
|
||||
fig2.text(0.5, 0.02, caption, ha="center", va="bottom", fontsize=12)
|
||||
plt.show()
|
||||
|
||||
def default_message(self, full_output: bool = False) -> str:
|
||||
"""
|
||||
Build a single formatted prompt string using the processor's chat template.
|
||||
Contains one image (HF logo) and one user text message.
|
||||
If available, adds the generation prompt as well.
|
||||
Falls back to a minimal '<image>' string if no template is available.
|
||||
"""
|
||||
# ensure this is a multimodal processor with image + tokenizer
|
||||
if not (
|
||||
hasattr(self.processor, "attributes")
|
||||
and "image_processor" in self.processor.attributes
|
||||
and "tokenizer" in self.processor.attributes
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Processor does not expose both 'image_processor' and 'tokenizer'; cannot build multimodal example."
|
||||
)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/hf-logo-224x224.png",
|
||||
},
|
||||
{"type": "text", "text": "Please describe this image."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
print("For a 224x224 RGB png image: \n")
|
||||
decoded_message = self.processor.batch_decode(
|
||||
self.processor.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=False,
|
||||
truncation=False,
|
||||
),
|
||||
skip_special_tokens=False,
|
||||
)[0]
|
||||
|
||||
image_token_string = getattr(self.processor, "image_token", "<image>")
|
||||
token_escaped = re.escape(image_token_string)
|
||||
image_token_run_pattern = re.compile(rf"(?:{token_escaped})(?:\s*{token_escaped}){{2,}}")
|
||||
|
||||
def compress_image_token_run(match: re.Match) -> str:
|
||||
n_tokens = match.group(0).count(image_token_string)
|
||||
return f"{image_token_string}[...{n_tokens} tokens...]{image_token_string}"
|
||||
|
||||
if full_output:
|
||||
return decoded_message
|
||||
else:
|
||||
return image_token_run_pattern.sub(compress_image_token_run, decoded_message)
|
||||
|
||||
except ValueError:
|
||||
image_token_string = getattr(
|
||||
self.processor,
|
||||
"image_token",
|
||||
getattr(getattr(self.processor, "tokenizer", None), "image_token", "<image>"),
|
||||
)
|
||||
return f"{image_token_string} {'Please describe this image.'}"
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
images: Optional[Union[Image.Image, np.ndarray, str, list[Union[Image.Image, np.ndarray, str]]]] = None,
|
||||
rows: Optional[int] = None,
|
||||
cols: Optional[int] = None,
|
||||
add_grid: bool = True,
|
||||
figsize=(12, 12),
|
||||
):
|
||||
"""
|
||||
Visualize the model-processed image(s). Only single images are supported.
|
||||
If the processor returns multiple tiles, display them in a grid with optional patch grid overlay.
|
||||
"""
|
||||
if images is None:
|
||||
images = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw)
|
||||
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
else:
|
||||
if len(images) > 1:
|
||||
raise ValueError(
|
||||
"You passed a list of several images. Only single images are accepted by the visualizer."
|
||||
)
|
||||
|
||||
pil_images = [convert_to_rgb(to_pil_image(x)) for x in images]
|
||||
img_width, img_height = pil_images[0].size
|
||||
aspect_ratio = img_width / max(img_height, 1)
|
||||
|
||||
processed_inputs = self.processor(images=pil_images, text=self.default_prompt, return_tensors="pt")
|
||||
pixel_values = processed_inputs["pixel_values"]
|
||||
unnormalized = unnormalize(pixel_values, mean=self.channel_means, std=self.channel_stds)
|
||||
if unnormalized.ndim == 3 or unnormalized.shape[0] == 1:
|
||||
self._display_single_image(
|
||||
unnormalized[0] if unnormalized.ndim == 4 else unnormalized,
|
||||
show_patch_grid=add_grid,
|
||||
figsize=figsize,
|
||||
)
|
||||
return
|
||||
elif unnormalized.ndim != 4:
|
||||
raise ValueError(f"Unsupported shape after unnormalization: {unnormalized.shape}")
|
||||
|
||||
num_tiles = unnormalized.shape[0]
|
||||
|
||||
if rows is None or cols is None:
|
||||
tile_h, tile_w = unnormalized.shape[1:3]
|
||||
tile_aspect = tile_w / tile_h if tile_h > 0 else 1.0
|
||||
target_aspect = aspect_ratio / tile_aspect
|
||||
|
||||
best_rows, best_cols = 1, num_tiles
|
||||
min_diff = float("inf")
|
||||
for r in range(1, num_tiles + 1):
|
||||
c = int(np.ceil(num_tiles / r))
|
||||
diff = abs((c / r) - target_aspect)
|
||||
if diff < min_diff:
|
||||
min_diff = diff
|
||||
best_rows, best_cols = r, c
|
||||
|
||||
rows = best_rows
|
||||
cols = best_cols
|
||||
self._display_tiled_images(
|
||||
unnormalized,
|
||||
pil_images[0],
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
aspect_ratio=aspect_ratio,
|
||||
add_grid=add_grid,
|
||||
figsize=figsize,
|
||||
)
|
Reference in New Issue
Block a user