Compare commits

...

12 Commits

Author SHA1 Message Date
b356fce1da solve unequal cropping 2025-08-11 19:20:28 +02:00
af7f75e682 use existing methods, add default image 2025-08-11 16:44:06 +02:00
34ba5909a2 add an unnormalize image method 2025-08-11 16:43:27 +02:00
fbec904fb0 Merge branch 'main' into vision_visualizer 2025-08-06 19:19:09 +02:00
a1263dfe7b fixup 2025-08-06 19:17:38 +02:00
1878d6c4ff add captions and better tiling detection 2025-08-06 19:16:14 +02:00
a6a18efe53 better namings 2025-08-05 17:30:05 +02:00
e581d2f2ce fixup 2025-07-25 08:02:39 +00:00
1f6822d114 move processor visualizer 2025-07-25 07:58:35 +00:00
edb70ae15c Merge branch 'main' into vision_visualizer 2025-07-24 12:50:27 +00:00
27bc371bea Merge branch 'main' into vision_visualizer 2025-07-22 13:01:45 +02:00
58c619e809 draft the vision visualizer 2025-03-21 18:53:04 +01:00
2 changed files with 472 additions and 0 deletions

View File

@ -452,6 +452,105 @@ def normalize(
return image
def unnormalize(
image: Union[np.ndarray, "torch.Tensor"],
mean: Union[float, Collection[float]],
std: Union[float, Collection[float]],
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Inverse of `normalize`:
image = image * std + mean
Accepts NumPy arrays or PyTorch tensors and mirrors `normalize`'s API,
but also handles 4D/5D by broadcasting along the channel axis and
collapsing leading batch dims. Defaults to NHWC output for visualization.
"""
# type check
is_np = isinstance(image, np.ndarray)
is_torch = isinstance(image, torch.Tensor)
if not (is_np or is_torch):
raise TypeError("image must be a numpy array or a torch tensor")
# infer layout
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
# cast policy (match normalize): cast only if not floating
if is_np:
if not np.issubdtype(image.dtype, np.floating):
image = image.astype(np.float32)
else:
if not image.is_floating_point():
image = image.float()
# channel axis and sizes
ch_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
num_channels = int(image.shape[ch_axis])
# normalize mean/std to per-channel vectors
def _as_seq(x, n):
if isinstance(x, Collection):
if len(x) != n:
raise ValueError(f"value must have {n} elements if it is an iterable, got {len(x)}")
return x
return [x] * n
mean_seq = _as_seq(mean, num_channels)
std_seq = _as_seq(std, num_channels)
# make broadcastable tensors/arrays shaped [1, ..., C (at ch_axis), ..., 1]
bshape = [1] * image.ndim
bshape[ch_axis] = num_channels
if is_np:
mean_arr = np.asarray(mean_seq, dtype=image.dtype).reshape(bshape)
std_arr = np.asarray(std_seq, dtype=image.dtype).reshape(bshape)
image = image * std_arr + mean_arr
else:
mean_arr = torch.as_tensor(mean_seq, dtype=image.dtype, device=image.device).view(bshape)
std_arr = torch.as_tensor(std_seq, dtype=image.dtype, device=image.device).view(bshape)
image = image * std_arr + mean_arr
# convert to numpy for plotting
if is_torch:
image = image.detach().cpu().numpy()
is_np = True # from here on
# target layout: default to NHWC so downstream viz works out of the box
target_format = data_format or ChannelDimension.LAST
# collapse any leading batch dims into one, preserving (C,H,W) or (H,W,C)
if input_data_format == ChannelDimension.FIRST:
# layout: [*, C, H, W]
lead = int(np.prod(image.shape[: image.ndim - 3])) if image.ndim > 3 else 1
if image.ndim == 3:
c, h, w = image.shape
image = image.reshape(1, c, h, w)
lead = 1
else:
c, h, w = image.shape[-3:]
image = image.reshape(lead, c, h, w)
if target_format == ChannelDimension.LAST:
image = np.transpose(image, (0, 2, 3, 1)) # -> [N, H, W, C]
else:
# layout: [*, H, W, C]
lead = int(np.prod(image.shape[: image.ndim - 3])) if image.ndim > 3 else 1
if image.ndim == 3:
h, w, c = image.shape
image = image.reshape(1, h, w, c)
lead = 1
else:
h, w, c = image.shape[-3:]
image = image.reshape(lead, h, w, c)
if target_format == ChannelDimension.FIRST:
image = np.transpose(image, (0, 3, 1, 2)) # -> [N, C, H, W]
return image
def center_crop(
image: np.ndarray,
size: tuple[int, int],

View File

@ -0,0 +1,373 @@
import re
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
from PIL import Image
from ..image_transforms import convert_to_rgb, to_pil_image, unnormalize
from ..models.auto import AutoConfig, AutoProcessor
# archs failing that should raise immediately for this util:
INCOMPATIBLE_MODELS = [
"bit",
"colpali",
"colqwen2",
"convnext",
"d_fine",
"data2vec",
"efficientloftr",
"efficientnet",
"fuyu",
"gemma3",
"glm4v",
"glpn",
"hgnet_v2",
"hiera",
"internvl",
"janus",
"layoutlmv3",
"levit",
"lightglue",
"llama4",
"mistral3",
"mllama",
"mobilevit",
"mobilevitv2",
"musicgen",
"musicgen_melody",
"oneformer",
"perceiver",
"perception_lm",
"phi4_multimodal",
"qwen2_5_omni",
"qwen2_5_vl",
"qwen2_vl",
"regnet",
"resnet",
"superglue",
"superpoint",
"swin2sr",
"timm_wrapper",
"tvp",
"udop",
"vitmatte",
"vitpose",
"vjepa2",
"whisper",
]
DEFAULT_IMAGE_URL = (
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/hf-logo-224x224.png"
)
def _looks_like_global(tile: np.ndarray, base: Image.Image, *, mae_tol: float = 0.3) -> bool:
"""
Very simple visualizer heuristic.
"""
base_r = base.convert("RGB").resize(tile.shape[:2][::-1], Image.BILINEAR)
base_np = np.asarray(base_r).astype(np.float32) / 255.0
tile_f32 = tile.astype(np.float32)
if tile_f32.max() > 1.5:
tile_f32 /= 255.0
mae = np.abs(tile_f32 - base_np).mean()
return mae < mae_tol
class ImageVisualizer:
def __init__(self, repo_id: str):
self.processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=False)
self.config = AutoConfig.from_pretrained(repo_id, trust_remote_code=False)
if hasattr(self.processor, "image_processor"):
image_processor = self.processor.image_processor
elif hasattr(self.processor, "image_mean"):
image_processor = self.processor # weak test, but works most of the time
else:
raise ValueError(f"No image processor found for {repo_id}.")
self.channel_means = getattr(image_processor, "image_mean", [0.485, 0.456, 0.406])
self.channel_stds = getattr(image_processor, "image_std", [0.229, 0.224, 0.225])
if hasattr(self.processor, "image_token"):
self.image_token_marker = self.processor.image_token
elif hasattr(self.processor, "image_token_id"):
self.image_token_marker = self.processor.decode(self.processor.image_token_id)
else:
self.image_token_marker = "<image>"
self.default_prompt = f"{self.image_token_marker} How does it look?"
self.vision_config = getattr(self.config, "vision_config", None)
self.patch_size = getattr(self.vision_config, "patch_size", getattr(image_processor, "patch_size", 14))
self.merge_size = getattr(image_processor, "merge_size", 1)
def _pixel_values_as_tensor(
self, pixel_values: Union[torch.Tensor, np.ndarray, list[np.ndarray], list[torch.Tensor]]
):
"""
Normalize input to a 4D tensor with shape (batch, channels, height, width).
Supports input of shape:
- (B, C, H, W)
- (B, N, C, H, W) -> flattened to (B*N, C, H, W)
- (C, H, W) -> expanded to (1, C, H, W)
- list/tuple of arrays or tensors
"""
if isinstance(pixel_values, (list, tuple)):
tensor_list = [pv if isinstance(pv, torch.Tensor) else torch.tensor(pv) for pv in pixel_values]
pixel_values = torch.stack(tensor_list, dim=0)
if not isinstance(pixel_values, torch.Tensor):
pixel_values = torch.tensor(pixel_values)
if pixel_values.ndim == 5:
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(batch_size * num_images, num_channels, height, width)
elif pixel_values.ndim == 4:
pass
elif pixel_values.ndim == 3:
pixel_values = pixel_values.unsqueeze(0)
else:
raise ValueError(f"Unexpected pixel tensor shape {pixel_values.shape}")
return pixel_values
def _display_single_image(self, image_array: np.ndarray, show_patch_grid: bool, figsize=(7, 7)):
plt.figure(figsize=figsize)
plt.imshow(image_array)
plt.xticks([])
plt.yticks([])
if show_patch_grid:
height, width = image_array.shape[:2]
step = max(1, min(height, width) // self.patch_size)
for x_pos in range(0, width, step):
plt.axvline(x_pos, color="red", linewidth=0.5)
for y_pos in range(0, height, step):
plt.axhline(y_pos, color="red", linewidth=0.5)
caption = f"{width}×{height} | mean={', '.join(f'{m:.3f}' for m in self.channel_means)} | std={', '.join(f'{s:.3f}' for s in self.channel_stds)}"
plt.tight_layout()
plt.figtext(0.5, -0.02, caption, ha="center", va="top", fontsize=12)
plt.show()
def _display_tiled_images(
self,
tiles_array: np.ndarray,
source_image: Image.Image,
rows: Optional[int] = None,
cols: Optional[int] = None,
aspect_ratio: float = 1.0,
add_grid: bool = True,
figsize=(7, 7),
):
"""
Display a grid of image tiles. Attempts to detect and preserve the original/global image tile,
which is then shown separately at the end.
"""
num_tiles = tiles_array.shape[0]
original_tile_index = None
saved_original_tile = None
for idx in (0, num_tiles - 1):
if _looks_like_global(tiles_array[idx], source_image):
original_tile_index = idx
break
if original_tile_index is not None:
saved_original_tile = tiles_array[original_tile_index]
tiles_array = np.delete(tiles_array, original_tile_index, axis=0)
num_tiles -= 1
# Infer grid if not specified
grid_rows, grid_cols = rows, cols
if grid_rows is None or grid_cols is None:
if aspect_ratio >= 1:
guessed_cols = int(np.ceil(np.sqrt(num_tiles * aspect_ratio)))
guessed_rows = int(np.ceil(num_tiles / max(guessed_cols, 1)))
else:
guessed_rows = int(np.ceil(np.sqrt(num_tiles / max(aspect_ratio, 1e-8))))
guessed_cols = int(np.ceil(num_tiles / max(guessed_rows, 1)))
grid_rows = grid_rows if grid_rows is not None else guessed_rows
grid_cols = grid_cols if grid_cols is not None else guessed_cols
fig, axes = plt.subplots(grid_rows, grid_cols, figsize=figsize, squeeze=False)
tile_index = 0
for row_idx in range(grid_rows):
for col_idx in range(grid_cols):
ax = axes[row_idx, col_idx]
if tile_index < num_tiles:
tile_image = tiles_array[tile_index]
ax.imshow(tile_image)
ax.set_xticks([])
ax.set_yticks([])
if add_grid:
height, width = tile_image.shape[:2]
step = max(1, min(height, width) // self.patch_size)
for x_pos in range(0, width, step):
ax.axvline(x_pos, color="red", linewidth=0.5)
for y_pos in range(0, height, step):
ax.axhline(y_pos, color="red", linewidth=0.5)
else:
ax.axis("off")
tile_index += 1
unique = sorted({f"{t.shape[1]}×{t.shape[0]}" for t in tiles_array})
sizes = ", ".join(unique)
caption = f"{tiles_array.shape[0]} patches | {sizes} | mean={', '.join(f'{m:.3f}' for m in self.channel_means)} | std={', '.join(f'{s:.3f}' for s in self.channel_stds)}"
plt.tight_layout()
fig.text(0.5, 0.02, caption, ha="center", va="bottom", fontsize=12)
plt.show()
if saved_original_tile is not None:
fig2, ax2 = plt.subplots(figsize=figsize)
ax2.imshow(saved_original_tile)
ax2.set_xticks([])
ax2.set_yticks([])
ax2.set_aspect("equal", adjustable="box")
fig2.subplots_adjust(left=0, right=1, top=1, bottom=0) # no clipping
h0, w0 = saved_original_tile.shape[:2]
caption = f"{w0}×{h0} | mean={', '.join(f'{m:.3f}' for m in self.channel_means)} | std={', '.join(f'{s:.3f}' for s in self.channel_stds)}"
fig2.text(0.5, 0.02, caption, ha="center", va="bottom", fontsize=12)
plt.show()
def default_message(self, full_output: bool = False) -> str:
"""
Build a single formatted prompt string using the processor's chat template.
Contains one image (HF logo) and one user text message.
If available, adds the generation prompt as well.
Falls back to a minimal '<image>' string if no template is available.
"""
# ensure this is a multimodal processor with image + tokenizer
if not (
hasattr(self.processor, "attributes")
and "image_processor" in self.processor.attributes
and "tokenizer" in self.processor.attributes
):
raise RuntimeError(
"Processor does not expose both 'image_processor' and 'tokenizer'; cannot build multimodal example."
)
conversation = [
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/hf-logo-224x224.png",
},
{"type": "text", "text": "Please describe this image."},
],
}
]
try:
print("For a 224x224 RGB png image: \n")
decoded_message = self.processor.batch_decode(
self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=False,
truncation=False,
),
skip_special_tokens=False,
)[0]
image_token_string = getattr(self.processor, "image_token", "<image>")
token_escaped = re.escape(image_token_string)
image_token_run_pattern = re.compile(rf"(?:{token_escaped})(?:\s*{token_escaped}){{2,}}")
def compress_image_token_run(match: re.Match) -> str:
n_tokens = match.group(0).count(image_token_string)
return f"{image_token_string}[...{n_tokens} tokens...]{image_token_string}"
if full_output:
return decoded_message
else:
return image_token_run_pattern.sub(compress_image_token_run, decoded_message)
except ValueError:
image_token_string = getattr(
self.processor,
"image_token",
getattr(getattr(self.processor, "tokenizer", None), "image_token", "<image>"),
)
return f"{image_token_string} {'Please describe this image.'}"
def visualize(
self,
images: Optional[Union[Image.Image, np.ndarray, str, list[Union[Image.Image, np.ndarray, str]]]] = None,
rows: Optional[int] = None,
cols: Optional[int] = None,
add_grid: bool = True,
figsize=(12, 12),
):
"""
Visualize the model-processed image(s). Only single images are supported.
If the processor returns multiple tiles, display them in a grid with optional patch grid overlay.
"""
if images is None:
images = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw)
if not isinstance(images, list):
images = [images]
else:
if len(images) > 1:
raise ValueError(
"You passed a list of several images. Only single images are accepted by the visualizer."
)
pil_images = [convert_to_rgb(to_pil_image(x)) for x in images]
img_width, img_height = pil_images[0].size
aspect_ratio = img_width / max(img_height, 1)
processed_inputs = self.processor(images=pil_images, text=self.default_prompt, return_tensors="pt")
pixel_values = processed_inputs["pixel_values"]
unnormalized = unnormalize(pixel_values, mean=self.channel_means, std=self.channel_stds)
if unnormalized.ndim == 3 or unnormalized.shape[0] == 1:
self._display_single_image(
unnormalized[0] if unnormalized.ndim == 4 else unnormalized,
show_patch_grid=add_grid,
figsize=figsize,
)
return
elif unnormalized.ndim != 4:
raise ValueError(f"Unsupported shape after unnormalization: {unnormalized.shape}")
num_tiles = unnormalized.shape[0]
if rows is None or cols is None:
tile_h, tile_w = unnormalized.shape[1:3]
tile_aspect = tile_w / tile_h if tile_h > 0 else 1.0
target_aspect = aspect_ratio / tile_aspect
best_rows, best_cols = 1, num_tiles
min_diff = float("inf")
for r in range(1, num_tiles + 1):
c = int(np.ceil(num_tiles / r))
diff = abs((c / r) - target_aspect)
if diff < min_diff:
min_diff = diff
best_rows, best_cols = r, c
rows = best_rows
cols = best_cols
self._display_tiled_images(
unnormalized,
pil_images[0],
rows=rows,
cols=cols,
aspect_ratio=aspect_ratio,
add_grid=add_grid,
figsize=figsize,
)